/* extipl.c for FreeBSD/Linux
 *	install and testing extended IPL
 *				Auther: KIMURA Takamichi<takamiti@tsden.org>
 *				last update : 1999/04/05
 *
 * CAUTION:
 *   TAKE CARE!
 *   I(WE) HAVE NOT SHARE IN THE TROUBLES, RUNNING BY YOUR OWN RISKS.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <fcntl.h>
#include <errno.h>
#include <unistd.h>
#include <termios.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <regex.h>
#ifdef __linux__
#include <linux/unistd.h>
#endif
#include "extipl.h"
#include "extndipl.src"

#define DEVDIR			"/dev"
#define FD_DEVICE		"/dev/.*fd.*"
#define FILE_CREATE		(O_CREAT | O_TRUNC | O_RDWR)
#define FILE_MODE		(S_IRUSR | S_IWUSR)
#define FILE_RDONLY		S_IRUSR

#define TRY_NEXT		 2
#define OK			 1
#define OK_NOMSG		 0
#define ERR			-1
#define ERR_BAD_IPLMAGIC	-2
#define DEBUG_OUT		"_BootSec"
#define EMPTYNAME		""

static int  help(char *, char *);
static char *takedevname(char *);
static int  rdipl(char *, char *, struct offset_s *);
static int  wripl(char *, char *, struct offset_s *);
static int  save(char *, char *);
static int  saveipl(char *, char *);
static int  load(char *, char *, int);
static int  restore(char *, char *);
static int  fdtest(char *, char *);
static int  install(char *, char *);
static int  clrboot(char *, char *);
static struct sysident_s *get_sysidnt(int);
static struct tblchain_s *read_bootsec(char *, struct offset_s *, int *);
static void print_table(struct tblchain_s *, int, int);
static int  showtbl(char *, char *);
static int  makeboot(char *, struct tblchain_s *, struct tblchain_s *, int);
static int  setboot(char *, char *);
static void tblsort(struct partition_s *);
static void tblpack(struct partition_s *);
static char *ask(char *);
static int  sure(char *);
static void hexdump(char *, int);
#ifdef __linux__
static _syscall5(int, _llseek,  uint,  fd, ulong, hi, ulong, lo, loff_t *, res, uint, wh);
static int long_seek(int, off_t, int);
#endif

static struct cmdtab {
	char *command;
	int (*func)(char *, char *);
	char *help;
} cmdtab[] =
    {
	{"install", install,  "install 'extended_IPL' on specified device" },
	{"fdtest",  fdtest,   "testing 'extended_IPL' from diskette" },
	{"save",    saveipl,  "save current IPL code to file" },
	{"restore", restore,  "restore last IPL code from file"},
	{"show",    showtbl,  "show partition table"},
	{"chgboot", setboot,  "change bootable partition"},
	{"clrboot", clrboot,  "clear active flag"},
	{"help",    help,     "show this"},
	{NULL, NULL, NULL}
    };

static offset_s mbr;
static int opt_sort = 0;
static int opt_pack = 0;
static int opt_debug = 0;
static int opt_force = 0;

main(argc, argv)
int argc;
char **argv;
{
    char *device;
    int n, r;

    while ((n = getopt(argc, argv, "dfFhsp?")) != EOF) {
	switch(n) {
	case 'h':
	case '?': exit(help(NULL, NULL));
	case 's': opt_sort = 1; break;
	case 'p': opt_pack = 1; break;
	case 'd': opt_debug = 1; break;
	case 'f': opt_force = 1; break;
	case 'F': opt_force = 2; break;
	default:  exit(1);
	}
    }
    argc -= optind;
    argv += optind;

    if (argc < 2) exit(help(NULL, NULL));
    if (opt_sort && opt_pack) opt_sort = 0;
#ifdef DEBUG
    opt_debug = 1;
#endif

    mbr.chs.head = mbr.chs.cyl = 0;
    mbr.chs.sect = 1;
    mbr.lba = mbr.base = 0;

    for(n = 0; cmdtab[n].command != NULL; n++) {
	if (strcmp(*argv, cmdtab[n].command) == 0) {
	    argv++;
	    device = takedevname(*argv);
	    argv++;
	    /* printf("cmd =[%d], target = %s, opt = %s\n", n, device, *argv); */
	    r = (*(cmdtab[n].func))(device, *argv);
	    if (r != OK_NOMSG) {
		fprintf(stderr, "%s.\n", (r == OK) ? "Ok": "Aborted");
		r = (r == OK) ? 0: r;
	    }
	    exit(r);
	}
    }
    fprintf(stderr, "extipl: Unknown command \"%s\"\n", *argv);
    help(NULL, NULL);
    exit(ERR);
}

