//----------------------------------*-C++-*----------------------------------//
// Copyright 1998 The Regents of the University of California. 
// All rights reserved. See LEGAL.LLNL for full text and disclaimer.
//---------------------------------------------------------------------------//

#ifndef __CXX_Extensions__h
#define __CXX_Extensions__h
#include "CXX_Objects.h"

extern "C" {
    extern PyObject py_object_initializer;
}

#include STANDARD_HEADER(vector)

NAMESPACE_BEGIN(Py)

class MethodTable {
    private:
        MethodTable(const MethodTable& m); //unimplemented
        void operator=(const MethodTable& m); //unimplemented
        
    protected:
        STD::vector<PyMethodDef> t; // accumulator of PyMethodDef's
        PyMethodDef *mt; // Actual method table produced when full
        
        static PyMethodDef method (
            const char* method_name, 
            PyCFunction f, 
            int flags = 1,
            const char* doc="") 
        {
            PyMethodDef m;
            m.ml_name = const_cast<char*>(method_name);
            m.ml_meth = f;
            m.ml_flags = flags;
            m.ml_doc = const_cast<char*>(doc);
            return m;
        }
    public:
        MethodTable() {
            t.push_back(method(0, 0, 0, 0));
            mt = 0;
        }
        
        virtual ~MethodTable() {
            delete [] mt;
        }
        
        void add(const char* method_name, PyCFunction f, const char* doc="", int flag=1) {
            if (!mt) {
                t.insert (t.end()-1, method(method_name, f, flag, doc));
            }
            else {
                throw RuntimeError("Too late to add a module method!");
            }
        }
        
        PyMethodDef* table() {    
            if (!mt) {
                int t1size = t.size();
                mt = new PyMethodDef[t1size];
                int j = 0;
                for(STD::vector<PyMethodDef>::iterator i = t.begin(); i != t.end(); i++) {
                    mt[j++] = *i;
                }
            }
            return mt;
        }
}; // end class MethodTable

class ExtensionModule {
private:
    char* module_name;
    MethodTable* method_table;
    ExtensionModule(const ExtensionModule&); //unimplemented
    void operator=(const ExtensionModule&); //unimplemented
    
public:
    ExtensionModule (char* name) {
        method_table = new MethodTable;
        module_name = name;
    }
    
    virtual ~ExtensionModule () {
        delete method_table;
    }
    
    void add(const char* method_name, PyCFunction f, const char* doc="", int flag=1) {
        method_table->add (method_name, f, doc, flag);
    }
    
    // Initialize returns the new module dictionary so you can add to it.
    Dict initialize () {
        // Both Py_InitModule and PyModule_GetDict return borrowed refs
        PyObject* pm = Py_InitModule(module_name, method_table->table());
        return Dict(PyModule_GetDict(pm));
    }
    
};


template<class T>
class PythonType {
private:
    PythonType (const PythonType& tb) {};
    void operator=(const PythonType& t) {};
protected:
    PyTypeObject* table;
    PySequenceMethods* sequence_table;
    PyMappingMethods* mapping_table;
    PyNumberMethods* number_table;
    PyBufferProcs* buffer_table;
    // use to initialize methods that must be present
    // this deallocator gets replaced by the one in PythonExtension
    // if you are using that.
    static void missing_method() {
        throw RuntimeError("Extension object missing a required method.");
    }
    
    static void standard_dealloc(PyObject* p) {
        PyMem_DEL(p);
    }
    
    static int missing_method_inquiry(PyObject*) {
        missing_method();
        return 0;
    }
    
    static PyObject* missing_method_unary(PyObject*) {
        missing_method();
        return 0;
    }
    
    static PyObject* missing_method_binary(PyObject*, PyObject*) {
        missing_method();
        return 0;
    }
    
    static PyObject* missing_method_ternary(PyObject*,PyObject*,PyObject*) {
        missing_method();
        return 0;
    }
    
    static PyObject* missing_method_intargfunc(PyObject*,int) {
        missing_method();
        return 0;
    }
    
    static PyObject* missing_method_intintargfunc(PyObject*,int,int) {
        missing_method();
        return 0;
    }
    
    static int missing_method_coercion(PyObject**,PyObject**) {
        missing_method();
        return 0;
    }
    
    static int missing_method_getreadbufferproc(PyObject*, int, void**) {
        missing_method();
        return 0;
    }
    
