# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the definition of the Inception Resnet V1 architecture.
As described in http://arxiv.org/abs/1602.07261.
  Inception-v4, Inception-ResNet and the Impact of Residual Connections
    on Learning
  Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow.contrib.slim as slim
from .utils import is_trainable


# Inception-Renset-A
def block35(net,
            scale=1.0,
            activation_fn=tf.nn.relu,
            scope=None,
            reuse=None,
            trainable_variables=True):
    """Builds the 35x35 resnet block."""
    with tf.variable_scope(scope, 'Block35', [net]):
        with tf.variable_scope('Branch_0'):
            tower_conv = slim.conv2d(
                net, 32, 1, scope='Conv2d_1x1',
                reuse=reuse,
                trainable=trainable_variables)
        with tf.variable_scope('Branch_1'):
            tower_conv1_0 = slim.conv2d(
                net,
                32,
                1,
                scope='Conv2d_0a_1x1',
                reuse=reuse,
                trainable=trainable_variables)
            tower_conv1_1 = slim.conv2d(
                tower_conv1_0,
                32,
                3,
                scope='Conv2d_0b_3x3',
                reuse=reuse,                
                trainable=trainable_variables)
        with tf.variable_scope('Branch_2'):
            tower_conv2_0 = slim.conv2d(
                net,
                32,
                1,
                scope='Conv2d_0a_1x1',
                reuse=reuse,                
                trainable=trainable_variables)
            tower_conv2_1 = slim.conv2d(
                tower_conv2_0,
                32,
                3,
                scope='Conv2d_0b_3x3',
                reuse=reuse,
                trainable=trainable_variables)
            tower_conv2_2 = slim.conv2d(
                tower_conv2_1,
                32,
                3,
                scope='Conv2d_0c_3x3',
                reuse=reuse,
                trainable=trainable_variables)
        mixed = tf.concat([tower_conv, tower_conv1_1, tower_conv2_2], 3)
        up = slim.conv2d(
            mixed,
            net.get_shape()[3],
            1,
            normalizer_fn=None,
            activation_fn=None,
            scope='Conv2d_1x1',
            reuse=reuse,
            trainable=trainable_variables)
        net += scale * up
        if activation_fn:
            net = activation_fn(net)
    return net


# Inception-Renset-B
def block17(net,
            scale=1.0,
            activation_fn=tf.nn.relu,
            scope=None,
            reuse=None,
            trainable_variables=True):
    """Builds the 17x17 resnet block."""
    with tf.variable_scope(scope, 'Block17', [net]):
        with tf.variable_scope('Branch_0'):
            tower_conv = slim.conv2d(
                net, 128, 1, scope='Conv2d_1x1', trainable=trainable_variables, reuse=reuse)
        with tf.variable_scope('Branch_1'):
            tower_conv1_0 = slim.conv2d(
                net,
                128,
                1,
                scope='Conv2d_0a_1x1',
                reuse=reuse,
                trainable=trainable_variables)
            tower_conv1_1 = slim.conv2d(
                tower_conv1_0,
                128, [1, 7],
                scope='Conv2d_0b_1x7',
                reuse=reuse,
                trainable=trainable_variables)
            tower_conv1_2 = slim.conv2d(
                tower_conv1_1,
                128, [7, 1],
                scope='Conv2d_0c_7x1',
                reuse=reuse,
                trainable=trainable_variables)
        mixed = tf.concat([tower_conv, tower_conv1_2], 3)
        up = slim.conv2d(
            mixed,
            net.get_shape()[3],
            1,
            normalizer_fn=None,
            activation_fn=None,
            scope='Conv2d_1x1',
            reuse=reuse,
            trainable=trainable_variables)
        net += scale * up
        if activation_fn:
            net = activation_fn(net)
    return net


