#!/usr/bin/python3

import argparse
import multiprocessing
import contextlib
from collections import defaultdict
import os
import io
import sys
import warnings
import shutil
import logging
import traceback
import tempfile
import ast
from termcolor import cprint
import importlib
import json
import docutils
import docutils.core
import docutils.nodes

log = logging.getLogger()


class Fail(Exception):
    pass


class RunEnv:
    """
    Test environment for python tests
    """
    TEST_BUFR = "extra/bufr/synop-sunshine.bufr"

    def __init__(self, entry):
        self.entry = entry
        self.globals = dict(globals())
        self.locals = dict(locals())
        self.assigned = set()
        self.wanted = set()
        self.stdout = io.StringIO()
        self.stderr = io.StringIO()

        # Import modules expected by the snippets
        for modname in ("dballe", "wreport", "os", "datetime"):
            self.globals[modname] = importlib.__import__(modname)

    def check_reqs(self):
        """
        Check if the snippet expects some well-known variables to be defined
        """
        tree = ast.parse(self.entry["code"], self.entry["src"], "exec")
        for node in ast.walk(tree):
            if isinstance(node, ast.Assign):
                if len(node.targets) != 1:
                    continue
                if not isinstance(node.targets[0], ast.Name):
                    continue
                self.assigned.add(node.targets[0].id)
            elif isinstance(node, ast.Attribute):
                if not isinstance(node.value, ast.Name):
                    continue
                if node.value.id in self.assigned:
                    continue
                self.wanted.add(node.value.id)
            elif isinstance(node, ast.With):
                if len(node.items) != 1:
                    continue
                self.assigned.add(node.items[0].optional_vars.id)

    # Provide some assertion methods the snippets can use

    def assertEqual(self, a, b):
        if a != b:
            raise AssertionError(f"{a!r} != {b!r}")

    @contextlib.contextmanager
    def setup(self, capture_output=True):
        import dballe

        # Build a test environment in a temporary directory
        with tempfile.TemporaryDirectory() as root:
            shutil.copy(self.TEST_BUFR, os.path.join(root, "test.bufr"))
            db_url = "sqlite:" + os.path.join(root, "db.sqlite")
            db = dballe.DB.connect(db_url + "?wipe=yes")
            importer = dballe.Importer("BUFR")
            with importer.from_file(os.path.join(root, "test.bufr")) as f:
                db.import_messages(f)

            # Create well-known variables if needed
            if "db" in self.wanted or "tr" in self.wanted or "explorer" in self.wanted:
                self.locals["db"] = db
            if "explorer" in self.wanted:
                self.locals["explorer"] = dballe.Explorer()
                with self.locals["explorer"].rebuild() as update:
                    with self.locals["db"].transaction() as tr:
                        update.add_db(tr)
            if "tr" in self.wanted:
                self.locals["tr"] = self.locals["db"].transaction()
            if "msg" in self.wanted:
                importer = dballe.Importer("BUFR")
                with importer.from_file("extra/bufr/synop-sunshine.bufr") as f:
                    for msg in f:
                        self.locals["msg"] = msg[0]

            # Set env variables the snippets may use
            os.environ["DBA_DB"] = db_url

            # Run the code, in the temp directory, discarding stdout
            curdir = os.getcwd()
            if capture_output:
                orig_stdout = sys.stdout
                orig_stderr = sys.stderr
                sys.stdout = self.stdout
                sys.stderr = self.stderr
            os.chdir(root)
            try:
                yield
            finally:
                if capture_output:
                    sys.stderr = orig_stderr
                    sys.stdout = orig_stdout
                os.chdir(curdir)

    def store_exc(self):
        type, value, tb = sys.exc_info()
        for frame, lineno in traceback.walk_tb(tb):
            if frame.f_code.co_filename == self.entry["src"]:
                self.entry["exception"] = (lineno, repr(value))
        if "exception" not in self.entry:
            self.entry["exception"] = None, traceback.format_exc()

    def run(self):
        """
        Run a compiled test snippet
        """
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            self.entry["warnings"] = w

            try:
                code = compile(self.entry["code"], self.entry["src"], "exec")
            except Exception:
                self.entry["result"] = "fail-compile"
                self.store_exc()
                return

            self.check_reqs()

            with self.setup():
                try:
                    exec(code, self.globals, self.locals)
                except Exception:
                    self.entry["result"] = "fail-run"
                    self.store_exc()
                    return
                finally:
                    self.entry["stdout"] = self.stdout.getvalue()
                    self.entry["stderr"] = self.stderr.getvalue()

        self.entry["result"] = "ok"


