#!python

from Cinema.Prompt import Launcher, Visualiser
import matplotlib.pyplot as plt
import numpy as np
import argparse
from Cinema.Interface.Utils import findData
import os

parser = argparse.ArgumentParser()
parser.add_argument('-g', '--gdml', action='store', type=str, default='',
                    dest='gdml', help='input gdml file')
parser.add_argument('-v', '--visualize', action='store_true', dest='visualize', help='flag to visualize gdml model')
parser.add_argument('-s', '--seed', action='store', type=int, default=4096,
                    dest='seed', help='random seed number')
parser.add_argument('-n', '--neutronNum', action='store', type=float, default=100,
                    dest='neutronNum', help='neutron number')
parser.add_argument('-b', '--blacklist',  type=str, nargs='+', dest='blacklist', help='solid mesh blacklist to inform the geometry mesh loader ')
parser.add_argument('-d', '--dumpmesh', action='store_true', dest='dumpmesh', help='dump mesh into disk')
parser.add_argument('-m', '--mergemesh', action='store_true', dest='mergemesh', help='flag to merge mesh for efficient visualize')
parser.add_argument('--nSeg', action='store', type=int, default=30,
                    dest='nSegments', help='number of verts a volume')



#TODO:
# parser.add_argument('-l', '--geoLayer', action='store', type=float, default=0,
#                     dest='geoLayer', help='geometry tree layers to be shown')


args=parser.parse_args()
inputfile=args.gdml
printTraj=False
rdseed=args.seed

myLcher=Launcher()
myLcher.setSeed(rdseed)

if inputfile=='':
    myLcher.loadFakeGeoPhysics()
else:
    if not os.path.isfile(inputfile):
        inputfile=findData(f'gdml/{inputfile}', '.')
        if not os.path.isfile(inputfile):
            raise IOError(f'The input GDML file {args.gdml} is not found.')
    myLcher.loadGeometry(inputfile)

if args.visualize is True:
    v = Visualiser(args.blacklist, printWorld=False, nSegments=args.nSegments, mergeMesh=args.mergemesh, dumpMesh=args.dumpmesh)
    for i in range(int(args.neutronNum)):
        myLcher.go(1, recordTrj=True, timer=False, save2Dis=False)
        trj = myLcher.getTrajectory()
        try:
            v.addTrj(trj)
        except ValueError:
            print("skip ValueError in File '/Prompt/scripts/prompt', in <module>, v.addLine(trj)")
    v.show()
else:
    myLcher.go(int(args.neutronNum), recordTrj=False)
