/* z-socket.c: compressed socket library
   This has nothing to do with cryptography.
   Copyright (C) 1998 Paul Sheer

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */


#include "mostincludes.h"
#include <sys/types.h>
#if defined(HAVE_UNISTD_H)
#include <unistd.h>
#endif
#include <netdb.h>		/* struct hostent */
#include <sys/socket.h>		/* AF_INET */
#include <netinet/in.h>		/* struct in_addr */
#ifdef HAVE_SETSOCKOPT
#include <netinet/ip.h>		/* IP options */
#endif
#include <arpa/inet.h>
#ifndef SYS_TIME_H
#include <sys/time.h>		/* alex: this redefines struct timeval */
#endif				/* SCO_FLAVOR */
#ifdef HAVE_SYS_PARAM_H
#include <sys/param.h>
#endif
#include "zlib/zlib.h"
#include "z-socket.h"
#include "diffie-socket.h"
#include "diffie/compat.h"
#include "src/mad.h"

#define MAX_FDS_PER_CONN 32

static int recv_all (int s, unsigned char *buf, int len, unsigned int flags)
{
    int count;
    int total = 0;
    while (len > 0) {
	count = recv (s, buf, len, flags);
	if (count == -1 && errno == EINTR)
	    continue;
	if (count <= 0)
	    return -1;
	buf += count;
	len -= count;
	total += count;
    }
    return total;
}

static int send_all (int s, unsigned char *buf, int len, unsigned int flags)
{
    int count;
    int total = 0;
    while (len > 0) {
	count = send (s, buf, len, flags);
	if (count == -1 && errno == EINTR)
	    continue;
	if (count <= 0)
	    return -1;
	buf += count;
	len -= count;
	total += count;
    }
    return total;
}

typedef struct zsocket {
/* two encryption streams: one for reading, and one for writing */
    z_stream z_read;
    z_stream z_write;
    int fd[MAX_FDS_PER_CONN];
    int n_fd;
    int shutdown;
    unsigned char *buf;
    long deflate_time;
    long write_time;
    int n_writes;
    int level;
    struct zsocket *prev;
    struct zsocket *next;
} ZSocket;

/* list of live connections */
static ZSocket *connections = 0;

static ZSocket *z_socket_index (int fd)
{
    ZSocket *c;
    for (c = connections; c; c = c->next) {
	int i;
	for (i = 0; i < c->n_fd; i++)
	    if (c->fd[i] == fd)
		return c;
    }
    return 0;
}

static int z_socket_flags = Z_SOCKET_COMPRESSION_OFF;

void z_socket_set_flags (unsigned int flags)
{
    z_socket_flags = flags;
}

unsigned int z_socket_get_flags (void)
{
    return z_socket_flags;
}

void z_socket_set_compression (int on)
{
    if (on)
	z_socket_set_flags (0);
    else
	z_socket_set_flags (Z_SOCKET_COMPRESSION_OFF);
}

int z_socket_accept (int sock, struct sockaddr *addr, unsigned int *addrlen)
{
    int fd;
    fd = accept (sock, addr, (unsigned int *) addrlen);
    if (fd < 0)
	return fd;
    return z_socket_accept_fd (fd);
}

int z_socket_accept_fd (int fd)
{
    ZSocket *c = 0;
    char buf[Z_SOCKET_MAGIC_LEN];
    if (fd < 0)
	return fd;
    recv (fd, buf, Z_SOCKET_MAGIC_LEN, MSG_PEEK);
    if (strncmp ((char *) buf, Z_SOCKET_MAGIC, Z_SOCKET_MAGIC_LEN))
	return fd;
    if (recv_all (fd, (unsigned char *) buf, Z_SOCKET_MAGIC_LEN, 0) != Z_SOCKET_MAGIC_LEN) {
	close (fd);
	return -1;
    }
    c = malloc (sizeof (ZSocket));
    memset (c, 0, sizeof (ZSocket));
    c->next = connections;
    if (c->next)
	c->next->prev = c;
    connections = c;
    c->fd[c->n_fd++] = fd;
    c->level = 6;
    deflateInit (&c->z_write, c->level);
    inflateInit (&c->z_read);
    return fd;
}

int z_socket_connect (int fd, struct sockaddr *addr, int addrlen)
{
    ZSocket *c = 0;
    int result;
    result = connect (fd, addr, addrlen);
    if (result < 0)
	return result;
    if ((z_socket_flags & Z_SOCKET_COMPRESSION_OFF))
	return result;
    if (send (fd, Z_SOCKET_MAGIC, Z_SOCKET_MAGIC_LEN, 0) != Z_SOCKET_MAGIC_LEN) {
	close (fd);
	return -1;
    }
    c = malloc (sizeof (ZSocket));
    memset (c, 0, sizeof (ZSocket));
    c->next = connections;
    if (c->next)
	c->next->prev = c;
    connections = c;
    c->fd[c->n_fd++] = fd;
    c->level = 6;
    deflateInit (&c->z_write, c->level);
    inflateInit (&c->z_read);
    return fd;
}

