#!/usr/bin/env python
# -*- coding: utf-8

import os
import sys

import anvio
import anvio.db as db
import anvio.tables as t
import anvio.utils as utils
import anvio.dictio as dictio
import anvio.terminal as terminal
import anvio.filesnpaths as filesnpaths

from anvio.errors import ConfigError, FilesNPathsError


__author__ = "A. Murat Eren"
__copyright__ = "Copyright 2015, The anvio Project"
__credits__ = []
__license__ = "GPL 3.0"
__version__ = anvio.__version__
__maintainer__ = "A. Murat Eren"
__email__ = "a.murat.eren@gmail.com"
__status__ = "Development"


run = terminal.Run()
progress = terminal.Progress()


def generate_gexf_network_file(units, samples_dict, output_file, sample_mapping_dict = None,
                               unit_mapping_dict = None, project = None, sample_size=8, unit_size=2,
                               skip_sample_labels = False, skip_unit_labels = False):
    output = open(output_file, 'w')

    samples = sorted(samples_dict.keys())
    sample_mapping_categories = sorted([k for k in sample_mapping_dict.keys() if k != 'colors']) if sample_mapping_dict else None
    unit_mapping_categories = sorted([k for k in unit_mapping_dict.keys() if k not in ['colors', 'labels']]) if unit_mapping_dict else None
    
    output.write('''<?xml version="1.0" encoding="UTF-8"?>\n''')
    output.write('''<gexf xmlns:viz="http:///www.gexf.net/1.1draft/viz" xmlns="http://www.gexf.net/1.2draft" version="1.2">\n''')
    output.write('''<meta lastmodifieddate="2010-01-01+23:42">\n''')
    output.write('''    <creator>Oligotyping pipeline</creator>\n''')
    if project:
        output.write('''    <creator>Network description for %s</creator>\n''' % (project))
    output.write('''</meta>\n''')
    output.write('''<graph type="static" defaultedgetype="undirected">\n\n''')

    if sample_mapping_dict:
        output.write('''<attributes class="node" type="static">\n''')
        for i in range(0, len(sample_mapping_categories)):
            category = sample_mapping_categories[i]
            output.write('''    <attribute id="%d" title="%s" type="string" />\n''' % (i, category))
        output.write('''</attributes>\n\n''')

    # FIXME: IDK what the hell is this one about:
    if unit_mapping_dict:
        output.write('''<attributes class="edge">\n''')
        for i in range(0, len(unit_mapping_categories)):
            category = unit_mapping_categories[i]
            output.write('''    <attribute id="%d" title="%s" type="string" />\n''' % (i, category))
        output.write('''</attributes>\n\n''')

    output.write('''<nodes>\n''')
    for sample in samples:
        if skip_sample_labels:
            output.write('''    <node id="%s">\n''' % (sample))
        else:
            output.write('''    <node id="%s" label="%s">\n''' % (sample, sample))
        output.write('''        <viz:size value="%d"/>\n''' % sample_size)
        if sample_mapping_dict and sample_mapping_dict.has_key('colors'):
            output.write('''        <viz:color r="%d" g="%d" b="%d" a="1"/>\n''' %\
                                             utils.HTMLColorToRGB(sample_mapping_dict['colors'][sample], scaled = False))

        if sample_mapping_categories:
            output.write('''        <attvalues>\n''')
            for i in range(0, len(sample_mapping_categories)):
                category = sample_mapping_categories[i]
                output.write('''            <attvalue id="%d" value="%s"/>\n''' % (i, sample_mapping_dict[category][sample]))
            output.write('''        </attvalues>\n''')

        output.write('''    </node>\n''')

    for unit in units:
        if skip_unit_labels:
            output.write('''    <node id="%s">\n''' % (unit))
        else:
            if unit_mapping_dict and unit_mapping_dict.has_key('labels'):
                output.write('''    <node id="%s" label="%s">\n''' % (unit, unit_mapping_dict['labels'][unit]))
            else:
                output.write('''    <node id="%s">\n''' % (unit))
        output.write('''        <viz:size value="%d" />\n''' % unit_size)

        if unit_mapping_categories:
            output.write('''        <attvalues>\n''')
            for i in range(0, len(unit_mapping_categories)):
                category = unit_mapping_categories[i]
                output.write('''            <attvalue id="%d" value="%s"/>\n''' % (i, unit_mapping_dict[category][unit]))
            output.write('''        </attvalues>\n''')

        output.write('''    </node>\n''')

    output.write('''</nodes>\n''')
    
    edge_id = 0
    output.write('''<edges>\n''')
    for sample in samples:
        for i in range(0, len(units)):
            unit = units[i]
            if samples_dict[sample][unit] > 0.0:
                if unit_mapping_dict:
                    output.write('''    <edge id="%d" source="%s" target="%s" weight="%f">\n''' % (edge_id, unit, sample, samples_dict[sample][unit]))
                    if unit_mapping_categories:
                        output.write('''        <attvalues>\n''')
                        for i in range(0, len(unit_mapping_categories)):
                            category = unit_mapping_categories[i]
                            output.write('''            <attvalue id="%d" value="%s"/>\n''' % (i, unit_mapping_dict[category][unit]))
                        output.write('''        </attvalues>\n''')
                    output.write('''    </edge>\n''')
                else:
                    output.write('''    <edge id="%d" source="%s" target="%s" weight="%f" />\n''' % (edge_id, unit, sample, samples_dict[sample][unit]))


                edge_id += 1
    output.write('''</edges>\n''')
    output.write('''</graph>\n''')
    output.write('''</gexf>\n''')
    
    output.close()


