/*      $NetBSD: edahdi.c,v 1.11 2011/10/01 15:59:00 chs Exp $  */

/*
* Copyright (c) 1996 Leo Weppelman, Waldi Ravens.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
*    notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
*    notice, this list of conditions and the following disclaimer in the
*    documentation and/or other materials provided with the distribution.
* 3. All advertising materials mentioning features or use of this software
*    must display the following acknowledgement:
*      This product includes software developed by
*                      Leo Weppelman and Waldi Ravens.
* 4. The name of the author may not be used to endorse or promote products
*    derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
* IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
* NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/*
* This code implements a simple editor for partition id's on disks containing
* AHDI partition info.
*
* Credits for the code handling disklabels goes to Waldi Ravens.
*
*/
#include <sys/types.h>
#include <sys/param.h>
#include <sys/stat.h>
#include <sys/disklabel.h>

#include <machine/ahdilabel.h>

#include <fcntl.h>
#include <stdlib.h>
#include <term.h>
#include <termios.h>
#include <unistd.h>
#include <stdio.h>
#include <string.h>
#include <err.h>
#include <ctype.h>

/*
* Internal partition tables:
*/
typedef struct {
       char    id[4];
       u_int   start;
       u_int   end;
       u_int   rsec;
       u_int   rent;
       int     mod;
} part_t;

typedef struct {
       int     nparts;
       part_t  *parts;
} ptable_t;

/*
* I think we can savely assume a fixed blocksize - AHDI won't support
* something different...
*/
#define BLPM            ((1024 * 1024) / DEV_BSIZE)

/*
* #Partition entries shown on the screen at once
*/
#define MAX_PSHOWN      16      /* #partitions shown on screen  */

/*
* Tokens:
*/
#define T_INVAL         0
#define T_QUIT          1
#define T_WRITE         2
#define T_NEXT          3
#define T_PREV          4
#define T_NUMBER        5
#define T_EOF           6

void    ahdi_cksum(void *);
u_int   ahdi_getparts(int, ptable_t *, u_int, u_int);
int     bsd_label(int, u_int);
int     dkcksum(struct disklabel *);
int     edit_parts(int, ptable_t *);
void   *disk_read(int, u_int, u_int);
void    disk_write(int, u_int, u_int, void  *);
char   *get_id(void);
int     lex(int *);
int     show_parts(ptable_t *, int);
void    update_disk(ptable_t *, int, int);

int
main(int argc, char *argv[])
{
       int             fd;
       ptable_t        ptable;
       int             rv;
       struct stat     st;

       if (argc != 2) {
               char    *prog = strrchr(argv[0], '/');

               if (prog == NULL)
                       prog = argv[0];
               else prog++;
               fprintf(stderr, "Usage: %s <raw_disk_device>", prog);
               exit(1);
       }
       if ((fd = open(argv[1], O_RDWR)) < 0)
               err(1, "Cannot open '%s'.", argv[1]);
       if (fstat(fd, &st) < 0)
               err(1, "Cannot stat '%s'.", argv[1]);
       if (!S_ISCHR(st.st_mode))
               errx(1, "'%s' must be a character special device.", argv[1]);

       if ((rv = bsd_label(fd, LABELSECTOR)) < 0)
               errx(1, "I/O error");
       if (rv == 0) {
               warnx("Disk has no ahdi partitions");
               return (2);
       }

       setupterm(NULL, STDOUT_FILENO, NULL);

       ptable.nparts = 0;
       ptable.parts  = NULL;

       if ((ahdi_getparts(fd, &ptable, AHDI_BBLOCK, AHDI_BBLOCK) != 0)
               || (ptable.nparts == 0))
               exit (1);

       edit_parts(fd, &ptable);
       return (0);
}

