/*
 * Copyright 1995,96 Thierry Bousch
 * Licensed under the Gnu Public License, Version 2
 *
 * $Id: Cyclic.c,v 2.8 1996/09/14 09:39:13 bousch Exp $
 *
 * Arithmetics on Z/nZ, where n can be prime or not.
 */

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "saml.h"
#include "saml-errno.h"
#include "mnode.h"
#include "builtin.h"
#include "mp-arch.h"

typedef struct _cyclic {
	struct mnode_header hdr;
	struct _cyclic *next;
	__u32 n, modulus;
} cyclic_mnode;

/*
 * Define BINARY_HASHSIZE if you want the hash size to be always a power
 * of two. It avoids an expensive division for each hash calculation.
 * But maybe other moduli have better mixing properties.
 */
#define BINARY_HASHSIZE

#ifdef BINARY_HASHSIZE
#define INITIAL_HASHSIZE  64
#else
#define INITIAL_HASHSIZE  59
#endif

static cyclic_mnode **htable;
static unsigned int hashsize = 0;
static unsigned int entries = 0;

static s_mnode* cyclic_build (const char*);
static s_mnode* cyclic_new (__u32 n, __u32 mod);
static void cyclic_free (cyclic_mnode*);
static gr_string* cyclic_stringify (cyclic_mnode*);
static s_mnode* cyclic_add (cyclic_mnode*, cyclic_mnode*);
static s_mnode* cyclic_sub (cyclic_mnode*, cyclic_mnode*);
static s_mnode* cyclic_mul (cyclic_mnode*, cyclic_mnode*);
static int cyclic_notzero (cyclic_mnode*);
static s_mnode* cyclic_zero (cyclic_mnode*);
static s_mnode* cyclic_negate (cyclic_mnode*);
static s_mnode* cyclic_one (cyclic_mnode*);
static s_mnode* cyclic_invert (cyclic_mnode*);
static s_mnode* cyclic_sqrt (cyclic_mnode*);
static s_mnode* int2cyclic (s_mnode*, cyclic_mnode*);

static unsafe_s_mtype MathType_Cyclic = {
	"CyclicInt",
	cyclic_free, cyclic_build, cyclic_stringify,
	NULL, NULL,
	cyclic_add, cyclic_sub, cyclic_mul, mn_std_div, mn_field_gcd,
	cyclic_notzero, NULL, NULL, mn_std_differ, NULL,
	cyclic_zero, cyclic_negate, cyclic_one, cyclic_invert,
	cyclic_sqrt
};

static inline int hash (__u32 x, __u32 mod)
{
#ifdef BINARY_HASHSIZE
	return (x ^ mod) & (hashsize - 1);
#else
	return (x ^ mod) % hashsize;
#endif
}

static void resize_htable (unsigned int new_size)
{
	cyclic_mnode *list, *p, *q;
	unsigned int i, h;

	list = NULL;
	for (i = 0; i < hashsize; i++)
		for (p = htable[i]; p; p = q) {
			q = p->next;
			p->next = list;
			list = p;
		}
	htable = realloc(htable, new_size * sizeof(cyclic_mnode*));
	if (htable == NULL)
		panic_out_of_memory();
	hashsize = new_size;
	memset(htable, 0, hashsize * sizeof(cyclic_mnode*));

	/* And insert them again in the new table */
	for (p = list; p; p = q) {
		q = p->next;
		h = hash(p->n, p->modulus);
		p->next = htable[h];
		htable[h] = p;
	}
}

void init_MathType_Cyclic (void)
{
	register_mtype(ST_CYCLIC, (s_mtype*)&MathType_Cyclic);
	resize_htable(INITIAL_HASHSIZE);
	register_CV_routine(ST_INTEGER, ST_CYCLIC, (void*)int2cyclic);
}

static inline __u32 product_mod (__u32 x1, __u32 x2, __u32 p)
{
	__u32 th, tl, quot, rem;
	
	umul_ppmm(th, tl, x1, x2);
	udiv_qrnnd(quot, rem, th, tl, p);
	return rem;
}

static __u32 power_mod (__u32 x, __u32 e, __u32 p)
{
	__u32 f = 1;

	while(1) {
		/* The value of f.pow(x,e) is a loop invariant */
		if (e&1)
		    f = product_mod(f,x,p);
		e = e/2;
		if (!e)
		    return f;
		x = product_mod(x,x,p);
	}
}

static s_mnode* cyclic_new (__u32 x, __u32 mod)
{
	cyclic_mnode *c;
	int h = hash(x,mod);

	for (c = htable[h]; c; c = c->next)
		if (c->n == x && c->modulus == mod)
			return copy_mnode((s_mnode*)c);
	/*
	 * Not found, create a new one
	 */
	c = (cyclic_mnode*) __mnalloc(ST_CYCLIC, sizeof(cyclic_mnode));
	c->n = x;
	c->modulus = mod;
	c->next = htable[h];
	htable[h] = c;

	if (++entries > hashsize) {
		int new_size;
#ifdef BINARY_HASHSIZE
		new_size = 2 * hashsize;
#else
		new_size = 2 * hashsize + 1;
#endif
		resize_htable(new_size);
	}
	return (s_mnode*) c;
}