static int help(arg1, arg2)
char *arg1, *arg2;
{
    int i;

    fprintf(stderr, "\n*** Extended IPL %s ***", VERSION);
    fprintf(stderr, "\nUsage: extipl command device-name [arg]\n");
    fprintf(stderr, "    vaild commands are:\n");
    for(i = 0; cmdtab[i].command != NULL; i++)
	fprintf(stderr,"\t%s\t- %s\n", cmdtab[i].command, cmdtab[i].help);
    return(OK_NOMSG);
}

static char *takedevname(argv)
char *argv;
{
    static char name[LBUF_SIZE];
    char *device, buf[SECTOR_SIZE];
    struct stat st;
    int fd;

    if (argv == NULL) {
	fprintf(stderr, "extipl : no device name\n");
	exit(ERR);
    }
    if (strncmp(argv, DEVDIR, strlen(DEVDIR)) == 0) {
	device = argv;
    } else {
	snprintf(name, LBUF_SIZE, "%s/%s", DEVDIR, argv);
	device = name;
    }

    if (stat(device, &st) == -1) {
	perror(device);
	exit(ERR);
    }
    if (!(st.st_mode & S_IFCHR)) {
	fprintf(stderr,"%s: not character special\n", device);
	exit(ERR);
    }
    return(device);
}

static int rdipl(device, buf, offset)
char *device, *buf;
struct offset_s *offset;
{
    int fd, n;

    if ((fd = open(device, O_RDONLY)) < 0) {
	perror(device);
	return(ERR);
    }
#ifdef __linux__
    long_seek(fd, (off_t)offset->lba, SEEK_SET);
#else
    lseek(fd, (off_t)offset->lba * SECTOR_SIZE, SEEK_SET);
#endif
    n = read(fd, buf, SECTOR_SIZE);
    close(fd);
    return(n);
}

static int wripl(device, buf, offset)
char *device, *buf;
struct offset_s *offset;
{
    int fd, r;

    if (opt_sort) {
	tblsort((struct partition_s *)(buf + IPL_SIZE));
    } else if (opt_pack) {
	tblpack((struct partition_s *)(buf + IPL_SIZE));
    }

    if (opt_debug) {
	char debugout[LBUF_SIZE];
	sprintf(debugout, "%s.%10ld", DEBUG_OUT, offset->lba);
	printf("<<DEBUG MODE: write to \"%s\">>\n", debugout);
	if ((fd = open(debugout, FILE_CREATE, FILE_MODE)) < 0) {
	    perror(debugout);
	    return(ERR);
	}
    } else {
	if ((fd = open(device, O_RDWR)) < 0) {
	    perror(device);
	    return(ERR);
	}
#ifdef __linux__
    long_seek(fd, (off_t)offset->lba, SEEK_SET);
#else
    lseek(fd, (off_t)offset->lba * (off_t)SECTOR_SIZE, SEEK_SET);
#endif
    }
    r = (write(fd, buf, SECTOR_SIZE) == SECTOR_SIZE) ? OK: ERR;
    close(fd);
    sync();
    return(r);
}