    static int missing_method_getwritebufferproc(PyObject*,int, void**) {
        missing_method();
        return 0;
    }
    
    static int missing_method_getsegcountproc(PyObject*,int*) {
        missing_method();
        return 0;
    }
    
    void init_sequence() {
        if(!sequence_table) {
            sequence_table = new PySequenceMethods;
            table->tp_as_sequence = sequence_table;
            sequence_table->sq_length = missing_method_inquiry;
            sequence_table->sq_concat = missing_method_binary;
            sequence_table->sq_repeat = (intargfunc) missing_method_intargfunc;
            sequence_table->sq_item = (intargfunc) missing_method_intargfunc;
            sequence_table->sq_slice = (intintargfunc) missing_method_intintargfunc;
            sequence_table->sq_ass_item = 0;
            sequence_table->sq_ass_slice = 0;
        }
    }
    
    void init_mapping() {
        if(!mapping_table) {
            mapping_table = new PyMappingMethods;
            table->tp_as_mapping = mapping_table;
            mapping_table->mp_length = missing_method_inquiry;
            mapping_table->mp_subscript = missing_method_binary;
            mapping_table->mp_ass_subscript = 0;
        }
    }
    
    void init_number() {
        if(!number_table) {
            number_table = new PyNumberMethods;
            table->tp_as_number = number_table;
            number_table->nb_add = missing_method_binary;
            number_table->nb_subtract = missing_method_binary;
            number_table->nb_multiply = missing_method_binary;
            number_table->nb_divide = missing_method_binary;
            number_table->nb_remainder = missing_method_binary;
            number_table->nb_divmod = missing_method_binary;
            number_table->nb_power = missing_method_ternary;
            number_table->nb_negative = missing_method_unary;
            number_table->nb_positive = missing_method_unary;
            number_table->nb_absolute = missing_method_unary;
            number_table->nb_nonzero = missing_method_inquiry;
            number_table->nb_invert = missing_method_unary;
            number_table->nb_lshift = 0;
            number_table->nb_rshift = 0;
            number_table->nb_and = missing_method_binary;
            number_table->nb_xor = missing_method_binary;
            number_table->nb_or = missing_method_binary;
            number_table->nb_coerce = missing_method_coercion;
            number_table->nb_int = missing_method_unary;
            number_table->nb_long = missing_method_unary;
            number_table->nb_float = missing_method_unary;
            number_table->nb_oct = missing_method_unary;
            number_table->nb_hex = missing_method_unary;
        }
    }
    
    void init_buffer() {
        if(!buffer_table) {
            buffer_table = new PyBufferProcs;
            table->tp_as_buffer = buffer_table;
            buffer_table->bf_getreadbuffer = missing_method_getreadbufferproc;
            buffer_table->bf_getwritebuffer = missing_method_getwritebufferproc;
            buffer_table->bf_getsegcount = missing_method_getsegcountproc;
        }
    }
    public:
        // if you define one sequence method you must define 
        // all of them except the assigns
        
        PythonType (int itemsize = 0) {
            number_table = 0;
            sequence_table = 0;
            mapping_table = 0;
            buffer_table = 0;
            
            table = new PyTypeObject;
            *reinterpret_cast<PyObject*>(table) = py_object_initializer;
            table->ob_type = &PyType_Type;
            table->ob_size = 0;
            table->tp_name = "unknown";
            table->tp_basicsize = sizeof(T);
            table->tp_itemsize = itemsize;
            table->tp_dealloc = (destructor) standard_dealloc;
            table->tp_print = 0;
            table->tp_getattr = 0;
            table->tp_setattr = 0;
            table->tp_compare = 0;
            table->tp_repr = 0;
            table->tp_as_number = 0;
            table->tp_as_sequence = 0;
            table->tp_as_mapping =  0;
            table->tp_hash = 0;
            table->tp_call = 0;
            table->tp_str = 0;
            table->tp_getattro = 0;
            table->tp_setattro = 0;
            table->tp_as_buffer = 0;
            table->tp_flags = 0L;
            table->tp_doc = 0;
            table->tp_xxx5 = 0L;
            table->tp_xxx6 = 0L;
            table->tp_xxx7 = 0L;
            table->tp_xxx8 = 0L;

#ifdef COUNT_ALLOCS
            table->tp_alloc = 0;
            table->tp_free = 0;
            table->tp_maxalloc = 0;
            table->tp_next = 0;
#endif
        }
        virtual ~PythonType (){
            delete table;
            delete sequence_table;
            delete mapping_table;
            delete number_table;
            delete buffer_table;
        };
        
