from brian import *
from statements import *
from expressions import *
from codeitems import *
from blocks import *
from languages import *
from dependencies import *
from formatting import *

__all__ = [
    'language_invariant_symbol_method',
    'Symbol',
        'RuntimeSymbol',
        'ArraySymbol',
            'NeuronGroupStateVariableSymbol',
        'SliceIndex',
        'ArrayIndex',
    'get_neuron_group_symbols',
    ]


def language_invariant_symbol_method(basemethname, langs, fallback=None,
                                     doc=None):
    '''
    Helper function to create methods for :class:`Symbol` classes.
    
    Sometimes it is clearer to write a separate method for each language the
    :class:`Symbol` supports. This function can generate a method that can take
    any language, and calls the desired method. For example, if you had defined
    two methods ``load_python`` and ``load_c`` then you would define the
    ``load`` method as follows::
    
        load = language_invariant_symbol_method('load',
            {'python':load_python, 'c':load_c})

    The ``fallback`` gives a method to call if no language-specific method was
    found. A docstring can be provided to ``doc``.
    '''
    def meth(self, *args, **kwds):
        langname = self.language.name
        if langname in langs:
            meth = langs[langname]
            return meth(self, *args, **kwds)
        if fallback is not None:
            return fallback(self, *args, **kwds)
    meth.__name__ = basemethname
    if doc is None:
        try:
            doc = '''
            Method generated by :func:`language_invariant_symbol_method`.
            
            Languages and methods follow:
            '''
            for langname, langmeth in langs.iteritems():
                doc += '\n'+strip_empty_lines('''
            ``{langname}``
                :meth:`{langmeth.__name__}`
                '''.format(langname=langname, langmeth=langmeth))
            if fallback is not None:
                if hasattr(fallback, 'im_class'):
                    fbname = fallback.im_class.__name__+'.'+fallback.__name__
                else:
                    fbname = fallback.__name__
                doc += '\n'+strip_empty_lines('''
            ``fallback``
                :meth:`{fbname}`
                '''.format(fbname=fbname))
        except:
            doc = '''
            Method generated by :func:`language_invariant_symbol_method`.
            
            Automatic generation of language list failed.
            '''
    meth.__doc__ = doc
    return meth
        
class Symbol(object):
    '''
    Base class for all symbols.
    
    Every symbol has attributes ``name`` and ``language`` which should be a
    string and :class:`Language` object respectively. The symbol class should
    define some or all of the methods below.
    '''
    supported_languages = []
    def __init__(self, name, language):
        self.name = name
        self.language = language
        if not self.supported():
            raise NotImplementedError(
                    "Language "+language.name+" not supported for symbol "+name)
    def supported(self):
        '''
        Returns ``True`` if the language specified at initialisation is
        supported. By default, checks if the language name is in the class
        attribute ``supported_languages`` (list), however can be overridden.
        '''
        return self.language.name in self.supported_languages
    def update_namespace(self, read, write, vectorisable, namespace):
        '''
        Called by :meth:`resolve`, can be overridden to modify the namespace,
        e.g. adding data.
        '''
        pass
    def load(self, read, write, vectorisable):
        '''
        Called by :meth:`resolve`, can be overridden to perform more complicated
        loading code. By default, returns an empty :class:`Block`.
        '''
        return Block()
    def save(self, read, write, vectorisable):
        '''
        Called by :meth:`resolve`, can be overridden to perform more complicated
        saving code. By default, returns an empty :class:`Block`.
        '''
        return Block()
    def read(self):
        '''
        The string that should be used when this symbol is read, by default
        just the symbol name.
        '''
        return self.name
    def write(self):
        '''
        The string that should be used when this symbol is written, by default
        just the symbol name.
        '''
        return self.name
    def resolve(self, read, write, vectorisable, item, namespace):
        '''
        Creates a modified item in which the symbol has been resolved.
        
        For example, if we started from the expression::
        
            x += 1
            
        and we wanted to produce the following C++ code::
        
            for(int i=0; i<n; i++)
            {
                double &x = __arr_x[i];
                x += 1;
            }
        
        we would need to take the expression ``x+=1`` and embed it inside a
        loop.
        
        Function arguments:
        
        ``read``
            Whether or not we read the value of the symbol. This is computed
            by analysing the dependencies by the main :func:`resolve` function.
        ``write``
            Whether or not we write a value to the symbol.
        ``vectorisable``
            Whether or not the expression is vectorisable. In Python, we can
            only vectorise one multi-valued index, so if there are two or more,
            only the innermost loop will be vectorised.
        ``item``
            The code item which needs to be resolved.
        ``namespace``
            The namespace to put data in.
            
        The default implementation first calls :meth:`update_namespace`, then
        creates a new :class:`Block` consisting of the value returned by
        :meth:`load`, the ``item``, and the value returned by :meth:`save`.
        Finally, this symbol's name is added to the ``resolved`` set for this
        block. 
        '''
        self.update_namespace(read, write, vectorisable, namespace)
        block = Block(
            self.load(read, write, vectorisable),
            item,
            self.save(read, write, vectorisable))
        block.resolved = block.resolved.union([self.name])
        return block
    def dependencies(self):
        '''
        Returns the set of dependencies of this symbol, can be overridden.
        '''
        return set()
    def resolution_requires_loop(self):
        '''
        Should return ``True`` if the resolution of this symbol will require a
        loop. The :func:`resolve` function uses this to optimise the symbol
        resolution order.
        '''
        return False
    def multi_valued(self):
        '''
        Should return ``True`` if this symbol is considered to have multiple
        values, for example if you are iterating over an array like so::
        
            for(int i=0; i<n; i++)
            {
                double &x = arr[i];
                ...
            }
        
        Here the symbol ``x`` is single-valued and depends on the symbol ``i``
        which is multi-valued and whose resolution required a loop. By default
        returns ``False`` unless the class has an attribute ``multiple_values``
        in which case this is returned.
        '''
        if hasattr(self, 'multiple_values'):
            return self.multiple_values
        return False