static void z_socket_set_level (ZSocket * c, int level)
{
    if (level == c->level)
	return;
    deflateEnd (&c->z_write);
    deflateInit (&c->z_write, level);
}

static void adjust_compression (ZSocket * c, long write_time, long deflate_time, int len)
{
/* these fudge values will provide estimates of average deflate vs write time */
    if (len > 100) {
	c->write_time += write_time;
	c->write_time = (c->write_time << 2) / 5;
	c->deflate_time += deflate_time;
	c->deflate_time = (c->deflate_time << 2) / 5;
	if (c->n_writes++ > 19 && !(c->n_writes % 10)) {
	    if (c->write_time > c->deflate_time / 20) {
/* writing is bottlenecking - increase compression level */
		c->level = (c->level == 9) ? 9 : (c->level + 1);
		z_socket_set_level (c, c->level);
#ifdef Z_DEBUG
		printf ("\nwt=%d, dt=%d, l=%d++", (int) c->write_time, (int) c->deflate_time, c->level);
#endif
	    } else if (c->write_time < c->deflate_time / 50) {
/* compression is too slow - decrease compression level */
		c->level = (c->level == 0) ? 0 : (c->level - 1);
		z_socket_set_level (c, c->level);
#ifdef Z_DEBUG
		printf ("\nwt=%d, dt=%d, l=%d--", (int) c->write_time, (int) c->deflate_time, c->level);
#endif
	    } else {
#ifdef Z_DEBUG
		printf ("\nwt=%d, dt=%d, l=%d", (int) c->write_time, (int) c->deflate_time, c->level);
#endif
	    }
	}
    }
}

/* this wraps after 36 hours on 32 bit machines */
static long get_sys_time (void)
{
#ifdef HAVE_GETTIMEOFDAY
    static long start_sec = -1;
    long result;
    struct timeval tv;
    struct timezone tz;
    gettimeofday (&tv, &tz);
    if (start_sec == -1)
	start_sec = tv.tv_sec;
    result = ((tv.tv_sec - start_sec) & 0x1FFFFL) * 10000 + tv.tv_usec / 100;
    return result;
#else
    return 0;
#endif
}

#if 0
void sym_slow (int l)
{
    usleep (l / 50);
}

#else
#define sym_slow(x)
#endif

int z_socket_send (int s, const void *msg, int len, unsigned int flags)
{
    ZSocket *c;
    unsigned char *t;
    int l;
    unsigned char k[4];
    long deflate_time;
    long write_time;
    c = z_socket_index (s);
    if (!c || (flags & MSG_OOB))
	return send (s, msg, len, flags);
    if (!len)
	return 0;
    t = malloc (len + (len >> 8) + 32);
    c->z_write.next_in = (unsigned char *) msg;
    c->z_write.avail_in = len;
    c->z_write.next_out = t;
    c->z_write.avail_out = len + (len >> 8) + 32;
    deflate_time = get_sys_time ();
    deflate (&c->z_write, Z_FULL_FLUSH);
    write_time = get_sys_time ();
    deflate_time = write_time - deflate_time;
    l = (unsigned long) c->z_write.next_out - (unsigned long) t;
    k[0] = l >> 24;
    k[1] = l >> 16;
    k[2] = l >> 8;
    k[3] = l >> 0;
    sym_slow (l);
    if (send_all (s, k, 4, 0) != 4)
	len = -1;
    else if (send_all (s, t, l, 0) != l)
	len = -1;
    write_time = get_sys_time () - write_time;
    adjust_compression (c, write_time, deflate_time, len);
    free (t);
    return len;
}

static int read_all (int s, unsigned char *buf, int len)
{
    int count;
    int total = 0;
    while (len > 0) {
	count = read (s, buf, len);
	if (count == -1 && errno == EINTR)
	    continue;
	if (count <= 0)
	    return -1;
	buf += count;
	len -= count;
	total += count;
    }
    return total;
}

static int write_all (int s, unsigned char *buf, int len)
{
    int count;
    int total = 0;
    while (len > 0) {
	count = write (s, buf, len);
	if (count == -1 && errno == EINTR)
	    continue;
	if (count <= 0)
	    return -1;
	buf += count;
	len -= count;
	total += count;
    }
    return total;
}


int z_socket_write (int s, void *msg, int len)
{
    ZSocket *c;
    unsigned char *t;
    int l;
    unsigned char k[4];
    long deflate_time;
    long write_time;
    if (!len)
	return 0;
    c = z_socket_index (s);
    if (!c)
	return write (s, msg, len);
    t = malloc (len + (len >> 8) + 32);
    c->z_write.next_in = (unsigned char *) msg;
    c->z_write.avail_in = len;
    c->z_write.next_out = t;
    c->z_write.avail_out = len + (len >> 8) + 32;
    deflate_time = get_sys_time ();
    deflate (&c->z_write, Z_FULL_FLUSH);
    write_time = get_sys_time ();
    deflate_time = write_time - deflate_time;
    l = (unsigned long) c->z_write.next_out - (unsigned long) t;
    k[0] = l >> 24;
    k[1] = l >> 16;
    k[2] = l >> 8;
    k[3] = l >> 0;
    sym_slow (l);
    if (write_all (s, k, 4) != 4)
	len = -1;
    else if (write_all (s, t, l) != l)
	len = -1;
    write_time = get_sys_time () - write_time;
    adjust_compression (c, write_time, deflate_time, len);
    free (t);
    return len;
}

