#
# scmver.core
#
#   Copyright (c) 2019-2021 Akinori Hattori <hattya@gmail.com>
#
#   SPDX-License-Identifier: MIT
#

import datetime
import importlib
import os
import re
import sys
import textwrap
from typing import cast, Any, Callable, Dict, Mapping, NamedTuple, Optional, Pattern, Tuple, Union


__all__ = ['generate', 'get_version', 'load_version', 'next_version', 'stat',
           'SCMInfo', 'Version', 'VersionError']

SV = Tuple[str, int]
ISV = Tuple[str, str, str, int]

_TEMPLATE = textwrap.dedent("""\
    # file generated by scmver; DO NOT EDIT.

    version = '{version}'
""")

_pep440_re = re.compile(r"""
    \A
    v?
    # public version identifiers
    (?:             # epoch segment
        (?P<epoch>[0-9]+) !
    )?
    (?P<release>    # release segment
        [0-9]+ (?:\. [0-9]+)*
    )
    (?:             # pre-release segment
        (?P<pre_sep>[-._])?
        (?P<pre_s>
            a  | alpha |
            b  | beta  |
            rc | c     | pre (?:view)?
        )
        (?P<pre_opt_sep>[-._])?
        (?P<pre_n>[0-9]*)
    )?
    (?:             # post-release segment
        (?:
            (?:
                (?P<post_sep>[-._])?
                (?P<post_s>
                    post | r (?:ev)?
                )
                (?P<post_opt_sep>[-._])?
            ) |
            -
        )
        (?P<post_n>(?(post_s)[0-9]* | [0-9]+))
    )?
    (?:             # development release segment
        (?P<dev_sep>[-._])?
        (?P<dev_s>
            dev
        )
        (?P<dev_opt_sep>[-._])?
        (?P<dev_n>[0-9]*)
    )?
    # local version identifiers
    (?:
        \+
        (?P<local>
            [a-z0-9] (?:[a-z0-9-_.]* [a-z0-9])?
        )
    )?
    \Z
""", re.IGNORECASE | re.VERBOSE)
_sep_re = re.compile(r'[-._]')
_version_re = re.compile(r'(?P<version>v?\d+.*)\Z')


def generate(path: str, version: Optional[str], info: Optional['SCMInfo'] = None, template: str = _TEMPLATE) -> None:
    kwargs: Dict[str, Any] = {'version': version}
    if info:
        kwargs.update(revision=info.revision,
                      branch=info.branch)
    with open(path, 'w') as fp:
        fp.write(template.format(**kwargs))


def get_version(root: str = '.', **kwargs: Any) -> Optional[str]:
    def take(d: Mapping[str, str], *keys: str) -> Dict[str, Any]:
        return {k: d[k] for k in d if k in keys}

    root = os.path.abspath(root)
    info = stat(root, **{k: kwargs[k] for k in kwargs if k.endswith('.tag')})
    if info:
        version = next_version(info, **take(kwargs, 'spec', 'local', 'version'))
        if 'write_to' in kwargs:
            generate(os.path.join(root, kwargs['write_to']), version, info, **take(kwargs, 'template'))
        return version
    elif 'fallback' in kwargs:
        fallback = kwargs['fallback']
        if callable(fallback):
            return cast(str, fallback())
        else:
            if isinstance(fallback, str):
                spec = fallback
                path = None
            else:
                spec = fallback[0]
                path = os.path.join(root, fallback[1])
            return load_version(spec, path)
    return None


def load_version(spec: str, path: Optional[str] = None) -> str:
    v = spec.split(':')
    if len(v) != 2:
        raise ValueError('invalid format')

    if path:
        sys.path.append(path)
        try:
            o = importlib.import_module(v[0])
        finally:
            sys.path.pop()
    else:
        o = importlib.import_module(v[0])

    for a in v[1].split('.'):
        o = getattr(o, a)
    return cast(str, o() if callable(o) else o)


def next_version(info: 'SCMInfo', spec: str = 'post', local: str = '{local:%Y-%m-%d}', version: Pattern[str] = _version_re) -> Optional[str]:
    m = version.search(info.tag)
    if not m:
        raise VersionError('cannot parse version from SCM tag')

    pv = Version(m.group('version'))
    if info.distance > 0:
        pv.update(spec, info.distance)

    if callable(local):
        lv = local(info)
    elif info.dirty:
        lv = local.format(distance=info.distance,
                          revision=info.revision,
                          branch=info.branch,
                          utc=datetime.datetime.utcnow(),
                          local=datetime.datetime.now())
    else:
        lv = None
    return str(pv) if not lv else f'{pv}+{lv}'


def stat(path: str, **kwargs: Any) -> Optional['SCMInfo']:
    impls: Tuple[Tuple[str, Callable[..., Optional[SCMInfo]]], ...]
    try:
        import pkg_resources

        impls = tuple((ep.name, ep.load()) for ep in pkg_resources.iter_entry_points('scmver.parse'))
        if not impls:
            raise ImportError
    except ImportError:
        from . import bazaar as bzr, fossil as fsl, git, mercurial as hg, subversion as svn

        impls = (('.bzr', bzr.parse), ('.fslckout', fsl.parse), ('_FOSSIL_', fsl.parse),
                 ('.git', git.parse), ('.hg', hg.parse), ('.hg_archival.txt', hg.parse),
                 ('.svn', svn.parse))

    path = os.path.abspath(path)
    while True:
        for name, parse in impls:
            if (kwargs.get(name, True)
                and os.path.exists(os.path.join(path, name))):
                info = parse(path, name=name, **kwargs)
                if info:
                    return info
        p, path = path, os.path.dirname(path)
        if path == p:
            return None