        PyTypeObject* type_object () const {return table;}
        
        void name (const char* nam) {
            table->tp_name = const_cast<char *>(nam);
        }
        
        void doc (const char* d) {
            table->tp_doc = const_cast<char *>(d);
        }
        
        void dealloc(void (*f)(PyObject*)) {
            table->tp_dealloc = f;
        }
        
        void print (int (*f)(PyObject*, FILE *, int)) {
            table->tp_print = f;
        }
        
        void getattr (PyObject* (*f)(PyObject*, char*)) {
            table->tp_getattr = f;
        }
        
        void setattr (int (*f)(PyObject*, char*, PyObject*)) {
            table->tp_setattr = f;
        }
        
        void getattro (PyObject* (*f)(PyObject*, PyObject*)) {
            table->tp_getattro = f;
        }
        
        void setattro (int (*f)(PyObject*, PyObject*, PyObject*)) {
            table->tp_setattro = f;
        }
        
        void compare (int (*f)(PyObject*, PyObject*)) {
            table->tp_compare = f;
        }
        
        void repr (PyObject* (*f)(PyObject*)) {
            table->tp_repr = f;
        }
        
        void str (PyObject* (*f)(PyObject*)) {
            table->tp_str = f;
        }
        
        void hash (long (*f)(PyObject*)) {
            table->tp_hash = f;
        }
        
        void call (PyObject* (*f)(PyObject*, PyObject*, PyObject*)) {
            table->tp_call = f;
        }
        
        // Sequence methods
        void sequence_length(int (*f)(PyObject*)) {
            init_sequence();
            sequence_table->sq_length = f;
        }
        
        void sequence_concat(PyObject* (*f)(PyObject*,PyObject*)) {
            init_sequence();
            sequence_table->sq_concat = f;
        }
        
        void sequence_repeat(PyObject* (*f)(PyObject*, int)) {
            init_sequence();
            sequence_table->sq_repeat = f;
        }
        
        void sequence_item(PyObject* (*f)(PyObject*, int)) {
            init_sequence();
            sequence_table->sq_item = f;
        }
        
        void sequence_slice(PyObject* (*f)(PyObject*, int, int)) {
            init_sequence();
            sequence_table->sq_slice = f;
        }
        
        void sequence_ass_item(int (*f)(PyObject*, int, PyObject*)) {
            init_sequence();
            sequence_table->sq_ass_item = f;
        }
        
        void sequence_ass_slice(int (*f)(PyObject*, int, int, PyObject*)) {
            init_sequence();
            sequence_table->sq_ass_slice = f;
        }
        // Mapping
        void mapping_length(int (*f)(PyObject*)) {
            init_mapping();
            mapping_table->mp_length = f;
        }
        
        void mapping_subscript(PyObject* (*f)(PyObject*, PyObject*)) {
            init_mapping();
            mapping_table->mp_subscript = f;
        }
        
        void mapping_ass_subscript(int (*f)(PyObject*, PyObject*, PyObject*)) {
            init_mapping();
            mapping_table->mp_ass_subscript = f;
        }
        
        // Number
        void number_nonzero (int (*f)(PyObject*)) {
            init_number();
            number_table->nb_nonzero = f;
        }
        
        void number_coerce (int (*f)(PyObject**, PyObject**)) {
            init_number();
            number_table->nb_coerce = f;
        }
        
        void number_negative (PyObject* (*f)(PyObject*)) {
            init_number();
            number_table->nb_negative = f;
        }
        void number_positive (PyObject* (*f)(PyObject*)) {
            init_number();
            number_table->nb_positive = f;
        }
        void number_absolute (PyObject* (*f)(PyObject*)) {
            init_number();
            number_table->nb_absolute = f;
        }
        void number_invert (PyObject* (*f)(PyObject*)) {
            init_number();
            number_table->nb_invert = f;
        }
        void number_int (PyObject* (*f)(PyObject*)) {
            init_number();
            number_table->nb_int = f;
        }
        void number_float (PyObject* (*f)(PyObject*)) {
            init_number();
            number_table->nb_float = f;
        }
        void number_long (PyObject* (*f)(PyObject*)) {
            init_number();
            number_table->nb_long = f;
        }
        void number_oct (PyObject* (*f)(PyObject*)) {
            init_number();
            number_table->nb_oct = f;
        }
        void number_hex (PyObject* (*f)(PyObject*)) {
            init_number();
            number_table->nb_hex = f;
        }
        