# Inception-Resnet-C
def block8(net,
           scale=1.0,
           activation_fn=tf.nn.relu,
           scope=None,
           reuse=None,           
           trainable_variables=True):
    """Builds the 8x8 resnet block."""
    with tf.variable_scope(scope, 'Block8', [net]):
        with tf.variable_scope('Branch_0'):
            tower_conv = slim.conv2d(
                net, 192, 1, scope='Conv2d_1x1', trainable=trainable_variables,reuse=reuse)
        with tf.variable_scope('Branch_1'):
            tower_conv1_0 = slim.conv2d(
                net,
                192,
                1,
                scope='Conv2d_0a_1x1',
                reuse=reuse,
                trainable=trainable_variables)
            tower_conv1_1 = slim.conv2d(
                tower_conv1_0,
                192, [1, 3],
                scope='Conv2d_0b_1x3',
                reuse=reuse,
                trainable=trainable_variables)
            tower_conv1_2 = slim.conv2d(
                tower_conv1_1,
                192, [3, 1],
                scope='Conv2d_0c_3x1',
                reuse=reuse,
                trainable=trainable_variables)
        mixed = tf.concat([tower_conv, tower_conv1_2], 3)
        up = slim.conv2d(
            mixed,
            net.get_shape()[3],
            1,
            normalizer_fn=None,
            activation_fn=None,
            scope='Conv2d_1x1',
            reuse=reuse,
            trainable=trainable_variables)
        net += scale * up
        if activation_fn:
            net = activation_fn(net)
    return net


def reduction_a(net, k, l, m, n, trainable_variables=True, reuse=None):
    with tf.variable_scope('Branch_0'):
        tower_conv = slim.conv2d(
            net,
            n,
            3,
            stride=2,
            padding='VALID',
            scope='Conv2d_1a_3x3',
            reuse=reuse,
            trainable=trainable_variables)
    with tf.variable_scope('Branch_1'):
        tower_conv1_0 = slim.conv2d(
            net, k, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse)
        tower_conv1_1 = slim.conv2d(
            tower_conv1_0,
            l,
            3,
            scope='Conv2d_0b_3x3',
            reuse=reuse,
            trainable=trainable_variables)
        tower_conv1_2 = slim.conv2d(
            tower_conv1_1,
            m,
            3,
            stride=2,
            padding='VALID',
            reuse=reuse,
            scope='Conv2d_1a_3x3',
            trainable=trainable_variables)
    with tf.variable_scope('Branch_2'):
        tower_pool = slim.max_pool2d(
            net, 3, stride=2, padding='VALID', scope='MaxPool_1a_3x3')
    net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3)
    return net


def reduction_b(net, trainable_variables=True, reuse=None):
    with tf.variable_scope('Branch_0'):
        tower_conv = slim.conv2d(
            net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse)
        tower_conv_1 = slim.conv2d(
            tower_conv,
            384,
            3,
            stride=2,
            padding='VALID',
            reuse=reuse,
            scope='Conv2d_1a_3x3',
            trainable=trainable_variables)
    with tf.variable_scope('Branch_1'):
        tower_conv1 = slim.conv2d(
            net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse)
        tower_conv1_1 = slim.conv2d(
            tower_conv1,
            256,
            3,
            stride=2,
            padding='VALID',
            scope='Conv2d_1a_3x3',
            reuse=reuse,
            trainable=trainable_variables)
    with tf.variable_scope('Branch_2'):
        tower_conv2 = slim.conv2d(
            net, 256, 1, scope='Conv2d_0a_1x1', trainable=trainable_variables, reuse=reuse)
        tower_conv2_1 = slim.conv2d(
            tower_conv2,
            256,
            3,            
            scope='Conv2d_0b_3x3',
            reuse=reuse,
            trainable=trainable_variables)
        tower_conv2_2 = slim.conv2d(
            tower_conv2_1,
            256,
            3,
            stride=2,
            padding='VALID',
            scope='Conv2d_1a_3x3',
            reuse=reuse,
            trainable=trainable_variables)
    with tf.variable_scope('Branch_3'):
        tower_pool = slim.max_pool2d(
            net, 3, stride=2, padding='VALID', scope='MaxPool_1a_3x3')
    net = tf.concat([tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool],
                    3)
    return net