class RuntimeSymbol(Symbol):
    '''
    This Symbol is guaranteed by the context to be inserted into the namespace
    at runtime and can be used without modification to the name, for example
    ``t`` or ``dt``.
    '''
    def supported(self):
        '''
        Returns ``True``.
        '''
        return True


class ArraySymbol(Symbol):
    '''
    This symbol is used to specify a value taken from an array.
    
    Schematically: ``name = arr[index]``.
    
    ``arr`` (numpy array)
        The numpy array which the values will be taken from.
    ``name``, ``language``
        The name of the symbol and language.
    ``index``
        The index name, by default ``'_index_'+name``.
    ``array_name``
        The name of the array, by default ``'_arr_'+name``.
        
    Introduces a read-dependency on ``index`` and ``array_name``.
    '''
    supported_languages = ['python', 'c', 'gpu']
    def __init__(self, arr, name, language, index=None, array_name=None):
        self.arr = arr
        if index is None:
            index = '_index_'+name
        if array_name is None:
            array_name = '_arr_'+name
        self.index = index
        self.array_name = array_name
        Symbol.__init__(self, name, language)
    # Python implementation
    def load_python(self, read, write, vectorisable):
        '''
        If ``read`` is false, does nothing. Otherwise, returns a
        :class:`CodeStatement` of the form::
        
            name = array_name[index]
        '''
        dependencies = set([Read(self.index), Read(self.array_name)])
        if not read:
            block = Block()
            block.dependencies = dependencies
            return block
        code = '{name} = {array_name}[{index}]'.format(
            name=self.name,
            array_name=self.array_name,
            index=self.index)
        return CodeStatement(code, dependencies, set())
    def write_python(self):
        '''
        Returns ``array_name[index]``.
        '''
        return self.array_name+'['+self.index+']'
    # C implementation
    def load_c(self, read, write, vectorisable):
        '''
        Uses :class:`CDefineFromArray`.
        '''
        return CDefineFromArray(self.name, self.array_name,
                                self.index,
                                # use references when possible
                                #reference=write,
                                #const=(not write),
                                # don't use references
                                const=True,
                                reference=False,
                                dtype=self.arr.dtype)
    # only use this if you are not using references
    def write_c(self):
        return self.array_name+'['+self.index+']'
    # Language invariant implementation
    def update_namespace(self, read, write, vectorisable, namespace):
        '''
        Adds pair ``(array_name, arr)`` to namespace.
        '''
        namespace[self.array_name] = self.arr
    def dependencies(self):
        '''
        Read-dependency on index.
        '''
        return set([Read(self.index)])
    # use this version to use references