        void number_add (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_add = f;
        }
        void number_subtract (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_subtract = f;
        }
        void number_multiply (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_multiply = f;
        }
        void number_divide (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_divide = f;
        }
        void number_remainder (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_remainder = f;
        }
        void number_divmod (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_divmod = f;
        }
        void number_lshift (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_lshift = f;
        }
        void number_rshift (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_rshift = f;
        }
        void number_and (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_and = f;
        }
        void number_xor (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_xor = f;
        }
        void number_or (PyObject* (*f)(PyObject*, PyObject*)) {
            init_number();
            number_table->nb_or = f;
        }
        
        void number_power(PyObject* (*f)(PyObject*, PyObject*, PyObject*)) {
            init_number();
            number_table->nb_power = f;
        }
        // Buffer
        void buffer_getreadbuffer (int (*f)(PyObject*, int, void**)) {
            init_buffer();
            buffer_table->bf_getreadbuffer = f;
        }
        
        void buffer_getwritebuffer (int (*f)(PyObject*, int, void**)) {
            init_buffer();
            buffer_table->bf_getwritebuffer = f;
        }
        
        void buffer_getsegcount (int (*f)(PyObject*, int*)) {
            init_buffer();
            buffer_table->bf_getsegcount = f;
        }
        
    }; // end of PythonType<T>
    
    
    // Class PythonExtension is what you inherit from to create
    // a new Python extension type. You give your class itself
    // as the template paramter.
    
    // There are two ways that extension objects can get destroyed.
    // 1. Their reference count goes to zero
    // 2. Someone does an explicit delete on a pointer.
    // In (1) the problem is to get the destructor called 
    //        We register a special deallocator in the Python type object
    //        (see behaviors()) to do this.
    // In (2) there is no problem, the dtor gets called.

    // PythonExtension does not use the usual Python heap allocator, 
    // instead using new/delete. We do the setting of the type object
    // and reference count, usually done by PyObject_New, in the 
    // base class ctor.

    // This special deallocator does a delete on the pointer.

   
    template<class T>
        class PythonExtension: public PyObject 
    {
    private:
        static void extension_object_deallocator (PyObject* t) {
            delete (T*)(t);
        }

        explicit PythonExtension(const PythonExtension<T>& other);
        void operator=(const PythonExtension<T>& rhs);
    protected:
        explicit PythonExtension()
        {
            ob_refcnt = 1;
            ob_type = type_object();
        }
    public:
        virtual ~PythonExtension() {} 
        
        static PythonType<T>& behaviors() {
            static PythonType<T>* p;
            if(!p) 
            {
                p = new PythonType<T>();
                p->dealloc(extension_object_deallocator);
            }
            
            return *p;
        }
        
        static MethodTable& methods() {
            static MethodTable* p;
            if(!p) p = new MethodTable;
            return *p;
        }
        
        static PyTypeObject* type_object() {
            return behaviors().type_object();
        }
        
        static int check (PyObject *p) {
            // is p like me?
            return p->ob_type == type_object();
        }
        
        static int check (const Object& ob) {
            return check(ob.ptr());
        }
        
        PyObject* getattr(char* name) {
            return Py_FindMethod(methods().table(), static_cast<PyObject*>(this), name);
        }
        
    };
    
    // ExtensionObject<r> is an Object that will accept only r's.
    template<class T>
        class ExtensionObject: public Object {
    public:
        
        explicit ExtensionObject (PyObject *pyob): Object(pyob) {
            validate();
        }
        
        ExtensionObject(const ExtensionObject<T>& other): Object(*other) {
            validate();
        }
        
        ExtensionObject& operator= (const Object& rhs) {
            return (*this = *rhs);
        }
        
        ExtensionObject& operator= (PyObject* rhsp) {
            if(ptr() == rhsp) return *this;
            set(rhsp);
            return *this;
        }
        
        virtual bool accepts (PyObject *pyob) const {
            return (pyob && T::check(pyob));
        }       
    };
    
    NAMESPACE_END
        // End of CXX_Extensions.h
#endif
        