def inception_resnet_v1_batch_norm(inputs,
                                   dropout_keep_prob=0.8,
                                   bottleneck_layer_size=128,
                                   reuse=None,
                                   scope='InceptionResnetV1',
                                   mode=tf.estimator.ModeKeys.TRAIN,
                                   trainable_variables=None,
                                   weight_decay=1e-5,
                                   **kwargs):
    """
    Creates the Inception Resnet V1 model applying batch not to each
    Convolutional and FullyConnected layer.

    Parameters
    ----------

      inputs:
        4-D tensor of size [batch_size, height, width, 3].

      num_classes:
        number of predicted classes.

      is_training:
        whether is training or not.

      dropout_keep_prob: float
        the fraction to keep before final layer.

      reuse:
        whether or not the network and its variables should be reused. To be
        able to reuse 'scope' must be given.

      scope:
        Optional variable_scope.

      trainable_variables: :any:`list`
        List of variables to be trainable=True

    Returns
    -------
      logits:
        the logits outputs of the model.

      end_points:
        the set of end_points from the inception model.

    """

    batch_norm_params = {
        # Decay for the moving averages.
        'decay': 0.995,
        # epsilon to prevent 0s in variance.
        'epsilon': 0.001,
        # force in-place updates of mean and variance estimates
        'updates_collections': None,
    }

    with slim.arg_scope(
        [slim.conv2d, slim.fully_connected],
            weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
            weights_regularizer=slim.l2_regularizer(weight_decay),
            normalizer_fn=slim.batch_norm,
            normalizer_params=batch_norm_params):
        return inception_resnet_v1(
            inputs,
            dropout_keep_prob=dropout_keep_prob,
            bottleneck_layer_size=bottleneck_layer_size,
            reuse=reuse,
            scope=scope,
            mode=mode,
            trainable_variables=trainable_variables,
        )