int z_socket_recv (int s, void *buf, int len, unsigned int flags)
{
    ZSocket *c;
    if (!len)
	return 0;
    c = z_socket_index (s);
    if (!c || (flags & MSG_OOB))	/* not opened with z_socket */
	return recv (s, buf, len, flags);
    if (!c->z_read.avail_in) {
	int l;
	unsigned char k[4];
	if (recv_all (s, k, 4, 0) != 4)
	    return -1;
	l = ((int) k[0] << 24) | ((int) k[1] << 16) | ((int) k[2] << 8) | ((int) k[3] << 0);
	c->buf = malloc (l);
	c->z_read.next_in = c->buf;
	c->z_read.avail_in = l;
	sym_slow (l);
	if (recv_all (s, (unsigned char *) c->buf, l, 0) != l) {
	    free (c->buf);
	    c->z_read.avail_in = 0;
	    return -1;
	}
    }
    c->z_read.next_out = (unsigned char *) buf;
    c->z_read.avail_out = len;
    if (flags & MSG_PEEK) {
	int result;
	z_stream z_save;
	inflateCopy (&z_save, &c->z_read);
	inflate (&z_save, Z_FULL_FLUSH);
	result = (unsigned long) z_save.next_out - (unsigned long) buf;
	inflateEnd (&z_save);
	return result;
    }
    inflate (&c->z_read, Z_FULL_FLUSH);
    if (!c->z_read.avail_in)
	free (c->buf);
    return (unsigned long) c->z_read.next_out - (unsigned long) buf;
}

int z_socket_read (int s, void *buf, int len)
{
    ZSocket *c;
    c = z_socket_index (s);
    if (!c)			/* not opened with z_socket */
	return read (s, buf, len);
    if (!len)
	return 0;
    if (!c->z_read.avail_in) {
	int l;
	unsigned char k[4];
	if (read_all (s, k, 4) != 4)
	    return -1;
	l = ((int) k[0] << 24) | ((int) k[1] << 16) | ((int) k[2] << 8) | ((int) k[3] << 0);
	c->buf = malloc (l);
	c->z_read.next_in = c->buf;
	c->z_read.avail_in = l;
	sym_slow (l);
	if (read_all (s, c->buf, l) != l) {
	    free (c->buf);
	    c->z_read.avail_in = 0;
	    return -1;
	}
    }
    c->z_read.next_out = (unsigned char *) buf;
    c->z_read.avail_out = len;
    inflate (&c->z_read, Z_FULL_FLUSH);
    if (!c->z_read.avail_in)
	free (c->buf);
    return (unsigned long) c->z_read.next_out - (unsigned long) buf;
}

void z_socket_remove_connection (ZSocket * c, int fd)
{
    int i;
    for (i = 0; i < c->n_fd; i++)
	if (c->fd[i] == fd) {
	    memcpy (&c->fd[i], &c->fd[i + 1], (c->n_fd - i - 1) * sizeof (int));
	    c->n_fd--;
	    if (c->n_fd)	/* still fds bound to this arc */
		return;
	    break;
	}
    deflateEnd (&c->z_read);
    deflateEnd (&c->z_write);
    if (c->next)
	c->next->prev = c->prev;
    if (c->prev)
	c->prev->next = c->next;
    if ((unsigned long) c == (unsigned long) connections)
	connections = connections->next;
    free (c);
}

static int z_socket_xdup (int fd, int fd_new)
{
    ZSocket *c;
    if (fd_new < 0)
	return fd_new;
    c = z_socket_index (fd);
    if (!c)
	return fd_new;
    if (c->n_fd >= MAX_FDS_PER_CONN) {
	close (fd_new);
	errno = EMFILE;
	return -1;
    }
    c->fd[c->n_fd++] = fd_new;
    return fd_new;
}

int z_socket_dup2 (int fd, int fd_new)
{
    int f;
    f = dup2 (fd, fd_new);
    return z_socket_xdup (fd, f);
}

int z_socket_dup (int fd)
{
    int f;
    f = dup (fd);
    return z_socket_xdup (fd, f);
}

int z_socket_shutdown (int s, int how)
{
    ZSocket *c;
    int r;
    c = z_socket_index (s);
    if (!c)
	return shutdown (s, how);
    c->shutdown |= (how == 0) | ((how == 1) * 2) | ((how == 2) * 3);
    if (c->shutdown == 3)
	z_socket_remove_connection (c, s);
    r = shutdown (s, how);
    return r;
}

int z_socket_close (int s)
{
    ZSocket *c;
    c = z_socket_index (s);
    if (!c)
#ifdef WIN32
	return closesocket (s);
#else
	return close (s);
#endif
    z_socket_remove_connection (c, s);
    return close (s);
}


