#! /usr/bin/env python

"""
Generates VAST taxonomy declarations from CIMs of other tools.

Usage:

   taxonomize --splunk < model.json > splunk.yaml
   taxonomize --sentinel > sentinel.yaml

The input requirements change with the tool-specific option. For example,
--splunk requires the JSON file from the TA on stdin, whereas --sentinel does
not process anything on stdin because the script downloads the declarations
from a website.
"""

import argparse
import json
import pandas
import re
import sys
import yaml
from yaml.representer import SafeRepresenter

# ---------------------------------------------------------------------------

# Tweak pyyaml to support folded blocks.
# See https://stackoverflow.com/a/20863889 for details.
class folded_str(str):
    pass


def change_style(style, representer):
    def new_representer(dumper, data):
        scalar = representer(dumper, data)
        scalar.style = style
        return scalar

    return new_representer

represent_folded_str = change_style(">", SafeRepresenter.represent_str)

yaml.add_representer(folded_str, represent_folded_str)

# ---------------------------------------------------------------------------

def yamlify(input, width=72):
    """Renders input as YAML suitable for a VAST taxonomy."""
    return yaml.dump(input, sort_keys=False, width=width)

def make_concept(prefix, name, description=None):
    """Creates a concept from name and description."""
    fqn = ".".join([prefix, name]) if prefix else name
    result = {"concept": {"name": fqn }}
    if description:
        result["concept"]["description"] = folded_str(description)
    return result

class Sentinel:
    URL = "https://docs.microsoft.com/en-us/azure/sentinel/normalization-schema"

    def __init__(self, prefix = "sentinel.network"):
        self.prefix = prefix

    def parse(self):
        # There is exactly one table on this page.
        table = pandas.read_html(self.URL, match="Field name")[0]
        clean = lambda x: x.replace(u"\u2019", "'")
        make = lambda row: make_concept(self.prefix, row[1], clean(row[4]))
        # There is a bogous NaN entry in the last row, so we only consider
        # strings.
        concepts = [make(x) for x in table.itertuples() if type(x[1]) is str]
        return yamlify(concepts)

class Splunk:
    def __init__(self, prefix = "splunk"):
        self.prefix = prefix

    def parse(self, data):
        def make(model, field):
            name = field["fieldName"]
            desc = None
            if "comment" in field:
                desc = field["comment"]["description"]
            return make_concept(f"{self.prefix}.{model.lower()}", name, desc)
        result = ""
        model = json.load(data)
        for object in model["objects"]:
            name = object["objectName"]
            result += f"#\n# Model: {name}\n#\n"
            fields = object["fields"]
            if fields:
                result += "\n# extracted fields\n"
                # Convert plain 1-to-1 mappings.
                result += yamlify([make(name, x) for x in fields])
            # Convert calculated fields.
            if "calculations" not in object:
                continue
            calculations = object["calculations"]
            if calculations:
                result += f"\n# calculated fields\n"
                for calculation in calculations:
                    output_fields = calculation["outputFields"]
                    result += yamlify([make(name, x) for x in output_fields])
        return result

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--sentinel", action="store_true",
                        help="generate Azure Sentinel taxonomy")
    parser.add_argument("--splunk", action="store_true",
                        help="generate splunk CIM taxonomy")
    args = parser.parse_args()
    if args.sentinel:
        sentinel = Sentinel()
        print(sentinel.parse())
    elif args.splunk:
        splunk = Splunk()
        print(splunk.parse(sys.stdin))
    else:
        print("missing taxonomy, see --help for options", file=sys.stderr)
        sys.exit(1)

if __name__ == "__main__":
    main()
