#!python

from ase.io import write
from susmost.acutils import load_ac_list
from susmost.enumphases import enum_phases, enum_surface_supercells
import sys
from ase.geometry import get_distances
from ase.data.vdw_alvarez import vdw_radii
from itertools import combinations
import numpy as np

def enum_compositions(S, N):
	assert N > 0, N
	if N == 1:
		yield [S]
	else:
		for i in range(S+1):
			for c in enum_compositions(S-i, N-1):
				yield [i] + c

def vdw_hsph_check(atoms):
	adsorbates = atoms[atoms.arrays['adsorption']!='s']
	if len(adsorbates) == 0:
		return True
	_,D = get_distances(adsorbates.positions, adsorbates.positions, cell=atoms.cell, pbc=True)
	np.fill_diagonal(D, 10000)
	Dv = vdw_radii[adsorbates.numbers, None] + vdw_radii[None, adsorbates.numbers]
	return np.min(D / Dv)

if len(sys.argv) == 1:
  print (f"Enumerates phases of a 2D system (adsorption layer, 2-d material with impurities, etc)\nUsage:\n{sys.argv[0]} <command>\n<command> ::= \n\tsym <empty-surface.xyz> <ac1.xyz> <ac2.xyz> ... <acn.xyz> : count number of symmetries for each adsorption complex (from files <acX.xyz>) and empty surface (from file <empty-surface.xyz>)\n\tenum <N> <empty-surface.xyz> <ac1.xyz> ... <acn.xyz> : enumerate phases with supercell size up to <N> unit cells")
  sys.exit(0)

cmd = sys.argv[1]
assert cmd in ['sym', 'enum', 'enum-vdw']
if cmd == 'sym':
  ac_fns = sys.argv[2:]
elif cmd in ['enum', 'enum-vdw']:
  maxN = int(sys.argv[2])
  ac_fns = sys.argv[3:]


acs = load_ac_list(ac_fns)
for ac in acs:
	print (f'{ac.name} symop count = {len(ac.symops)}')
	for so in ac.symops:
		print(f"\t{so}")

if cmd == 'sym':
  sys.exit(0)

idx = 0
for N in range(1,maxN+1):
	for sc_idx, supercell in enumerate(enum_surface_supercells(N, acs[0])):
		for ic,c in enumerate(enum_compositions(N, len(acs))):
			print (f'N={N}, supercell: {supercell[:2, :2].ravel()} composition: {c}')
			for i,x in enumerate(enum_phases(acs, supercell, c)):
				#print ('\t',N,i,x._values.ravel(), x._positions[:,:2].ravel())
				sample_atoms = x.surface_sample()
				if cmd == 'enum-vdw':
					min_vdw_D = vdw_hsph_check(sample_atoms)
					sample_atoms.info['min_vdw_D'] = min_vdw_D
					if min_vdw_D < 0.3:
						continue
				idx += 1
				fn = f"phase_{idx}_{N}_{sc_idx}_{ic}_{i}.xyz"
				print ("\t", fn)
				write(fn, [sample_atoms], 'extxyz')