#    write = language_invariant_symbol_method('write',
#        {'python':write_python}, Symbol.write)
    # use this version to not use references
    write = language_invariant_symbol_method('write',
        {'python':write_python, 'c':write_c, 'gpu':write_c})
    load = language_invariant_symbol_method('load',
        {'python':load_python, 'c':load_c, 'gpu':load_c})


class SliceIndex(Symbol):
    '''
    Multi-valued symbol that ranges over a slice.
    
    Schematically: ``name = slice(start, end)``
    
    ``name``, ``language``
        Symbol name and language.
    ``start``
        The initial value, can be an integer or string.
    ``end``
        The final value (not included), can be an integer or string.
    ``all``
        Set to ``True`` to indicate that the slice covers the whole range
        possible (small optimisation for Python).
    '''
    supported_languages = ['python', 'c', 'gpu']
    multiple_values = True
    def __init__(self, name, start, end, language, all=False):
        self.start = start
        self.end = end
        self.all = all
        Symbol.__init__(self, name, language)
    # Python implementation
    def resolve_python(self, read, write, vectorisable, item, namespace):
        '''
        If ``vectorisable`` and ``all`` then we simply return ``item`` and
        add ``name=slice(None)`` to the namespace.
        
        If ``vectorisable`` and not ``all`` then we prepend the following
        statement to ``item``::
        
            name = slice(start, end)
            
        If not ``vectorisable`` then we add a for loop over
        ``xrange(start, end)``.
        '''
        if vectorisable:
            if self.all:
                namespace[self.name] = slice(None)
                return item
            code = '{name} = slice({start}, {end})'
            code = code.format(name=self.name, start=self.start, end=self.end)
            return Block(
                CodeStatement(code, self.dependencies(), set()),
                item,
                )
        else:
            container = 'xrange({start}, {end})'.format(
                                                start=self.start, end=self.end)
            return PythonForBlock(self.name, container, item)
    # C implementation
    def resolve_c(self, read, write, vectorisable, item, namespace):
        '''
        Returns ``item`` embedded in a C for loop.
        '''
        spec ='int {name}={start}; {name}<{end}; {name}++'.format(
            name=self.name, start=self.start, end=self.end)
        return CForBlock(self.name, spec, item)
    # GPU implementation
    def resolve_gpu(self, read, write, vectorisable, item, namespace):
        '''
        If not ``vectorisable`` return :meth:`resolve_c`. If ``vectorisable``
        we mark it by adding ``_gpu_vector_index = name`` and
        ``_gpu_vector_slice = (start, end)`` to the namespace. The GPU code
        will handle this later on.
        '''
        if vectorisable:
            # This just defers the generation of the GPU vector index to the
            # Code object
            namespace['_gpu_vector_index'] = self.name
            namespace['_gpu_vector_slice'] = (self.start, self.end)
            return item
        else:
            return self.resolve_c(read, write, vectorisable, item, namespace)
    # Language invariant implementation
    def resolution_requires_loop(self):
        '''
        Returns ``True`` except for Python.
        '''
        return self.language.name!='python'
    resolve = language_invariant_symbol_method('resolve',
        {'c':resolve_c, 'python':resolve_python, 'gpu':resolve_gpu})