class SCMInfo(NamedTuple):

    tag: str = '0.0'
    distance: int = 0
    revision: Optional[Union[int, str]] = None
    dirty: bool = False
    branch: Optional[str] = None


class Version:

    __slots__ = ('epoch', 'release', '_pre', '_post', '_dev', 'local')

    _pre: Optional[ISV]
    _post: Optional[ISV]
    _dev: Optional[ISV]

    def __init__(self, version: str) -> None:
        m = _pep440_re.match(version.strip())
        if not m:
            raise VersionError(f'invalid version: {version!r}')

        self.epoch = int(m.group('epoch')) if m.group('epoch') else 0
        self.release = tuple(map(int, m.group('release').split('.')))
        for g in ('pre', 'post', 'dev'):
            s = m.group(g + '_s')
            n = m.group(g + '_n')
            setattr(self, '_' + g, (m.group(g + '_sep') or '', s, m.group(g + '_opt_sep') or '', int(n) if n else -1) if s or n else None)
        self.local = m.group('local')

    def __repr__(self) -> str:
        return f'<{self.__class__.__name__}({self})>'

    def __str__(self) -> str:
        def seg(v: ISV) -> Tuple[str, ...]:
            return (v[0], v[1], v[2], str(v[3]) if v[3] >= 0 else '')

        buf = []
        if self.epoch != 0:
            buf.append(f'{self.epoch}!')
        buf.append('.'.join(map(str, self.release)))
        if self._pre:
            buf += seg(self._pre)
        if self._post:
            if self._post[1]:
                buf += seg(self._post)
            else:
                buf.append(f'-{self._post[3]}')
        if self._dev:
            buf += seg(self._dev)
        if self.local:
            buf.append(f'+{self.local}')
        return ''.join(buf)

    @property
    def pre(self) -> Optional[SV]:
        return self._pre[1::2] if self._pre else None

    @property
    def post(self) -> Optional[SV]:
        return self._post[1::2] if self._post else None

    @property
    def dev(self) -> Optional[SV]:
        return self._dev[1::2] if self._dev else None

    def normalize(self) -> 'Version':
        def seg(s: str, v: ISV, sep: str = '.') -> ISV:
            return (sep, s, '', v[3] if v[3] >= 0 else 0)

        v = self.__class__(str(self).lower())
        if v._pre:
            s = v._pre[1]
            if s == 'alpha':
                s = 'a'
            elif s == 'beta':
                s = 'b'
            elif s in ('c', 'pre', 'preview'):
                s = 'rc'
            v._pre = seg(s, v._pre, sep='')
        if v._post:
            v._post = seg('post', v._post)
        if v._dev:
            v._dev = seg('dev', v._dev)
        if v.local:
            v.local = '.'.join(str(int(s)) if s.isdigit() else s for s in _sep_re.split(v.local))
        return v

    def update(self, spec: str, value: int = 1) -> None:
        if self.local:
            raise VersionError('local version identifiers exists')

        def update(ver: int, val: int) -> int:
            if ver < 0:
                return val if val > 0 else -1
            return ver + val

        def zero(v: int) -> int:
            return v if v >= 0 else 0

        spec = spec.lower()
        if spec == 'major':
            self.release = (zero(self.release[0] + value),) + self.release[1:]
            self._pre = self._post = self._dev = None
        elif spec == 'minor':
            if len(self.release) < 2:
                self.release += (zero(value),)
            else:
                self.release = self.release[:1] + (zero(self.release[1] + value),) + self.release[2:]
            self._pre = self._post = self._dev = None
        elif spec in ('micro', 'patch'):
            if len(self.release) < 2:
                self.release += (0, zero(value),)
            elif len(self.release) < 3:
                self.release += (zero(value),)
            else:
                self.release = self.release[:2] + (zero(self.release[2] + value),) + self.release[3:]
            self._pre = self._post = self._dev = None
        elif spec in ('pre', 'dev'):
            v = getattr(self, '_' + spec)
            if not v:
                raise VersionError(f'{"pre-" if spec != "dev" else "development "}release segment does not exist')
            setattr(self, '_' + spec, v[:3] + (update(v[3], value),))
        elif spec == 'post':
            if self._post:
                if self._post[1]:
                    self._post = self._post[:3] + (update(self._post[3], value),)
                else:
                    self._post = self._post[:3] + (zero(self._post[3] + value),)
            elif value >= 0:
                self._post = ('.', 'post', '', value if value > 1 else -1)
        elif spec.endswith('.dev'):
            spec = spec[:-len('.dev')]
            if spec == 'major':
                i = 1
            elif spec == 'minor':
                i = 2
            elif spec in ('micro', 'patch'):
                i = 3
            else:
                raise VersionError('invalid segment specifier')
            if value < 0:
                raise VersionError('invalid value')

            self.release = self.release[:i] + (0,) * (len(self.release) - i)
            self.update(spec)
            self._dev = ('.', 'dev', '', value if value > 1 else -1)


class VersionError(ValueError):
    pass