static int save(file, buf)
char *file, *buf;
{
    int fd, i;

    if (*file == 0) {
	i = 1;
	do {
	    sprintf(file, "fdiskIPL.%03d", i++);
	    if (i > 999) return(ERR);
	} while (access(file, F_OK) == 0);
    }
    fd = open(file, FILE_CREATE, FILE_MODE);
    if (fd < 0 || write(fd, buf, SECTOR_SIZE) != SECTOR_SIZE) {
	perror(file);
	return(ERR);
    }
    fchmod(fd, FILE_RDONLY);
    close(fd);
    return(OK);
}

static int saveipl(device, arg)
char *device, *arg;
{
    char sectbuf[SECTOR_SIZE];

    if (rdipl(device, sectbuf, &mbr) != SECTOR_SIZE)
	return(ERR);
    return( save(((arg == NULL) ? "master.ipl": arg), sectbuf) );
}

static int load(file, buf, len)
char *file, *buf;
int len;
{
    int	fd, r;

    fd = open(file, O_RDONLY);
    if (fd < 0 || (r = read(fd, buf, len)) <= 0) {
	perror(file);
	return(ERR);
    }
    close(fd);
    return(r);
}

static int fdtest(device, arg)
char *device, *arg;
{
    char sectbuf[SECTOR_SIZE];
    int n;
    regex_t exp;

    if ((n = regcomp(&exp, FD_DEVICE, REG_EXTENDED | REG_NOSUB)) == 0) {
	n = regexec(&exp, device, (size_t)0, NULL, 0);
	regfree(&exp);
    }
    if (n != 0) {
	printf("\"%s\" is correct name of floppy device", device);
	if (!sure(NULL)) return(ERR);
    }

    memset(sectbuf + IPL_SIZE, 0, TABLE_SIZE);
    if (arg == NULL) {
	n = sizeof(fdtestIPL);
	if (n > SECTOR_SIZE) n = SECTOR_SIZE;
	memcpy(sectbuf, fdtestIPL, n);
    } else {
	n = load(arg, sectbuf, SECTOR_SIZE);	
	if (n <= 0) return(ERR);
    }
    *(unsigned short *)(sectbuf + IPL_MAGIC_POS) = IPL_MAGIC;

    printf("Please insert blank diskette in FLOPPY unit.\n");
    printf("Write Extended IPL to \"%s\"", device);
    return(sure(NULL) ? wripl(device, sectbuf, &mbr) : ERR);
}

static int install(device, arg)
char *device, *arg;
{
    char *name, sectbuf[SECTOR_SIZE], iplcode[SECTOR_SIZE];
    char tmpname[LBUF_SIZE];
    int n;

    if (rdipl(device, sectbuf, &mbr) != SECTOR_SIZE)
	return(ERR);

    if (*(unsigned short *)(sectbuf + IPL_MAGIC_POS) != IPL_MAGIC) {
	fprintf(stderr, "%s: IPL Magic not found\n", device);
	return(ERR);
    }

    if (arg == NULL) {
	n = sizeof(extendedIPL);
	if (n > IPL_SIZE) n = IPL_SIZE;
	memcpy(iplcode, extendedIPL, n);
    } else {
	n = load(arg, iplcode, IPL_SIZE);
	if (n <= 0) return(ERR);
    }

    switch(opt_force) {
    case 0:
	printf("*** Before exchange the master boot program,\n");
	printf("*** You had better keep the original IPL code.\n");
	name = ask("Enter file name to save:");
	if (save(name, sectbuf) != OK)  return(ERR);
	printf("Current IPL saved to '%s'.\n", name);
	printf("Install Extended-IPL to \"%s\"", device);
	if (!sure(NULL)) return(ERR);
	break;
    case 1:
	*tmpname = 0;
	save(tmpname, sectbuf);
	break;
    }
    memcpy(sectbuf, iplcode, n);
    return(wripl(device, sectbuf, &mbr));
}