def run_python_test(entry):
    """
    Run a collected python code snippet, annotating its information with run
    results
    """
    env = RunEnv(entry)
    env.run()
    return entry


def format_python_result(entry, verbose=False):
    """
    Colorful formatting of a code snippet annotated with run results
    """
    if "exception" in entry:
        color = "red"
        ok = False
    elif entry["warnings"]:
        color = "yellow"
        ok = False
    else:
        color = "green"
        ok = True

    cprint(f"{entry['src']}: {entry['source']}:{entry['line']}: {entry['result']}", color, attrs=["bold"])

    # Turn warnings and exceptions into line annotations
    line_annotations = defaultdict(list)
    if "exception" in entry:
        lineno, msg = entry["exception"]
        line_annotations[lineno].append((msg, "red"))
    if entry["warnings"]:
        for w in entry["warnings"]:
            line_annotations[w.lineno].append((f"{w.category}, {w.message}", "yellow"))

    show_code = line_annotations or entry.get("stderr")
    if verbose and entry.get("stdout"):
        show_code = True

    if show_code:
        for idx, line in enumerate(entry["code"].splitlines(), start=1):
            annotations = line_annotations.pop(idx, None)
            cprint(f" {idx:2d} {line}", "red" if annotations else "grey", attrs=["bold"])
            if annotations is not None:
                for msg, color in annotations:
                    cprint(f"  ↪ {msg}", color)

        for lineno, annotations in line_annotations.items():
            for msg, color in annotations:
                for line in msg.splitlines():
                    cprint(f"  ↪ {line}", color)

    if verbose and entry.get("stdout"):
        for line in entry["stdout"].splitlines():
            cprint(f"  O {line}", "blue")

    if entry.get("stderr"):
        for line in entry["stderr"].splitlines():
            cprint(f"  E {line}", "blue")

    if not ok:
        print()


def run_tests(entries, verbose=False):
    """
    Run a list of collected test snippets
    """
    by_lang = defaultdict(list)
    for e in entries:
        lang = e["lang"]
        if lang == "default":
            lang = "python"
        by_lang[lang].append(e)

    entries = by_lang["python"]
    if entries:
        with multiprocessing.Pool(1) as p:
            entries = p.map(run_python_test, entries)

        for entry in entries:
            format_python_result(entry, verbose)

    # Only python supported so far


def get_code_from_rst(fname):
    """
    Return the code in the first code block in the given rst
    """
    with open(fname, "rt") as fd:
        doctree = docutils.core.publish_doctree(
                fd, source_class=docutils.io.FileInput, settings_overrides={"input_encoding": "unicode"})

        for node in doctree.traverse(docutils.nodes.literal_block):
            # if "dballe.DB.connect" in str(node):
            for subnode in node.traverse(docutils.nodes.Text):
                return {
                    "src": fname,
                    "lang": "python",
                    "code": subnode,
                    "source": node.source,
                    "line": node.line,
                }


def main():
    parser = argparse.ArgumentParser(description="Run code found in documentation code blocks")
    parser.add_argument("--verbose", "-v", action="store_true", help="verbose output")
    parser.add_argument("--debug", action="store_true", help="debug output")
    parser.add_argument("-r", "--run", action="store_true", help="run the first code snippet found in a .rst file")
    parser.add_argument("code", action="store", nargs="?", default="doc/test_code.json",
                        help="JSON file with the collected test code")
    args = parser.parse_args()

    # Setup logging
    FORMAT = "%(asctime)-15s %(levelname)s %(message)s"
    if args.debug:
        logging.basicConfig(level=logging.DEBUG, stream=sys.stderr, format=FORMAT)
    elif args.verbose:
        logging.basicConfig(level=logging.INFO, stream=sys.stderr, format=FORMAT)
    else:
        logging.basicConfig(level=logging.WARN, stream=sys.stderr, format=FORMAT)

    if args.run:
        entry = get_code_from_rst(args.code)
        if entry is None:
            raise Fail(f"No code found in {args.code}")
        env = RunEnv(entry)
        env.run()
        format_python_result(entry, verbose=args.debug or args.verbose)
    else:
        with open(args.code, "rt") as fd:
            entries = json.load(fd)

        run_tests(entries, verbose=args.debug or args.verbose)


if __name__ == "__main__":
    try:
        sys.exit(main())
    except Fail as e:
        print(e, file=sys.stderr)
        sys.exit(1)