class ArrayIndex(Symbol):
    '''
    Multi-valued symbol giving an index that iterates through an array.
    
    Schematically: ``name = array_name[array_slice]``
    
    ``name``, ``language``
        Symbol name and language.
    ``array_name``
        The name of the array we iterate through.
    ``array_len``
        The length of the array (int or string), by default has value
        ``'_len_'+array_name``.
    ``index_name``
        The name of the index into the array, by default has value
        ``'_index_'+array_name``.
    ``array_slice``
        A pair ``(start, end)`` giving a slice of the array, if left the whole
        array will be used.
        
    Dependencies are collected from those arguments that are used (``item``,
    ``array_name``, ``array_len``, ``array_slice``). 
    '''
    supported_languages = ['python', 'c', 'gpu']
    multiple_values = True
    def __init__(self, name, array_name, language, array_len=None,
                 index_name=None, array_slice=None):
        if index_name is None:
            index_name = '_index_'+array_name
        if array_len is None:
            array_len = '_len_'+array_name
        self.array_name = array_name
        self.array_len = array_len
        self.index_name = index_name
        self.array_slice = array_slice 
        Symbol.__init__(self, name, language)
    # Python implementation
    def resolve_python(self, read, write, vectorisable, item, namespace):
        '''
        If vectorisable it will prepend one of these two forms to ``item``::
        
            name = array_name
            name = array_name[start:end]
            
        (where ``(start, end) = array_slice`` if provided).
        
        If not vectorisable, it will return a for loop over either
        ``array_name`` or `array_name[start:end]``.
        '''
        if vectorisable:
            code = '{name} = {array_name}'
            start, end = '', ''
            if self.array_slice is not None:
                start, end = self.array_slice
                code += '[{start}:{end}]'
            code = code.format(start=start, end=end,
                name=self.name, array_name=self.array_name)
            block = Block(
                CodeStatement(code, set([Read(self.array_name)]), set()),
                item)
            return block
        else:
            if self.array_slice is None:
                return PythonForBlock(self.name, self.array_name, item)
            else:
                start, end = self.array_slice
                container = '{array_name}[{start}:{end}]'.format(
                    array_name=self.array_name, start=start, end=end)
                return PythonForBlock(self.name, container, item)
    # C implementation
    def resolve_c(self, read, write, vectorisable, item, namespace):
        '''
        Returns a C++ for loop of the form::
        
            for(int index_name=start; index_name<end; index_name++)
            {
                const int name = array_name[index_name];
                ...
            }
        
        If defined ``(start, end)=array_slice`` otherwise
        ``(start, end)=(0, array_len)``.
        '''
        if self.array_slice is None:
            start, end = '0', self.array_len
        else:
            start, end = self.array_slice
        spec ='int {index_name}={start}; {index_name}<{end}; {index_name}++'
        spec = spec.format(
            index_name=self.index_name, start=start, end=end)
        block = Block(
            CDefineFromArray(self.name, self.array_name, self.index_name,
                             dtype=int, reference=False, const=True),
            item)
        return CForBlock(self.name, spec, block)
    # GPU implementation
    def resolve_gpu(self, read, write, vectorisable, item, namespace):
        '''
        If not vectorisable, use :meth:`resolve_c`. If vectorisable, we set
        the following in the namespace::
        
            _gpu_vector_index = index_name
            _gpu_vector_slice = (start, end)
            
        Where ``start`` and ``end`` are as in :meth:`resolve_c`. This marks
        that we want to vectorise over this index, and the GPU code will handle
        this later. Finally, we prepend the item with:
        
            const int name = array_name[index_name];
        '''
        if not vectorisable:
            return self.resolve_c(read, write, vectorisable, item, namespace)
        if self.array_slice is None:
            start, end = '0', self.array_len
        else:
            start, end = self.array_slice
        # This just defers the generation of the GPU vector index to the
        # Code object
        namespace['_gpu_vector_index'] = self.index_name
        namespace['_gpu_vector_slice'] = (self.start, self.end)
        block = Block(
            CDefineFromArray(self.name, self.array_name, self.index_name,
                             dtype=int, reference=False, const=True),
            item)
        return block
    # Language invariant implementation
    def resolution_requires_loop(self):
        '''
        Returns ``True`` except for Python.
        '''
        return self.language.name!='python'
    resolve = language_invariant_symbol_method('resolve',
        {'c':resolve_c, 'python':resolve_python})


class NeuronGroupStateVariableSymbol(ArraySymbol):
    '''
    Symbol for a state variable.
    
    Wraps :class:`ArraySymbol`.
    
    Arguments:
    
    ``name``, ``language``
        Symbol name and language.
    ``group``
        The :class:`NeuronGroup`.
    ``varname``
        The state variable name in the group.
    ``index``
        An index name (or use default of :class:`ArraySymbol`).
    '''
    def __init__(self, group, varname, name, language, index=None):
        self.group = group
        self.varname = varname
        arr = group.state_(varname)
        ArraySymbol.__init__(self, arr, name, language, index=index)

def get_neuron_group_symbols(group, language, index='_neuron_index',
                             prefix=''):
    '''
    Returns a dict of :class:`NeuronGroupStateVariable` from a group.
    
    Arguments:
    
    ``group``
        The group to extract symbols from.
    ``language``
        The language to use.
    ``index``
        The name of the neuron index, by default ``_neuron_index``.
    ``prefix``
        An optional prefix to add to each symbol.
    '''
    eqs = group._eqs
    symbols = dict(
       (prefix+name,
        NeuronGroupStateVariableSymbol(group, name, prefix+name, language,
               index=index)) for name in eqs._diffeq_names)
    return symbols