int
edit_parts(int fd, ptable_t *ptable)
{
       int     scr_base = 0;
       int     value;
       char    *error, *new_id;

       for (;;) {
               error = NULL;
               if (clear_screen)
                       tputs(clear_screen, 1, putchar);
               show_parts(ptable, scr_base);

               printf("\n\n");
               printf("q    : quit - don't update the disk\n");
               printf("w    : write changes to disk\n");
               printf(">    : next screen of partitions\n");
               printf("<    : previous screen of partitions\n");
               printf("<nr> : modify id of partition <nr>\n");
               printf("\n\nCommand? ");
               fflush(stdout);

               switch (lex(&value)) {
                   case T_EOF:
                       exit(0);

                   case T_INVAL:
                       error = "Invalid command";
                       break;
                   case T_QUIT :
                       for (value = 0; value < ptable->nparts; value++) {
                               if (ptable->parts[value].mod) {
                                       printf("\nThere are unwritten changes."
                                               " Quit anyway? [n] ");
                                       value = getchar();
                                       if ((value == 'y') || (value == 'Y'))
                                                       exit (0);
                                       while (value != '\n')
                                               value = getchar();
                                       break;
                               }
                       }
                       if (value == ptable->nparts)
                               exit(0);
                       break;
                   case T_WRITE:
                       error = "No changes to write";
                       for (value = 0; value < ptable->nparts; value++) {
                               if (ptable->parts[value].mod) {
                                       update_disk(ptable, fd, value);
                                       error = "";
                               }
                       }
                       break;
                   case T_NEXT :
                       if ((scr_base + MAX_PSHOWN) < ptable->nparts)
                               scr_base += MAX_PSHOWN;
                       break;
                   case T_PREV :
                       scr_base -= MAX_PSHOWN;
                       if (scr_base < 0)
                               scr_base = 0;
                       break;
                   case T_NUMBER:
                       if (value >= ptable->nparts) {
                               error = "Not that many partitions";
                               break;
                       }
                       if ((new_id = get_id()) == NULL) {
                               error = "Invalid id";
                               break;
                       }
                       strncpy(ptable->parts[value].id, new_id, 3);
                       ptable->parts[value].mod = 1;
                       scr_base = (value / MAX_PSHOWN) * MAX_PSHOWN;
                       break;
                   default :
                       error = "Internal error - unknown token";
                       break;
               }
               if (error != NULL) {
                       printf("\n\n%s", error);
                       fflush(stdout);
                       sleep(2);
               }
       }
}

int
show_parts(ptable_t *ptable, int nr)
{
       int     i;
       part_t  *p;
       u_int   megs;

       if (nr >= ptable->nparts)
               return (0);     /* Nothing to show */
       printf("\n\n");
       printf("nr      root  desc   id     start       end    MBs\n");

       p = &ptable->parts[nr];
       i = nr;
       for(; (i < ptable->nparts) && ((i - nr) < MAX_PSHOWN); i++, p++) {
               megs = ((p->end - p->start + 1) + (BLPM >> 1)) / BLPM;
               printf("%2d%s %8u  %4u  %s  %8u  %8u  (%3u)\n", i,
                       p->mod ? "*" : " ",
                       p->rsec, p->rent, p->id, p->start, p->end, megs);
       }
       return (1);
}

int
lex(int *value)
{
       char    c[1];
       int     rv, nch;

       rv = T_INVAL;

       *value = 0;
       for (;;) {
               if ((nch = read (0, c, 1)) != 1) {
                       if (nch == 0)
                               return (T_EOF);
                       else return (rv);
               }
               switch (*c) {
                       case 'q':
                               rv = T_QUIT;
                               goto out;
                       case 'w':
                               rv = T_WRITE;
                               goto out;
                       case '>':
                               rv = T_NEXT;
                               goto out;
                       case '<':
                               rv = T_PREV;
                               goto out;
                       default :
                               if (isspace((unsigned char)*c)) {
                                       if (rv == T_INVAL)
                                               break;
                                       goto out;
                               }
                               else if (isdigit((unsigned char)*c)) {
                                       *value = (10 * *value) + *c - '0';
                                       rv = T_NUMBER;
                               }
                               goto out;
               }
       }
       /* NOTREACHED */
out:
       /*
        * Flush rest of line before returning
        */
       while (read (0, c, 1) == 1)
               if ((*c == '\n') || (*c == '\r'))
                       break;
       return (rv);
}

char *
get_id(void)
{
       static char     buf[5];
              int      n;
       printf("\nEnter new id: ");
       if (fgets(buf, sizeof(buf), stdin) == NULL)
               return (NULL);
       for (n = 0; n < 3; n++) {
               if (!isalpha((unsigned char)buf[n]))
                       return (NULL);
               buf[n] = toupper((unsigned char)buf[n]);
       }
       buf[3] = '\0';
       return (buf);
}

int
bsd_label(int fd, u_int offset)
{
       u_char          *bblk;
       u_int           nsec;
       int             rv;

       nsec = (BBMINSIZE + (DEV_BSIZE - 1)) / DEV_BSIZE;
       bblk = disk_read(fd, offset, nsec);
       if (bblk) {
               u_int   *end, *p;

               end = (u_int *)&bblk[BBMINSIZE - sizeof(struct disklabel)];
               rv  = 1;
               for (p = (u_int *)bblk; p < end; ++p) {
                       struct disklabel *dl = (struct disklabel *)&p[1];
                       if (  (  (p[0] == NBDAMAGIC && offset == 0)
                             || (p[0] == AHDIMAGIC && offset != 0)
                             || (u_char *)dl - bblk == 7168
                             )
                          && dl->d_npartitions <= MAXPARTITIONS
                          && dl->d_magic2 == DISKMAGIC
                          && dl->d_magic  == DISKMAGIC
                          && dkcksum(dl)  == 0
                          )    {
                               rv = 0;
                               break;
                       }
               }
               free(bblk);
       }
       else rv = -1;

       return(rv);
}

