#!python

#
# Copyright 2019-2020 NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


"""Generate a reference and some noisy reads

Generates 3 files:
    1. A FASTA reference file (default name reference.fasta)
    2. A FASTA reads file (default name reads.fasta)
    3. A PAF overlaps file with all-to-all overalps for the reads

Example usage:
    genome_simulator --reference_length 2700000 --num_reads 54000 --median_read_length=10000
"""

from __future__ import print_function
from __future__ import division

import argparse
import functools
import gzip
import multiprocessing
import random
import uuid

from tqdm import tqdm

from genomeworks import simulators
from genomeworks.io import fastaio
from genomeworks.io import pafio
from genomeworks.simulators import genomesim
from genomeworks.simulators import readsim


def _readgen(num_reads, args, read_generator, reference_string):
    reads = []
    for r in range(num_reads):
        read, start, end = read_generator.generate_read(reference_string,
                                                        median_length=args.median_read_length,
                                                        snv_error_rate=args.snv_error_rate,
                                                        insertion_error_rate=args.insertion_error_rate,
                                                        deletion_error_rate=args.deletion_error_rate)
        read = gzip.compress(bytes(read, 'utf-8'))
        read_id = "read_" + str(uuid.uuid4())
        reads.append((read_id, read, start, end))
    return reads


def main():
    parser = argparse.ArgumentParser(description="Create a reference and some reads")

    parser.add_argument('--reference_length',
                        type=int,
                        default=int(1e6))
    parser.add_argument('--reference_filepath',
                        type=str,
                        default='ref.fasta')
    parser.add_argument('--reads_filepath',
                        type=str,
                        default='reads.fasta')
    parser.add_argument('--paf_filepath',
                        type=str,
                        default='overlaps.paf')
    parser.add_argument('--random_seed',
                        type=int,
                        default=0)
    parser.add_argument('--median_read_length',
                        type=int,
                        default=10000)
    parser.add_argument('--snv_error_rate',
                        type=float,
                        default=0.025)
    parser.add_argument('--insertion_error_rate',
                        type=float,
                        default=0.0125)
    parser.add_argument('--deletion_error_rate',
                        type=float,
                        default=0.0125)
    parser.add_argument('--num_reads',
                        type=int,
                        default=100)
    parser.add_argument('--num_threads',
                        type=int,
                        default=multiprocessing.cpu_count())

    args = parser.parse_args()

    random.seed(args.random_seed)

    genome_simulator = genomesim.MarkovGenomeSimulator()
    reference_string = genome_simulator.build_reference(args.reference_length,
                                                        transitions=simulators.HIGH_GC_HOMOPOLYMERIC_TRANSITIONS,
                                                        num_threads=args.num_threads)

    reference = [('Reference', reference_string)]
    read_generator = readsim.NoisyReadSimulator()
    readgen = functools.partial(_readgen, args=args, read_generator=read_generator, reference_string=reference_string)
    pool = multiprocessing.Pool(args.num_threads)

    num_workers = args.num_threads * 10
    quotient, remainder = divmod(args.num_reads, num_workers)
    reads_per_worker = [quotient + int(i < remainder) for i in range(num_workers)]

    print('Simulating reads:')
    reads = tqdm(pool.imap(readgen, reads_per_worker), total=len(reads_per_worker))
    reads = [val for sublist in reads for val in sublist]

    overlaps = readsim.generate_overlaps(reads, gzip_compressed=True)

    # Write the overlaps
    pafio.write_paf(overlaps, args.paf_filepath)

    # Write the reference
    fastaio.write_fasta(reference, args.reference_filepath)

    # Write the reads
    fastaio.write_fasta(reads, args.reads_filepath, gzip_compressed=True)


if __name__ == '__main__':
    main()
