#! /usr/bin/env python3

"""
The tree module extract, from a table, a tree limited to a given depth.
"""

from __future__ import print_function
import argparse
import pickle
from collections import defaultdict, Mapping
from wagoner.utils import *
from wagoner.table import Table

__all__ = ["Tree"]


class Tree(Mapping):
    """
    A tree is a mapping of nodes to mappings of nodes to weights. Each node is
    a pair of string and length, meaning that if a word of length ends with the
    string, it can be followed by successors, according to the given weights.
    """

    def __init__(self, tree):
        """
        Create a new tree from the given tree content.

        :param tree: the tree content.
        """
        self.__content = tree

    @classmethod
    def from_table(cls, table, length, prefix=0, flatten=False):
        """
        Extract from the given table a tree for word length, taking only
        prefixes of prefix length (if greater than 0) into account to compute
        successors.
        
        :param table: the table to extract the tree from;
        :param length: the length of words generated by the extracted tree;
                       greater or equal to 1;
        :param prefix: if greater than 0, the length of the prefixes used for
                       computing successors;
        :param flatten: whether to flatten the table or not;
        :return: the tree corresponding to words of length from table.
        """
        # Build the expanded tree with necessary suffix and length
        tree = defaultdict(dict)  # The tree
        pending = {(">", 0)}  # The nodes to expand
        while pending:
            suffix, size = pending.pop()
            if size < length:
                choices = table.weighted_choices(suffix, exclude={"<"},
                                                 flatten=flatten)
                # The word length is not reached yet, expand
                for successor, weight in choices.items():
                    expanded = suffix + successor
                    if prefix > 0:
                        expanded = expanded[-prefix:]
                    new_node = (expanded, size + 1)
                    tree[(suffix, size)][new_node] = weight
                    pending.add(new_node)
            else:
                choices = table.weighted_choices(suffix, flatten=flatten)
                # The word length is reached, only add < if present
                if "<" in choices:
                    tree[(suffix, size)][("<", size + 1)] = 1
                else:
                    tree[(suffix, size)] = dict()
        return cls(cls.trim_tree(tree))

    @staticmethod
    def trim_tree(tree):
        """
        Remove the dead branches of tree, that is, the resulting tree accepts
        the same language as the original one (that is, the same words that end
        with the < character), but parts of the tree that lead to nothing are
        removed.

        :param tree: the tree;
        :return: the tree without dead branches.
        """
        # Remove empty nodes
        new_tree = {k: v for k, v in tree.items() if v}
        # Remove missing successors
        new_tree = {k: {successor: weight for successor, weight in v.items()
                        if successor in new_tree or successor[0] == "<"}
                    for k, v in new_tree.items()}
        while tree != new_tree:
            tree = new_tree
            # Remove empty nodes
            new_tree = {k: v for k, v in tree.items() if v}
            # Remove missing successors
            new_tree = {k: {successor: weight
                            for successor, weight in v.items()
                            if successor in new_tree or successor[0] == "<"}
                        for k, v in new_tree.items()}
        return new_tree

    def __getitem__(self, key):
        return self.__content[key]

    def __iter__(self):
        return iter(self.__content)

    def __len__(self):
        return len(self.__content)

    def __str__(self):
        return str(self.__content)

    def random_word(self, *args, **kwargs):
        """
        Return a random word from this tree. The length of the word depends on
        the this tree.

        :return: a random word from this tree.

        args and kwargs are ignored.
        """
        word = ""
        current = (">", 0)
        while current[0] != "<":
            choices = self[current]
            choice = random_weighted_choice(choices)
            current = choice
            word += current[0][-1]
        return word[:-1]


def process_arguments():
    """
    Process the command line arguments. The arguments are:
     * the table to generate random from;
     * -l (or --length) for the length of generated words (default: 10);
     * -p (or --prefix) for the maximum of prefixes to consider (default: 0);
     * -c (or --count) for the number of words to generate (default: 10);
     * -f (or --flatten) if the table must be flattened before generation.
    """
    parser = argparse.ArgumentParser(description="Generate trees from "
                                                 "the given content",
                                     epilog="This script can generate trees "
                                            "from tables or texts; if a "
                                            "text is given, "
                                            "the corresponding table is "
                                            "first built, then the tree is "
                                            "built.")
    parser.add_argument("content", help="the content (a table or a text)")
    parser.add_argument("--length", "-l", type=nonzero_natural, default=10,
                        dest="length", help="the length of words generable by "
                                            "the tree (default: 10)")
    parser.add_argument("--prefix", "-p", type=natural, default=0,
                        dest="prefix", help="if not 0, the maximum length of "
                                            "prefixes to consider when "
                                            "choosing the next character "
                                            "(default: 0)")
    parser.add_argument("--flatten", "-f", action="store_true", default=False,
                        dest="flatten", help="flatten the table")
    parser.add_argument("--output", "-o", type=argparse.FileType('wb'),
                        default=None, dest="output",
                        help="the output destination; "
                             "if missing, print the tree")
    return parser.parse_args()

if __name__ == "__main__":
    args = process_arguments()
    try:
        with open(args.content, "rb") as table_file:
            table = pickle.load(table_file)
    except pickle.UnpicklingError:
        with open(args.content, "r") as text_file:
            table = Table.from_words(extract_words(text_file),
                                     prefix=args.prefix, flatten=args.flatten)
    tree = Tree.from_table(table, args.length, prefix=args.prefix,
                           flatten=args.flatten)
    if args.output:
        pickle.dump(tree, args.output)
    else:
        print(tree)
