import copy, weakref
import numpy

import cuda
from variable import Variable

class Function(object):
    """Function of variable(s) to variable(s) that leaves footprint to the
    output variables on application.

    All function implementations defined in :mod:`chainer.functions` inherit this class.

    The main feature of this class is keeping track of function applications as
    a backward graph. When a function is applied to :class:`Variable` objects,
    the function is copied, and its :meth:`forward` method is called on
    :data:`~Variable.data` fields of input variables, and at the same time it chains
    references from output variables to the function and from the function to
    its inputs.

    .. note::

       Strictly speaking, when a function is applied to some variable, a special
       :class:`Function` object called *splitter* is inserted between the
       variable and the function. The splitter is used to manipulate multiple
       function applications on the same variable, where gradients from
       different backward paths are accumulated at the variable.

    .. note::

       :meth:`__call__` copies the function instance before the forward
       computation and chaining. This enables us to reuse one function object
       for multiple function applications, where the different calls must use
       different references to the function object. Note that the copy is
       shallow, so implementations of :class:`Function` must take care of any
       member attributes shared accross forward and backward computations.

    .. admonition:: Example

       Let ``x`` an instance of :class:`Variable` and ``f`` an instance of
       :class:`Function` taking only one argument. Then a line

       >>> y = f(x)

       computes a new variable ``y`` and creates backward references. Actually,
       backward references are set as per the following diagram::

           x <--- (splitter) <--- x' <--- f' <--- y

       where prime "'" indicates a copy of the original object. If another
       application the function occurs as

       >>> z = f(x)

       then the splitter acts like a branch as the following new diagram::

                               |--- x'  <--- f'  <--- y
           x <--- (splitter) <-+
                               |--- x'' <--- f'' <--- z
    
       Note that the splitter is implicitly inserted and user does not need to
       take any special care of it; just remember that such branching is
       correctly managed by chainer.

    Every function implementation should provide :meth:`forward_cpu`,
    :meth:`forward_gpu`, :meth:`backward_cpu` and :meth:`backward_gpu`.
    Alternatively, one can provide :meth:`forward` and :meth:`backward` instead
    of separate methods. Backward methods have default implementations that
    just return ``None``, which indicates that the function is non-
    differentiable.

    Function implementations are classified into two types: parameterized ones
    and non-parameterized ones. A parameterized function holds parameter arrays
    and coresponding gradient arrays. Implementation can choose any way to keep
    these arrays, but it is recommended to keep them as attributes to easily
    migrate between CPU and GPU. Parameterized function must provide accessors
    to these arrays called :meth:`parameters` and :meth:`gradients`.

    Attributes:
        inputs: A tuple or list of input variables.
        outputs: A tuple or list of output variables.
        parameter_names: A tuple or list of names of parameter attributes.
            It is set to an empty tuple by default. This attribute is used by the
            default implementation of :meth:`parameters` property to gather the
            collection of parameter arrays. Implementation of parameterized
            function should override this field as an attribute or a property,
            or otherwise it should override :meth:`parameters` property.
        gradient_names: A tuple or list of names of gradient attributes. The
            detail is same as :data:`parameter_names`.

    """
    parameter_names = ()
    gradient_names = ()

    def __init__(self):
        self.inputs  = None
        self.outputs = None
        self.rank    = None

    def __call__(self, *inputs):
        """Applies forward propagation on input variables with chaining backward
        reference.

        Basic behavior is also expressed in documentation of :class:`Function`
        class. This function first copies itself to avoid conflict over multiple
        invokations.

        .. note::

           If the :data:`~Variable.data` attribute of input variables reside on
           GPU device, then, before it calls :meth:`forward` method, the
           appropriate device is selected, so in most cases implementor does not
           need to take care of device selection.

        Args:
            inputs: Tuple of input :class:`Variable` objects. All input
                variables must have same volatile flag.

        Returns:
            One
            :class:`Variable` object or a tuple of multiple
            :class:`Variable` objects.

        """
        # First copy itself to avoid duplication within the graph.
        self = copy.copy(self)

        if any(x.volatile for x in inputs):  # not build graph
            assert all(x.volatile for x in inputs)  # do not mix multiple volatility

            in_data = tuple(x.data for x in inputs)
            with cuda.using_device(*in_data):
                out_data = self.forward(in_data)
            assert type(out_data) == tuple

            outputs = list(Variable(y, volatile=True) for y in out_data)
            if len(outputs) == 1:
                return outputs[0]
            return outputs

        # Build graph
        # Be careful that forward references must be weak
        self.inputs = []
        for x in inputs:
            splitter = x.splitter()
            if splitter is None:
                splitter = Split(x)
                x.splitter = weakref.ref(splitter)
            self.inputs.append(splitter.add_branch())

        if self.inputs:
            self.rank = max(x.rank for x in self.inputs)
        else:
            self.rank = 0

        in_data = tuple(x.data for x in self.inputs)
        with cuda.using_device(*in_data):
            outputs = self.forward(in_data)
        assert type(outputs) == tuple

        ret = tuple(Variable(y) for y in outputs)
        for y in ret:
            y.set_creator(self)

        # Make forward references weak
        self.outputs = tuple(weakref.ref(y) for y in ret)

        if len(ret) == 1:
            return ret[0]
        return ret

    def forward(self, inputs):
        """Applies forward propagation to input arrays.

        It delegates the procedure to :meth:`forward_cpu` or :meth:`forward_gpu`
        by default. Which it selects is determined by the type of input arrays.
        Implementations of :class:`Function` must implement either cpu/gpu
        methods or this method.

        Args:
            inputs: Tuple of input array(s).

        Returns:
            Tuple of output array(s).

        .. warning::

            Implementations of :class:`Function` must take care that the
            return value must be a tuple even if it returns only one array.

        """
        if any(isinstance(x, cuda.GPUArray) for x in inputs):
            return self.forward_gpu(inputs)
        else:
            return self.forward_cpu(inputs)

    def forward_cpu(self, inputs):
        """Applies forward propagation to input arrays on CPU.

        Args:
            inputs: Tuple of :class:`~numpy.ndarray` object(s).

        Returns:
            tuple: Tuple of :class:`~numpy.ndarray` object(s).

        .. warning::

            Implementations of :class:`Function` must take care that the
            return value must be a tuple even if it returns only one array.

        """
        raise NotImplementedError()

    def forward_gpu(self, inputs):
        """Applies forward propagation to input arrays on GPU.

        Args:
            inputs: Tuple of :class:`~pycuda.gpuarray.GPUArray` object(s).

        Returns:
            tuple: Tuple of :class:`~pycuda.gpuarray.GPUArray` object(s).

        .. warning::

            Implementations of :class:`Function` must take care that the
            return value must be a tuple even if it returns only one array.

        """
        raise NotImplementedError()

    def backward(self, inputs, grad_outputs):
        """Applies backprop to output gradient arrays.

        It delegates the procedure to :meth:`backward_cpu` or
        :meth:`backward_gpu` by default. Which it selects is determined by the
        type of input arrays and output gradient arrays. Implementations of
        :class:`Function` must implement either cpu/gpu methods or this method,
        if the function is intended to be backprop-ed.

        Args:
            inputs: Tuple of input arrays.
            grad_outputs: Tuple of output gradient arrays.

        Returns:
            tuple: Tuple of input gradient arrays. Some or all of them can be
            ``None``, if the function is not differentiable on
            inputs.

        .. warning::

            Implementations of :class:`Function` must take care that the
            return value must be a tuple even if it returns only one array.

        """
        if any(isinstance(x, cuda.GPUArray) for x in inputs + grad_outputs):
            return self.backward_gpu(inputs, grad_outputs)
        else:
            return self.backward_cpu(inputs, grad_outputs)

    def backward_cpu(self, inputs, grad_outputs):
        """Applies backprop to output gradient arrays on CPU.

        Args:
            inputs: Tuple of input :class:`~numpy.ndarray` object(s).
            grad_outputs: Tuple of output gradient :class:`~numpy.ndarray`
                object(s).

        Returns:
            tuple: Tuple of input gradient :class:`~numpy.ndarray` object(s).
            Some or all of them can be ``None``, if the function is not
            differentiable on corresponding inputs.

        .. warning::

            Implementations of :class:`Function` must take care that the
            return value must be a tuple even if it returns only one array.

        """
        return tuple(None for _ in inputs)

    def backward_gpu(self, inputs, grad_outputs):
        """Applies backprop to output gradient arrays on GPU.

        Args:
            inputs: Tuple of input :class:`~pycuda.gpuarray.GPUArray` object(s).
            grad_outputs: Tuple of output gradient
                :class:`~pycuda.gpuarray.GPUArray` object(s).

        Returns:
            tuple: Tuple of input gradient :class:`~pycuda.gpuarray.GPUArray`
            object(s). Some or all of them can be ``None``, if the function is
            not differentiable on corresponding inputs.

        .. warning::

            Implementations of :class:`Function` must take care that the
            return value must be a tuple even if it returns only one array.

        """
        return tuple(None for _ in inputs)

    def unchain(self):
        """Purges in/out variables and removes this function from the backward
        graph.

        This method is called from :meth:`Variable.unchain_backward` method.

        """
        for y in self.outputs:
            y_ref = y()
            if y_ref is not None:
                y_ref.creator = None
        for x in self.inputs:
            x.splitter = weakref.ref(lambda: 0)  # dead ref
        self.inputs = None

    def to_gpu(self, device=None):
        """Migrates the function to GPU and returns self.

        The default implementation moves all fields of type
        :class:`~numpy.ndarray` onto GPU.

        Args:
            device (int or :class:`pycuda.driver.Device` or ``None``): Device ID
                of GPU that the function will be migrated on. If this is
                ``None``, the current device is used.

        Returns:
            self.

        """
        with cuda.using_device(device):
            for k, v in self.__dict__.iteritems():
                if isinstance(v, numpy.ndarray):
                    setattr(self, k, cuda.to_gpu(v))
                elif isinstance(v, cuda.GPUArray) and v.gpudata.device != device:
                    setattr(self, k, cuda.copy(v, out_device=device))
        return self

    def to_cpu(self):
        """Migrates the function to CPU and returns self.

        The default implementation moves all fields of type
        :class:`pycuda.gpuarray.GPUArray` onto CPU.

        Returns:
            self.

        """
        for k, v in self.__dict__.iteritems():
            if isinstance(v, cuda.GPUArray):
                setattr(self, k, cuda.to_cpu(v))
        return self

    @property
    def parameters(self):
        """A tuple of parameter arrays.

        Default implementation collects parameter arrays based on
        :data:`parameter_names` attribute.

        """
        return tuple(getattr(self, name) for name in self.parameter_names)

    @parameters.setter
    def parameters(self, values):
        for name, value in zip(self.parameter_names, values):
            setattr(self, name, value)

    @property
    def gradients(self):
        """A tuple of gradient arrays.

        Default implementation collects gradient arrays based on
        :data:`gradient_names` attribute.

        """
        return tuple(getattr(self, name) for name in self.gradient_names)

    @gradients.setter
    def gradients(self, values):
        for name, value in zip(self.gradient_names, values):
            setattr(self, name, value)


class Split(Function):
    """Special function to branch the graph at variable node.

    Split does not implement forward: it is intended to implicitly used by
    Function.

    """
    def __init__(self, var):
        self.inputs  = [var]
        self.outputs = []
        self.rank    = var.rank

    def add_branch(self):
        x = self.inputs[0]
        output = Variable(x.data)
        output.set_creator(self)
        self.outputs.append(weakref.ref(output))
        return output

    def backward(self, inputs, grad_outputs):
        # Accumulate gradients
        if len(grad_outputs) == 1:
            return grad_outputs  # no copy

        gx = None
        grad_outputs = [gy for gy in grad_outputs if gy is not None]
        device_changed = False
        try:
            for gy in grad_outputs:
                if gx is not None:
                    gx += gy
                elif isinstance(gy, cuda.GPUArray):
                    cuda.use_device(gy, pop=False)  # it affects to above +=, too
                    device_changed = True
                    gx = cuda.copy_async(gy)
                else:
                    gx = gy.copy()
        finally:
            if device_changed:
                cuda.Context.pop()
            
        return gx,
