#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Pavel Korshunov <pavel.korshunov@idiap.ch>
# Mon  18 Aug 23:12:22 CEST 2015
#
# Copyright (C) 2012-2015 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the ipyplotied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""
The script generates text file lists for grandtest protocol using the file names and
the directory structure of the voicePA database. These lists are later used to build an SQL-based
bob.db interface for voicePA database.
"""
from __future__ import print_function

import os
import math
import argparse
import numpy

def listFiles(rootdir="."):
    objs = []
    for root, dirs, files in os.walk(rootdir):
        _, path = root.split(rootdir)
        path = path.split(os.path.sep)
        gender = "male"  # female also ends with "male", so we cover both
        if gender in root:
            for sample in files:
                sample, ext = os.path.splitext(sample)
                if ext == ".wav":
                    subpath = path[1:]
                    subpath.append(sample)
                    start_id_indx = root.index(gender) + len(gender) + 1
                    id = root[start_id_indx:start_id_indx + 4]
                    objs.append((os.path.join(*subpath), id))
    return objs


def generateIds(rootdir=".", gender="male", filter="genuine"):
    ids = []
    for root, dirs, files in os.walk(rootdir):
        curdir = os.path.basename(root)
        if curdir == gender and filter in root:
            ids = dirs
            break
    return ids


def getSubIds(ids, set="train"):
    subids = []

    if not ids:
        return subids


    #  indices=numpy.arange(len(ids))

#    sublen = int(math.floor(len(ids) / 3))
#
#    if set == "train":
#        subids = ids[0:sublen]
#    elif set == "dev":
#        subids = ids[sublen:2 * sublen]
#    elif set == "eval":
#        subids = ids[2 * sublen:]

    return subids


#  numpy.choice(ids)


def command_line_arguments(command_line_parameters):
    """Parse the program options"""

    # set up command line parser
    parser = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('-d', '--database-directory', required=False,
                        help="The root directory of database data.")

    # parse arguments
    args = parser.parse_args(command_line_parameters)

    return args

def readIds(filename):
    def convertfunc(x):
        return x
    
    names = ('id', 'set_label')
    converters = {
      0: convertfunc,
      1: convertfunc}
      
    ids = numpy.genfromtxt(open(filename, "r"), dtype=None, names=names,
                           converters=converters, invalid_raise=True)
    
    new_dtype = []
    for name in ids.dtype.names:
        new_dtype.append((name, str(ids.dtype[name]).replace('S', 'U')))
    ids = numpy.array(ids, new_dtype)
    return ids

#    ids = []
#    with open("clients%s.txt" % (clientset), "w") as f:
#        lines = f.readlines()
#        for line in lines:
#            if gender in line:
#                ids += line.split()
#    return ids

def main(command_line_parameters=None):
    """Traverses the folder structure of voicePA database and create file lists corresponding
    to train, dev, and test sets indicating genuine or spoofed samples."""

    args = command_line_arguments(command_line_parameters)

    rootdir = os.path.curdir
    if args.database_directory:
        rootdir = args.database_directory

#    data_types = ["genuine", "attack"]
    data_types = ["genuine"]
    protocol_names = ['smalltest']
#    protocol_names = ['grandtest', 'avspoofPA', 'r106', 'seboffice', 'replay', 'synthetic', 'mobile', 'samsungs3', 'iphone3gs']
    sets = ["train", "dev", "eval"]
#    clientset = "-grandtest"
    clientset = "-smalltest"

    # fix ids split for the whole thing, so we are consistent across diff filters
#    male_ids = generateIds(rootdir, "male", data_types[0])
#    female_ids = generateIds(rootdir, "female", data_types[0])

    # read IDs from the provided client lists
    client_ids = readIds("clients%s.txt" % (clientset))

    # traverse database and construct the whole list of files
    file_list = listFiles(rootdir)

    # write clients' Ids for train, dev, and test subsets into the file
#    fid = open("clients%s.txt" % (clientset), "w")
    for set in sets:
        #   if not os.path.exists(set):
        #     os.makedirs(set)
#        set_ids = getSubIds(male_ids, set)
#        set_ids.extend(getSubIds(female_ids, set))
        set_ids = client_ids['id'][client_ids['set_label'] == set]
        print("set %s, set_ids: %s " % (set, " ".join(client_id for client_id in set_ids)))
#        fid.write('\n'.join("%s %s" % (client_id, set) for client_id in set_ids))
#        fid.write('\n')  # wrote all client Ids
        for data_type in data_types:
            for protocol in protocol_names:
                # get ids for the specific set
                # set_ids = getSubIds(male_ids, set)
                # set_ids.extend(getSubIds(female_ids, set))
                # print "set %s, set_ids: %s " %(set, " ".join(id for id in set_ids))

                # construct the sublist for the current set and the current data_type
                list4set = []
                if data_type == "attack":
                    for item in file_list:
                        if item[1] in set_ids and data_type in item[0]:
                            if protocol == 'smalltest' or protocol == 'grandtest':
                                list4set.append(item[0])
                            if protocol == 'avspoofPA' and 'r107' in item[0]:
                                list4set.append(item[0])
                            if (protocol == 'r106' or protocol == 'mobile') and 'r106' in item[0]:
                                list4set.append(item[0])
                            if (protocol == 'seboffice'  or protocol == 'mobile') and 'seboffice' in item[0]:
                                list4set.append(item[0])
                            if protocol == 'replay' and not ('-ss-' in item[0] or '-vc-' in item[0]):
                                list4set.append(item[0])
                            if protocol == 'synthetic' and ('-ss-' in item[0] or '-vc-' in item[0]):
                                list4set.append(item[0])
                            if protocol == 'samsungs3' and ('r106-samsungs3' in item[0] or 'seboffice-samsungs3' in item[0]):
                                list4set.append(item[0])
                            if protocol == 'iphone3gs' and ('r106-iphone3gs' in item[0] or 'seboffice-iphone3gs' in item[0]):
                                list4set.append(item[0])
                    filename = "%s-%s-%s.txt" % (data_type, protocol, set)
                else:
                    for item in file_list:
                        if item[1] in set_ids and data_type in item[0]:
                            list4set.append(item[0])
                    filename = "real-%s-%s.txt" % (protocol, set)

                # write file names either into real or attack file list
                with open(filename, "w") as f:
                    f.write('\n'.join(stem for stem in list4set))
                    f.close()

#    fid.close()


if __name__ == '__main__':
    main()