int
dkcksum(struct disklabel *dl)
{
       u_short *start, *end, sum = 0;

       start = (u_short *)dl;
       end   = (u_short *)&dl->d_partitions[dl->d_npartitions];
       while (start < end)
               sum ^= *start++;
       return(sum);
}

void
ahdi_cksum(void *buf)
{
       unsigned short  *p = (unsigned short *)buf;
       unsigned short  csum = 0;
       int             i;

       p[255] = 0;
       for(i = 0; i < 256; i++)
               csum += *p++;
       *--p = (0x1234 - csum) & 0xffff;
}


u_int
ahdi_getparts(fd, ptable, rsec, esec)
       int                     fd;
       ptable_t                *ptable;
       u_int                   rsec,
                               esec;
{
       struct ahdi_part        *part, *end;
       struct ahdi_root        *root;
       u_int                   rv;

       root = disk_read(fd, rsec, 1);
       if (!root) {
               rv = rsec + (rsec == 0);
               goto done;
       }

       if (rsec == AHDI_BBLOCK)
               end = &root->ar_parts[AHDI_MAXRPD];
       else end = &root->ar_parts[AHDI_MAXARPD];
       for (part = root->ar_parts; part < end; ++part) {
               u_int   id;

               memcpy(&id, &part->ap_flg, sizeof (id));
               if (!(id & 0x01000000))
                       continue;
               if ((id &= 0x00ffffff) == AHDI_PID_XGM) {
                       u_int   offs = part->ap_st + esec;
                       rv = ahdi_getparts(fd, ptable, offs,
                                       esec == AHDI_BBLOCK ? offs : esec);
                       if (rv)
                               goto done;
               } else {
                       part_t  *p;
                       u_int   i = ++ptable->nparts;
                       ptable->parts = realloc(ptable->parts,
                                               i * sizeof *ptable->parts);
                       if (ptable->parts == NULL) {
                               fprintf(stderr, "Allocation error\n");
                               rv = 1;
                               goto done;
                       }
                       p = &ptable->parts[--i];
                       memcpy(&p->id, &id, sizeof (id));
                       p->start = part->ap_st + rsec;
                       p->end   = p->start + part->ap_size - 1;
                       p->rsec  = rsec;
                       p->rent  = part - root->ar_parts;
                       p->mod   = 0;
               }
       }
       rv = 0;
done:
       free(root);
       return(rv);
}

void *
disk_read(fd, start, count)
       int     fd;
       u_int   start,
               count;
{
       char    *buffer;
       off_t   offset;
       size_t  size;

       size   = count * DEV_BSIZE;
       offset = start * DEV_BSIZE;
       if ((buffer = malloc(size)) == NULL)
               errx(1, "No memory");

       if (lseek(fd, offset, SEEK_SET) != offset) {
               free(buffer);
               err(1, "Seek error");
       }
       if (read(fd, buffer, size) != size) {
               free(buffer);
               err(1, "Read error");
               exit(1);
       }
       return(buffer);
}

void
update_disk(ptable_t *ptable, int fd, int pno)
{
       struct ahdi_root        *root;
       struct ahdi_part        *apart;
       part_t                  *lpart;
       u_int                   rsec;
       int                     i;

       rsec = ptable->parts[pno].rsec;
       root = disk_read(fd, rsec, 1);

       /*
        * Look for additional mods on the same sector
        */
       for (i = 0; i < ptable->nparts; i++) {
               lpart = &ptable->parts[i];
               if (lpart->mod && (lpart->rsec == rsec)) {
                       apart = &root->ar_parts[lpart->rent];

                       /* Paranoia.... */
                       if ((lpart->end - lpart->start + 1) != apart->ap_size)
                               errx(1, "Updating wrong partition!");
                       apart->ap_id[0] = lpart->id[0];
                       apart->ap_id[1] = lpart->id[1];
                       apart->ap_id[2] = lpart->id[2];
                       lpart->mod = 0;
               }
       }
       if (rsec == 0)
               ahdi_cksum(root);
       disk_write(fd, rsec, 1, root);
       free(root);
}

void
disk_write(fd, start, count, buf)
       int     fd;
       u_int   start,
               count;
       void    *buf;
{
       off_t   offset;
       size_t  size;

       size   = count * DEV_BSIZE;
       offset = start * DEV_BSIZE;

       if (lseek(fd, offset, SEEK_SET) != offset)
               err(1, "Seek error");
       if (write(fd, buf, size) != size)
               err(1, "Write error");
}