static int restore(device, arg)
char *device, *arg;
{
    char buff[SECTOR_SIZE], sectbuf[SECTOR_SIZE];
    char *wrbuf, *p;
    int i;

    if (arg == NULL || load(arg, buff, SECTOR_SIZE) != SECTOR_SIZE)
	return(ERR);

    if (*(unsigned short *)(buff + IPL_MAGIC_POS) != IPL_MAGIC) {
	fprintf(stderr, "%s: illegal IPL Magic number\n", arg);
	return(ERR);
    }

    if (opt_debug) {
	for(p = buff + IPL_SIZE, i = 0; i < 4; i++, p +=TBL_ENTRY_SIZE)
	    hexdump(p, TBL_ENTRY_SIZE);
	hexdump(buff + 510, 2);
    }

    if (rdipl(device, sectbuf, &mbr) != SECTOR_SIZE)
	return(ERR);

    wrbuf = sectbuf;
    printf("\nC)ode:  restore ipl code only");
    printf("\nT)able: restore partition table only");
    printf("\nA)ll:   restore ipl code and partition table");
    switch(tolower(*ask("\n  Restore(c/t/a)?"))) {
    case 'c':
	printf("\nRestore ipl code only");
	memcpy(sectbuf, buff, IPL_SIZE);
	break;
    case 't':
	printf("\nRestore partition table only");
	memcpy(sectbuf + IPL_SIZE, buff + IPL_SIZE, TABLE_SIZE);
	break;
    case 'a':
	printf("\nOver write whole data");
	wrbuf = buff;
	break;
    default:
	return(ERR);
    }
    return(sure(NULL) ? wripl(device, wrbuf, &mbr) : ERR);
}

static int clrboot(device, arg)
char *device, *arg;
{
    char diskbuf[SECTOR_SIZE];
    struct partition_s *p;
    int i;

    if (rdipl(device, diskbuf, &mbr) != SECTOR_SIZE)
	return(ERR);

    p = (struct partition_s *)(diskbuf + IPL_SIZE);
    for(i = 0; i < NR_PARTITION; i++, p++) {
	p->bootind &= 0x7f;
    }
    return(wripl(device, diskbuf, &mbr));
}

static struct sysident_s *get_sysidnt(id)
int id;
{
    struct sysident_s *si;

    for(si = sysident; si->id >= 0; si++) {
	if (si->id == id) break;
    }
    return(si);
}

static struct tblchain_s *read_bootsec(device, offset, nest)
char *device;
struct offset_s *offset;
int *nest;
{
    struct tblchain_s *chp;
    struct partition_s *table;
    struct offset_s position;
    int i;

    chp = (struct tblchain_s *)malloc(sizeof(struct tblchain_s));
    if (rdipl(device, chp->sector, offset) != SECTOR_SIZE) {
	free(chp);
	return(NULL);
    }
    if (*(unsigned short *)(chp->sector + IPL_MAGIC_POS) != IPL_MAGIC) {
	free(chp);
	return(NULL);
    }
    (*nest)++;
    chp->offset = *offset;
    table = (struct partition_s *)(chp->sector + IPL_SIZE);
    for(i = 0; i < NR_PARTITION; table++, i++) {
	chp->next[i] = NULL;
	if (try_recursiv(table->sysind)) {
	    position.lba = table->sector_offset + offset->base;
	    position.chs = table->start_chs;
	    position.base = offset->base == 0 ? table->sector_offset : offset->base;
	    chp->next[i] = read_bootsec(device, &position, nest);
	}
    }
    return(chp);
}