static void cyclic_free (cyclic_mnode* c)
{
	int h = hash(c->n, c->modulus);
	cyclic_mnode *d, **old;

	for (old = &htable[h]; (d = *old) != NULL; old = &(d->next))
		if (c == d) {
			*old = d->next;
			break;
		}
	assert(c == d);
	free(c);
	--entries;
}

static s_mnode* cyclic_build (const char *str)
{
	unsigned int x, mod;

	if (sscanf(str, "%u:%u", &x, &mod) == 2 && mod > 1) {
		x = x % mod;
		return cyclic_new(x, mod);
	}
	return mnode_error(SE_STRING, "cyclic_build");
}

static s_mnode* cyclic_zero (cyclic_mnode* model)
{
	return cyclic_new(0, model->modulus);
}

static s_mnode* cyclic_one (cyclic_mnode* model)
{
	return cyclic_new(1, model->modulus);
}

static s_mnode* int2cyclic (s_mnode* intg, cyclic_mnode* model)
{
	__u32 x, modulo;
	s_mnode *t1, *t2, *t3;
	gr_string *rem;

	assert(intg->type == ST_INTEGER);
	if (!model)
		return mnode_error(SE_ICAST, "int2cyclic");

	modulo = model->modulus;
	t1 = mnode_build(ST_INTEGER, u32toa(modulo));
	t2 = mnode_mod(intg, t1);
	if (mnode_isneg(t2)) {
		t3 = mnode_add(t2, t1);
		unlink_mnode(t2);
		t2 = t3;
	}
	unlink_mnode(t1);
	rem = mnode_stringify(t2); unlink_mnode(t2);
	rem = grs_append1(rem, '\0');
	x = strtoul(rem->s, 0, 10); free(rem);
	return cyclic_new(x, modulo);
}

static gr_string* cyclic_stringify (cyclic_mnode* c)
{
	gr_string *grs = new_gr_string(30);
	sprintf(grs->s, "%u", c->n);
	grs->len = strlen(grs->s);
	return grs;
}

static s_mnode* cyclic_add (cyclic_mnode* c1, cyclic_mnode* c2)
{
	__u32 n1, n2, n3, m;

	if ((m = c1->modulus) == c2->modulus) {
		n1 = c1->n;
		n2 = c2->n;
		n3 = n1 + n2;
		if (n3 < n1 || n3 >= m)
			return cyclic_new(n3-m, m);
		return cyclic_new(n3, m);
	}
	return mnode_error(SE_NSMOD, "cyclic_add");
}

static s_mnode* cyclic_sub (cyclic_mnode* c1, cyclic_mnode* c2)
{
	__u32 n1, n2, n3, m;

	if ((m = c1->modulus) == c2->modulus) {
		n1 = c1->n;
		n2 = c2->n;
		n3 = n1 - n2;
		if (n3 > n1)
			return cyclic_new(n3+m, m);
		return cyclic_new(n3, m);
	}
	return mnode_error(SE_NSMOD, "cyclic_sub");
}

static s_mnode* cyclic_mul (cyclic_mnode* c1, cyclic_mnode* c2)
{
	__u32 m, prod;

	if ((m = c1->modulus) == c2->modulus) {
		prod = product_mod(c1->n, c2->n, m);
		return cyclic_new(prod, m);
	}
	return mnode_error(SE_NSMOD, "cyclic_mul");
}

static int cyclic_notzero (cyclic_mnode* c)
{
	return (c->n != 0);
}

static s_mnode* cyclic_negate (cyclic_mnode* c)
{
	__u32 x = (c->n), m = (c->modulus);

	if (x == 0)
		return copy_mnode((s_mnode*)c);
	return cyclic_new(m-x, m);
}

#if 0
/*
 * Returns g = gcd(a,b) and fills x and y with numbers such that
 * ax - by == g, with 0 <= x <= b and 0 <= y <= a (at least if a and b
 * are non-zero)
 */

static int solve_linear (int a, int b, int *x, int *y)
{
	int g, q, r, u, v;

	if (b == 0) {
		*x = 1;
		*y = 0;
		return a;
	}
	q = a / b;
	r = a % b;
	g = solve_linear(b, r, &u, &v);
	*x = b - v;
	*y = a - (u + q * v);
	return g;
}
#endif

static s_mnode* cyclic_invert (cyclic_mnode *c)
{
	__u32 x = (c->n), p = (c->modulus);
	x = power_mod(x, p-2, p);
	return cyclic_new(x, p);
}

static s_mnode* cyclic_sqrt (cyclic_mnode* n)
{
	return mnode_error(SE_NOTRDY, "cyclic_sqrt");
}
