#!/usr/bin/env python
# -*- coding: utf-8

import os
import numpy
import argparse
from collections import Counter
from numpy import log2 as log

import anvio
import anvio.tables as t
import anvio.dbops as dbops
import anvio.dictio as dictio
import anvio.terminal as terminal


__author__ = "A. Murat Eren"
__copyright__ = "Copyright 2015, The anvio Project"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"
__status__ = "Development"


pp = terminal.pretty_print
progress = terminal.Progress()
run = terminal.Run(width = 30)

states_dict = {'AA': 'a',
               'TT': 't',
               'CC': 'c',
               'GG': 'g',
               'NN': 'n',
               'AT': 'w',
               'TA': 'w',
               'AC': 'm',
               'CA': 'm',
               'CT': 'y',
               'TC': 'y',
               'AG': 'r',
               'GA': 'r',
               'CG': 's',
               'GC': 's',
               'GT': 'k',
               'TG': 'k'}

valid_chars = set(states_dict.values()) 
reverse_states = dict([(states_dict[v], ''.join(set(v))) for v in set([''.join(sorted(k)) for k in states_dict.keys()])])


def entropy(l):
    E_Cs = []
    for char in valid_chars:
        P_C = (l.count(char) * 1.0 / len(l)) + 0.0000000000000000001
        E_Cs.append(P_C * log(P_C))
   
    # return un-weighted entropy
    return -(sum(E_Cs))


class ContigVariability:
    def __init__(self, split_name, split_data, split_sequence, num_positions_from_each_split = 2):
        self.split_name = split_name
        self.split_data = split_data
        self.split_sequence = split_sequence

        self.layer_names = sorted(self.split_data.keys())
        self.layer_index = dict(zip(self.layer_names, range(0, len(self.layer_names))))

        self.positions_dict = {}
        self.num_positions_from_each_split = num_positions_from_each_split
        self.positions_selected = None

        self.TOP = lambda key, num: [x[1] for x in sorted([(self.positions_dict[p][key], p) for p in self.positions_selected], reverse = True)[:num]]


    def populate_positions_dict(self):
        '''We take a split, and create a record of each variable position in it.'''
        counter = Counter()
        for variable_positions_in_layer in [s['competing_nucleotides'].keys() for s in self.split_data.values()]:
            for pos in variable_positions_in_layer:
                counter[pos] += 1

        # initiate a null dict only with the information regarding how many times
        # a position is observed across layers with a variability index > 0
        for pos in counter:
            self.positions_dict[pos] = {'occurrence': counter[pos],
                                        'variability': 0.0,
                                        'mean_coverage': 0.0,
                                        'coverages': [],
                                        'contents': [],
                                        'avatars': [],
                                        'entropy': 0.0 }

        # fill in more information to finalize the dict
        for p in self.positions_dict:
            for layer_name in self.layer_names:
                layer = self.split_data[layer_name]
                self.positions_dict[p]['variability'] += layer['variability'][p]
                self.positions_dict[p]['coverages'].append(layer['coverage'][p])

                if layer['competing_nucleotides'].has_key(p):
                    content = layer['competing_nucleotides'][p]
                else:
                    if layer['coverage'][p]:
                        content = self.split_sequence[p] + self.split_sequence[p]
                    else:
                        content = 'NN'
                self.positions_dict[p]['contents'].append(content)
                self.positions_dict[p]['avatars'].append(states_dict[content])


            avatar_set = set(self.positions_dict[p]['avatars'])
            avatar_conversion = dict(zip(avatar_set, range(1, len(avatar_set) + 1)))
            self.positions_dict[p]['identities'] = [avatar_conversion[a] for a in self.positions_dict[p]['avatars']]
            self.positions_dict[p]['mean_coverage'] = numpy.mean(self.positions_dict[p]['coverages'])
            self.positions_dict[p]['entropy'] = entropy(''.join(self.positions_dict[p]['avatars']))

        # remove any position with 0 entropy:
        positions_with_zero_entropy = [p for p in self.positions_dict if self.positions_dict[p]['entropy'] < 0.00000001]
        for pos in positions_with_zero_entropy:
            self.positions_dict.pop(pos)

        self.positions_selected = self.positions_dict.keys()


    def position_selection_heuristics(self):
        '''The main purpose of this function is to normalize the number of positions across
           multiple splits by selecting only top `self.num_positions_from_each_split` positions
           fom each split.
        '''

        # first order all positions in a split based on entropy
        self.positions_selected = self.TOP('entropy', self.num_positions_from_each_split * 3)
        # first order them by variability
        self.positions_selected = self.TOP('variability', self.num_positions_from_each_split * 2)
        # then order by coverage
        self.positions_selected = self.TOP('mean_coverage', self.num_positions_from_each_split)


    def analyze_split_summary(self):
        # get a dictionary with all positions in all layers with their occurrences
        self.populate_positions_dict()

        # set/select interesting positions
        self.position_selection_heuristics()


    def text_report(self):
        run.warning('', lc = 'crimson')
        for pos in self.positions_selected:
            run.info('%d (entropy: %f)' % (pos, self.positions_dict[pos]['entropy']), None, header = True)
            for layer_name in self.layer_names:
                i = self.layer_index[layer_name]
                layer = self.split_data[layer_name]
                content = self.positions_dict[pos]['contents'][i]
                run.info(layer_name, '%s %s %f %d' % (content, states_dict[content], layer['variability'][pos], self.positions_dict[pos]['coverages'][i]))


