# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# for license information.

from __future__ import absolute_import, print_function, unicode_literals

import functools
import inspect
import io
import platform
import os
import sys
import threading
import traceback

import ptvsd
from ptvsd.common import compat, fmt, options, timestamp


LEVELS = ("debug", "info", "warning", "error")
"""Logging levels, lowest to highest importance.
"""

stderr = sys.__stderr__

stderr_levels = {"warning", "error"}
"""What should be logged to stderr.
"""

file_levels = set(LEVELS)
"""What should be logged to file, when it is not None.
"""

filename_prefix = "ptvsd"
"""Prefix for log file names that are automatically generated by to_file().
"""

file = None
"""If not None, which file to log to.

This can be automatically set by to_file().
"""

timestamp_format = "09.3f"
"""Format spec used for timestamps. Can be changed to dial precision up or down.
"""

_lock = threading.Lock()
_tls = threading.local()
_filename = None


# Used to inject a newline into stderr if logging there, to clean up the output
# when it's intermixed with regular prints from other sources.
def newline(level="info"):
    with _lock:
        if level in stderr_levels:
            try:
                stderr.write("\n")
            except Exception:
                pass


def write(level, text):
    assert level in LEVELS

    t = timestamp.current()
    format_string = "{0}+{1:" + timestamp_format + "}: "
    prefix = fmt(format_string, level[0].upper(), t)

    indent = "\n" + (" " * len(prefix))
    output = indent.join(text.split("\n"))
    output = prefix + output + "\n\n"

    with _lock:
        if level in stderr_levels:
            try:
                stderr.write(output)
            except Exception:
                pass

        if file and level in file_levels:
            try:
                file.write(output)
                file.flush()
            except Exception:
                pass

    return text


def write_format(level, format_string, *args, **kwargs):
    # Don't spend cycles doing expensive formatting if we don't have to. Errors are
    # always formatted, so that error() can return the text even if it's not logged.
    if level != "error":
        if not (level in stderr_levels or (file and level in file_levels)):
            return

    try:
        text = fmt(format_string, *args, **kwargs)
    except Exception:
        exception()
        raise
    return write(level, text)


debug = functools.partial(write_format, "debug")
info = functools.partial(write_format, "info")
warning = functools.partial(write_format, "warning")


def error(*args, **kwargs):
    """Logs an error.

    Returns the output wrapped in AssertionError. Thus, the following::

        raise log.error(...)

    has the same effect as::

        log.error(...)
        assert False, fmt(...)
    """
    return AssertionError(write_format("error", *args, **kwargs))


def stack(title="Stack trace"):
    stack = "\n".join(traceback.format_stack())
    debug("{0}:\n\n{1}", title, stack)


def exception(format_string="", *args, **kwargs):
    """Logs an exception with full traceback.

    If format_string is specified, it is formatted with fmt(*args, **kwargs), and
    prepended to the exception traceback on a separate line.

    If exc_info is specified, the exception it describes will be logged. Otherwise,
    sys.exc_info() - i.e. the exception being handled currently - will be logged.

    If level is specified, the exception will be logged as a message of that level.
    The default is "error".

    Returns the exception object, for convenient re-raising::

        try:
            ...
        except Exception:
            raise log.exception()  # log it and re-raise
    """

    level = kwargs.pop("level", "error")
    exc_info = kwargs.pop("exc_info", sys.exc_info())

    if format_string:
        format_string += "\n\n"
    format_string += "{exception}\nStack where logged:\n{stack}"

    exception = "".join(traceback.format_exception(*exc_info))

    f = inspect.currentframe()
    f = f.f_back if f else f  # don't log this frame
    try:
        stack = "".join(traceback.format_stack(f))
    finally:
        del f  # avoid cycles

    write_format(
        level, format_string, *args, exception=exception, stack=stack, **kwargs
    )

    return exc_info[1]


def to_file(filename=None):
    global file, _filename

    # TODO: warn when options.log_dir is unset, after fixing improper use in ptvsd.server
    if file is not None or options.log_dir is None:
        return

    _filename = _filename or filename
    if _filename is None:
        if options.log_dir is None:
            warning(
                "ptvsd.to_file() cannot generate log file name - ptvsd.options.log_dir is not set"
            )
            return
        _filename = fmt(
            "{0}/{1}-{2}.log", options.log_dir, filename_prefix, os.getpid()
        )

    file = io.open(_filename, "w", encoding="utf-8")
    info(
        "{0} {1}\n{2} {3} ({4}-bit)\nptvsd {5}",
        platform.platform(),
        platform.machine(),
        platform.python_implementation(),
        platform.python_version(),
        64 if sys.maxsize > 2 ** 32 else 32,
        ptvsd.__version__,
    )
    return _filename


def filename():
    return _filename


def describe_environment(header):
    import multiprocessing
    import sysconfig
    import site  # noqa

    result = [header, "\n\n"]

    def report(*args, **kwargs):
        result.append(fmt(*args, **kwargs))

    def report_paths(get_paths, label=None):
        prefix = fmt("    {0}: ", label or get_paths)

        expr = None
        if not callable(get_paths):
            expr = get_paths
            get_paths = lambda: eval(expr, {}, sys.modules)
        try:
            paths = get_paths()
        except AttributeError:
            report("{0}<missing>\n", prefix)
            return
        except Exception:
            exception(
                "Error evaluating {0}",
                repr(expr) if expr else compat.srcnameof(get_paths),
            )
            return

        if not isinstance(paths, (list, tuple)):
            paths = [paths]

        for p in sorted(paths):
            report("{0}{1}", prefix, p)
            rp = os.path.realpath(p)
            if p != rp:
                report("({0})", rp)
            report("\n")

            prefix = " " * len(prefix)

    report("CPU count: {0}\n\n", multiprocessing.cpu_count())
    report("System paths:\n")
    report_paths("sys.prefix")
    report_paths("sys.base_prefix")
    report_paths("sys.real_prefix")
    report_paths("site.getsitepackages()")
    report_paths("site.getusersitepackages()")

    site_packages = [
        p
        for p in sys.path
        if os.path.exists(p) and os.path.basename(p) == "site-packages"
    ]
    report_paths(lambda: site_packages, "sys.path (site-packages)")

    for name in sysconfig.get_path_names():
        expr = fmt("sysconfig.get_path({0!r})", name)
        report_paths(expr)

    report_paths("os.__file__")
    report_paths("threading.__file__")

    result = "".join(result).rstrip("\n")
    info("{0}", result)


# The following are helper shortcuts for printf debugging. They must never be used
# in production code.


def _repr(value):
    warning("$REPR {0!r}", value)


def _vars(*names):
    locals = inspect.currentframe().f_back.f_locals
    if names:
        locals = {name: locals[name] for name in names if name in locals}
    warning("$VARS {0!r}", locals)