class NetworkDescriptonSamples:
    def __init__(self):
        self.runinfo_path = None
        self.annotation_db_path = None
        self.use_named_functions_only = False
        self.runinfo = None
        self.functions = {'labels': {}}
        self.P = lambda x: os.path.join(os.path.dirname(self.runinfo_path), x)


    def init(self):
        filesnpaths.is_file_exists(self.runinfo_path)
        filesnpaths.is_file_exists(self.annotation_db_path)

        self.runinfo = dictio.read_serialized_object(self.runinfo_path)
        self.profile_db_path = self.P(self.runinfo['profile_db'])
        self.profile_db = db.DB(self.profile_db_path, t.profile_db_version)

        self.annotation_db = db.DB(self.annotation_db_path, t.annotation_db_version)

        if self.annotation_db.get_meta_value('annotation_hash') != self.profile_db.get_meta_value('annotation_hash'):
            raise ConfigError, "OK. The run described in '%s' was not profiled using the annotation database\
                                      you sent as a parameter ('%s'). Hmm :/"% (self.runinfo_path, self.annotation_db_path)

        table_names = self.profile_db.get_table_names()

        if not int(self.profile_db.get_meta_value('merged')):
            raise ConfigError, "The profile database describes a single run. Current implementation of this\
                                      program restricts its use to merged runs. Sorry :/"

        if t.gene_coverages_table_name not in table_names:
            raise ConfigError, "There is no '%s' table in the profile database. This does not really make\
                                      any sense :/" % t.gene_coverages_table_name

        progress.new('Init')
        progress.update('Reading genes table from %s ...' % self.profile_db_path)
        self.genes_dict = self.profile_db.get_table_as_dict(t.gene_coverages_table_name)
        self.samples = sorted(list(set([r['sample_id'] for r in self.genes_dict.values()])))
        self.genes = sorted(list(set([r['prot'] for r in self.genes_dict.values()])))

        progress.update('Reading annotations table from %s  ...' % self.annotation_db_path)
        self.genes_in_contigs = self.annotation_db.get_table_as_dict(t.genes_contigs_table_name)

        # read functions.
        progress.update('Sifting through genes with non-hypothetical functions ...')
        for gene in self.genes_in_contigs:
            if self.genes_in_contigs[gene]['function'] and self.genes_in_contigs[gene]['function'].lower().find('hypothetical') < 0 and gene in self.genes:
                self.functions['labels'][gene] = self.genes_in_contigs[gene]['function']

        self.samples_dict = self.get_samples_dict(self.samples, self.genes_dict)
        progress.end()
        run.info('init', 'Genes database initialized for %d genes in %d samples.' % (len(self.genes), len(self.samples)))


    def generate_genes_network(self):
        network_desc_output_path = self.P('SAMPLE-GENE-NETWORK.gexf')
        progress.new('Processing')
        progress.update('Generating network description for %d genes across %d samples ... ' % (len(self.genes), len(self.samples_dict)))
        generate_gexf_network_file(self.genes, self.samples_dict, network_desc_output_path)
        progress.end()
        run.info('network with genes', network_desc_output_path)


    def generate_functions_network(self):
        genes_with_functions = sorted(self.functions['labels'].keys())
        progress.new('Processing')
        progress.update('Generating network description for %d genes w/functions across %d samples ... ' % (len(genes_with_functions), len(self.samples_dict)))

        network_desc_output_path = self.P('SAMPLE-FUNCTION-NETWORK.gexf')
        generate_gexf_network_file(genes_with_functions, self.samples_dict,
                                   network_desc_output_path, unit_mapping_dict = self.functions)
        progress.end()
        run.info('network with functions', network_desc_output_path)



    def get_samples_dict(self, samples, table, unit = 'prot'):
        samples_dict = {}
        for sample in samples:
            samples_dict[sample] = {}
        for v in table.values():
            samples_dict[v['sample_id']][v[unit]] = v['mean_coverage']
        return samples_dict


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Generate a network description file')
    parser.add_argument('runinfo', metavar = 'RUNINFO',
                        help = 'anvio RUNINFO file')
    parser.add_argument('annotation_db', metavar = 'ANNOTATION',
                        help = 'Annotation database that has been used for the run described in RUNINFO file')

    args = parser.parse_args()

    try:
        network = NetworkDescriptonSamples()
        network.runinfo_path = args.runinfo
        network.annotation_db_path = args.annotation_db
        network.init()
        network.generate_genes_network()
        network.generate_functions_network()
    except ConfigError, e:
        print e
        sys.exit(-1)
    except FilesNPathsError, e:
        print e
        sys.exit(-2)