class VariabilityWrapper:
    def __init__(self, args):
        self.args = args
        self.samples_dict = None
        self.splits = []
        self.split_positions = {}
        self.layer_names = []
        self.units = []

    def analyze(self):
        run_files_path = os.path.dirname(os.path.abspath(args.summary_dict))
        summary_index = dictio.read_serialized_object(args.summary_dict)
        num_positions_from_each_split = args.num_positions_from_each_split
        min_scatter = args.minimum_scatter
        
        splits = [c.strip() for c in open(args.splits_of_interest).readlines()]
        num_splits = len(splits)
        run.info('Splits', '%d splits found' % (num_splits))
        
        annotation_db = dbops.AnnotationDatabase(args.annotation_db)
        splits_info = annotation_db.db.get_table_as_dict(t.splits_info_table_name)
        contig_sequences = annotation_db.db.get_table_as_dict(t.contig_sequences_table_name)
        annotation_db.disconnect()


        progress.new('Analyzing splits')
        for i in range(0, num_splits):
            split_name = splits[i]
            progress.update('%d of %d' % (i + 1, num_splits))
            d = dictio.read_serialized_object(os.path.join(run_files_path, summary_index[split_name]))

            parent = splits_info[split_name]['parent']
            start = splits_info[split_name]['start']
            end = splits_info[split_name]['end']

            split_sequence = contig_sequences[parent]['sequence'][start:end]

            w.add_split(split_name, d, split_sequence, num_positions_from_each_split)

        progress.update('Generating output ...')
        w.create_TAB_delim_file(args.output, min_scatter)
        progress.end()
        run.info('Output matrix', args.output)


    def add_split(self, split_name, split_data, split_sequence, num_positions_from_each_split = 2):
        v = ContigVariability(split_name, split_data, split_sequence, num_positions_from_each_split)
        v.analyze_split_summary()

        self.splits.append(split_name)
        self.split_positions[split_name] = v.positions_selected

        # create an empty samples_dict
        if not self.samples_dict:
            self.samples_dict ={}
            for layer_name in v.layer_names:
                self.samples_dict[layer_name] = {}
            self.layer_names = v.layer_names

        for pos in v.positions_selected:
            unit = '%s_pos_%d' % (split_name, pos)
            self.units.append(unit)

            for layer_name in v.layer_names:
                i = v.layer_index[layer_name]
                self.samples_dict[layer_name][unit] = v.positions_dict[pos]['avatars'][i]


    def create_TAB_delim_file(self, path, min_scatter = 1):
        # units to discard due to min_scatter:
        units_to_discard = set([])

        if min_scatter > 1:
            for unit in self.units:
                values = []

                for layer_name in self.layer_names:
                    values.append(self.samples_dict[layer_name][unit])

                if len(set(values)) > 1 and Counter(values).most_common()[1][1] < min_scatter:
                    units_to_discard.add(unit)

        if len(units_to_discard):
            units = [unit for unit in self.units if unit not in units_to_discard]
        else:
            units = self.units

        output = open(path, 'w')
        output.write('\t'.join(['layers'] + units) + '\n')
        for layer_name in self.layer_names:
            values = []
            for unit in units:
                values.append(self.samples_dict[layer_name][unit])
            output.write('\t'.join([layer_name] + [str(reverse_states[v]) for v in values]) + '\n')
        output.close()




##############################################################################

parser = argparse.ArgumentParser(description='Generate Variability Matrix')
parser.add_argument('summary_dict', metavar = 'SUMMARY_DICT', default = None,
                    help = 'Summary file')
parser.add_argument('-a', '--annotation-db', metavar = "ANNOTATION_DB", required = True,\
                    help = 'Annotation database.')
parser.add_argument('-s', '--splits-of-interest', metavar = "SPLITS", required = True,\
                    help = 'List of splits to analyze.')
parser.add_argument('-S', '--samples-of-interest', metavar = "SAMPLES", default = None,\
                    help = 'List of samples to retain. If not declared, all samples are used.')
parser.add_argument('-n', '--num-positions-from-each-split', type=int, default = 2,
                    help = 'Each split may have one or more variable positions. What is the maximum number of positons\
                            to report from each split is described via this paramter.')
parser.add_argument('-r', '--min-ratio', type=float, default = 0, metavar = 'RATIO',
                    help = 'Minimum ratio of the competing nucleotides at a given position. Default is %(default)d.')
parser.add_argument('-m', '--minimum-scatter', type=int, default = 1,
                    help = 'This is tricky. If you have N samples in your dataset, a given variable position x in one\
                            of your splits can split your N samples into t groups based on the identity of the\
                            variation they harbor at position x. For instance, t would have been 1, if all samples had the same\
                            type of variation at position x (which would not be very interesting, because in this case\
                            position x would have zero contribution to a deeper understanding of how these samples differ\
                            based on variability. When t > 1, it would mean that identities at position x across samples\
                            do differ. But how much scattering occurs based on position x when t > 1? If t=2, how many\
                            samples ended in each group? Obviously, even distribution of samples across groups may tell\
                            us something different than uneven distribution of samples across groups. So, this parameter\
                            filters out any x if "the number of samples in the second largest group" (=scatter) is less\
                            than -m. Here is an example: lets assume you have 7 samples. While 5 of those have AG, 2\
                            of them have TC at position x. This would mean scatter of x is 2. If you set -m to 2, this\
                            position would not be reported in your output matrix. The default value for -m is\
                            %(default)d, which means every x that is returned from ContigVariability class is reported.')
parser.add_argument('-o', '--output', type=str, default = 'variability.txt', help = 'Output path for the matrix')
args = parser.parse_args()


w = VariabilityWrapper(args)
w.analyze()