def inception_resnet_v1(inputs,
                        dropout_keep_prob=0.8,
                        bottleneck_layer_size=128,
                        reuse=None,
                        scope='InceptionResnetV1',
                        mode=tf.estimator.ModeKeys.TRAIN,
                        trainable_variables=None,
                        **kwargs):
    """
    Creates the Inception Resnet V1 model.

    Parameters
    ----------

      inputs:
        4-D tensor of size [batch_size, height, width, 3].

      num_classes:
        number of predicted classes.

      is_training:
        whether is training or not.

      dropout_keep_prob: float
        the fraction to keep before final layer.

      reuse:
        whether or not the network and its variables should be reused. To be
        able to reuse 'scope' must be given.

      scope:
        Optional variable_scope.

      trainable_variables: :any:`list`
        List of variables to be trainable=True

    Returns
    -------
      logits:
        the logits outputs of the model.

      end_points:
        the set of end_points from the inception model.

    """
    end_points = {}

    with tf.variable_scope(scope, 'InceptionResnetV1', [inputs], reuse=reuse):
        with slim.arg_scope(
            [slim.dropout],
                is_training=(mode == tf.estimator.ModeKeys.TRAIN)):

            with slim.arg_scope(
                [slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
                    stride=1,
                    padding='SAME'):

                # 149 x 149 x 32
                name = "Conv2d_1a_3x3_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "Conv2d_1a_3x3"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = slim.conv2d(
                        inputs,
                        32,
                        3,
                        stride=2,
                        padding='VALID',
                        reuse=reuse,
                        scope=name,
                        trainable=trainable)
                    end_points[name] = net

                # 147 x 147 x 32
                name = "Conv2d_2a_3x3_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):
                    name = "Conv2d_2a_3x3"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = slim.conv2d(
                        net,
                        32,
                        3,
                        padding='VALID',
                        scope=name,
                        reuse=reuse,
                        trainable=trainable)
                    end_points[name] = net

                # 147 x 147 x 64
                name = "Conv2d_2b_3x3_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "Conv2d_2b_3x3"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = slim.conv2d(
                        net, 64, 3, scope=name, trainable=trainable, reuse=reuse)
                    end_points[name] = net

                # 73 x 73 x 64
                net = slim.max_pool2d(
                    net, 3, stride=2, padding='VALID', scope='MaxPool_3a_3x3')
                end_points['MaxPool_3a_3x3'] = net

                # 73 x 73 x 80
                name = "Conv2d_3b_1x1_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "Conv2d_3b_1x1"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = slim.conv2d(
                        net,
                        80,
                        1,
                        padding='VALID',
                        scope=name,
                        reuse=reuse,
                        trainable=trainable)
                    end_points[name] = net

                # 71 x 71 x 192
                name = "Conv2d_4a_3x3_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "Conv2d_4a_3x3"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = slim.conv2d(
                        net,
                        192,
                        3,
                        padding='VALID',
                        scope=name,
                        reuse=reuse,
                        trainable=trainable)
                    end_points[name] = net

                # 35 x 35 x 256
                name = "Conv2d_4b_3x3_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "Conv2d_4b_3x3"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = slim.conv2d(
                        net,
                        256,
                        3,
                        stride=2,
                        padding='VALID',
                        scope=name,
                        reuse=reuse,
                        trainable=trainable)
                    end_points[name] = net

                # 5 x Inception-resnet-A
                name = "block35_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "block35"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = slim.repeat(
                        net,
                        5,
                        block35,
                        scale=0.17,
                        reuse=reuse,
                        trainable_variables=trainable)
                    end_points[name] = net

                # Reduction-A
                name = "Mixed_6a_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "Mixed_6a"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    with tf.variable_scope(name):
                        net = reduction_a(
                            net,
                            192,
                            192,
                            256,
                            384,
                            trainable_variables=trainable)
                    end_points[name] = net

                # 10 x Inception-Resnet-B
                name = "block17_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "block17"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = slim.repeat(
                        net,
                        10,
                        block17,
                        scale=0.10,
                        trainable_variables=trainable)
                    end_points[name] = net

                # Reduction-B
                name = "Mixed_7a_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "Mixed_7a"
                    trainable = is_trainable(name, trainable_variables, mode=mode)

                    with tf.variable_scope(name):
                        net = reduction_b(
                            net, trainable_variables=trainable, reuse=reuse)
                    end_points[name] = net

                # 5 x Inception-Resnet-C
                name = "block8_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "block8"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = slim.repeat(
                        net,
                        5,
                        block8,
                        scale=0.20,
                        trainable_variables=trainable)
                    end_points[name] = net

                name = "Mixed_8b_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "Mixed_8b"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = block8(
                        net,
                        activation_fn=None,
                        trainable_variables=trainable)
                    end_points[name] = net

                with tf.variable_scope('Logits'):
                    end_points['PrePool'] = net
                    #pylint: disable=no-member
                    net = slim.avg_pool2d(
                        net,
                        net.get_shape()[1:3],
                        padding='VALID',
                        scope='AvgPool_1a_8x8')
                    net = slim.flatten(net)

                    net = slim.dropout(
                        net,
                        dropout_keep_prob,
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        scope='Dropout')

                    end_points['PreLogitsFlatten'] = net

                name = "Bottleneck_BN"
                trainable = is_trainable(name, trainable_variables, mode=mode)
                with slim.arg_scope(
                    [slim.batch_norm],
                        is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                        trainable=trainable):

                    name = "Bottleneck"
                    trainable = is_trainable(name, trainable_variables, mode=mode)
                    net = slim.fully_connected(
                        net,
                        bottleneck_layer_size,
                        activation_fn=None,
                        scope=name,
                        trainable=trainable)
                end_points[name] = net

    return net, end_points