static void print_table(chain, nest, depth)
struct tblchain_s *chain;
int nest, depth;
{
    int i, j;
    struct partition_s *table;
    struct sysident_s *si;
    
    table = (struct partition_s *)(chain->sector + IPL_SIZE);
    for(i = 0; i < NR_PARTITION; i++, table++) {
	if (nest > 1) {
	    if (table->sysind == 0)  continue;
	    for(j = 1; j < nest; j++) printf("    ");
	    printf("-->");
	}
	printf("%c[%d]", table->bootind & 0x80 ? 'A' : ' ', i + 1);
	if (table->sysind == 0)
	    printf("\n");
	else {
	    si = get_sysidnt(table->sysind);
	    printf(" %02X: %s", table->sysind, (si == NULL) ? "" : si->name);
	    if (!try_recursiv(table->sysind))
		printf(" // %dMB", (int)(table->nr_sector >> 11));
	    printf("\n");
	    if ((depth <= 0 || nest < depth)
		  && try_recursiv(table->sysind) && chain->next[i] != NULL)
		print_table(chain->next[i], nest + 1, depth);
	}
    }
}

static void free_chain(chp)
struct tblchain_s *chp;
{
    int i;

    for(i = 0; i < NR_PARTITION; i++) {
	if (chp->next[i] != NULL) {
	    free_chain(chp->next[i]);
	    free(chp->next[i]);
	}
    }
}

static int showtbl(device, arg)
char *device, *arg;
{
    struct tblchain_s *chain;
    int depth, n;

    n = 0;
    if ((chain = read_bootsec(device, &mbr, &n)) == NULL)  return(ERR);
    if (arg == NULL)
	depth = 0;
    else {
	depth = atoi(arg);
        if (strcmp(arg, "all") == 0) depth = 0;
    }
    printf("=========\nPartition TABLE on \"%s\"\n=========\n", device);
    print_table(chain, 1, depth);
    free_chain(chain);
    free(chain);
    return(OK_NOMSG);
}

static int makeboot(device, base, chp, depth)
char *device;
struct tblchain_s *base, *chp;
int depth;
{
    static char *helpmsg[] = {
	"1 .. 4 : specified bootable partition number",
	"     c : Clear all bootable flag",
	"     w : Write this table and quit",
	"     q : Quit without write",
	"     b : Back one step",
	NULL
	};
    struct partition_s *table, *active;
    struct sysident_s *si;
    int i, n, r, help, loop;
    int key;

    loop = 1;
    help = 0;
    while(loop) {
	printf("\n=========\nPartition TABLE on \"%s\"\n=========\n", device);
	print_table(base, 1, depth);
	if (help) {
	    for(help = i = 0; helpmsg[i] != NULL; i++)
		printf("\n%s.", helpmsg[i]);
	}
	key = *ask("\n>>> Select partition to make bootable (? for help):");
	table = (struct partition_s *)(chp->sector + IPL_SIZE);
	switch(key) {
	case '?':
	case 'h':
	    help = 1;
	    break;
	case '1':
	case '2':
	case '3':
	case '4':			 	/* make bootable */
	    n = key - '1';
	    si = get_sysidnt((table + n)->sysind);
	    if (si->bootable < 0) {
		printf("\n\007Partition #%d(%s), can not make bootable.\n", n+1, si->name);
		(void)ask("Hit <Enter> to continue.. ");
		break;
	    }
	    if ((table + n)->bootind & 0x80) {
		active = table + n;
	    } else {
		if (si->bootable == 0) {
		    printf("\nPartition #%d(%s) specified, ", n + 1, si->name);
		    if (!sure("Can you make bootable(y/n)?")) break;
		}
		for(i = 0; i < NR_PARTITION; i++, table++) {
		    if (table->sysind == 0) continue;
		    table->bootind &= 0x7f;
		    if (i == n) {
			table->bootind |= 0x80;
			active = table;
		    }
		}
	    }
	    if (try_recursiv(active->sysind) && chp->next[n] != NULL) {
		key = makeboot(device, base, chp->next[n], depth + 1);
		if (key == 'w' || key == 'q' || key == ERR) loop = 0;
	    }
	    break;
	case 'c':				/* clear bootable */
	    for(i = 0; i < NR_PARTITION; i++, table++) {
		if (table->sysind == 0) continue;
		table->bootind &= 0x7f;
	    }
	    break;
	case 'w':
	case 'q':
	case 'b':
	case ERR:
	    loop = 0;
	    break;
	default:
	    bell();
	    break;
	}	
    }
    if (key == 'w' && wripl(device, chp->sector, &chp->offset) == ERR)
	key = ERR;
    return(key);
}

static int setboot(device, arg)
char *device, *arg;
{
    struct tblchain_s *chain;
    int n, r;

    n = 0;
    if ((chain = read_bootsec(device, &mbr, &n)) == NULL)
	return(ERR);
    r = makeboot(device, chain, chain, 1);
    free_chain(chain);
    free(chain);
    return((r == ERR) ? ERR: ((r == 'w') ? OK: OK_NOMSG));
}

static void tblsort(table)
struct partition_s *table;
{
    struct partition_s *p, q;
    int n = NR_PARTITION;

    printf("sorting partition table\n");
    do {
	for(p = table; p < table + NR_PARTITION - 1; p++) {
	    if (p[0].sysind == 0 || 
	       (p[0].sector_offset > p[1].sector_offset && p[1].sysind != 0)) {
			q = p[0];
			p[0] = p[1];
			p[1] = q;
	    }
	}
    } while(--n > 0);
}

static void tblpack(table)
struct partition_s *table;
{
    struct partition_s *p, q;
    int n = NR_PARTITION;

    printf("packing partition table\n");
    do {
	for(p = table; p < table + NR_PARTITION - 1; p++) {
	    if (p[0].sysind == 0 || p[1].sysind != 0) {
		q = p[0];
		p[0] = p[1];
		p[1] = q;
	    }
	}
    } while(--n > 0);
}

static char *ask(prompt)
char *prompt;
{
    static char lbuf[LBUF_SIZE];
    char *p, *q;

    printf("%s ", prompt);
    fflush(stdout);
    *lbuf = 0;
    fgets(lbuf, LBUF_SIZE, stdin);
    *(lbuf + strlen(lbuf) - 1) = 0;
    p = lbuf;
    while(*p && isspace(*p)) p++;
    q = p;
    while(*p && !isspace(*p)) p++;
    if (isspace(*p)) *p = 0;
    return(q);
}

static int sure(s)
char *s;
{
    return(tolower(*ask((s == NULL) ? " ... Sure(y/n)?" : s)) == 'y');
}

static void hexdump(buf, len)
char *buf;
int len;
{
    int i;
    printf("> %02X", *buf & 0xff);
    for(i = 1; i < len; i++)
	printf(" %02X", *(++buf) & 0xff);
    printf("\n");
}

#ifdef __linux__
/* Hacked by Taketoshi Sano <xlj06203@nifty.ne.jp>  */
static int long_seek(fd, offset, whence)
int fd, whence;
off_t offset;
{
    loff_t loffset, result;
    unsigned long loff_hi, loff_lo;

    loffset = (loff_t)offset << 9;
    loff_hi = (unsigned long)(loffset>>32);
    loff_lo = (unsigned long)(loffset & 0xffffffff);
    if (opt_debug) {
	fprintf(stderr, " sector: %lu, loffset: %Lu, loff_hi: %lu, loff_lo: %lu\n",
		offset, loffset, loff_hi, loff_lo);
	fflush(stderr);
    }
    if(_llseek(fd, loff_hi, loff_lo, &result, whence) != 0) {
	perror("llseek");
        return(ERR);
    }
    if (opt_debug) {
	loffset = (unsigned long long)result >> 9;
	fprintf(stderr, " result: %Lu, sector: %Lu\n", result, loffset);
	fflush(stderr);
    }
    return(OK);
}
#endif
