#! /usr/bin/python3

import getopt
import sys
import os.path
import re
import xml.dom.minidom
import build.nroff

OFP_ACTION_ALIGN = 8

# Map from OpenFlow version number to version ID used in ofp_header.
version_map = {"1.0": 0x01,
               "1.1": 0x02,
               "1.2": 0x03,
               "1.3": 0x04,
               "1.4": 0x05,
               "1.5": 0x06}
version_reverse_map = dict((v, k) for (k, v) in version_map.items())

# Map from vendor name to the length of the action header.
vendor_map = {"OF": (0x00000000,  4),
              "ONF": (0x4f4e4600, 10),
              "NX": (0x00002320, 10)}

# Basic types used in action arguments.
types = {}
types['uint8_t'] =  {"size": 1, "align": 1, "ntoh": None,     "hton": None}
types['ovs_be16'] = {"size": 2, "align": 2, "ntoh": "ntohs",  "hton": "htons"}
types['ovs_be32'] = {"size": 4, "align": 4, "ntoh": "ntohl",  "hton": "htonl"}
types['ovs_be64'] = {"size": 8, "align": 8, "ntoh": "ntohll", "hton": "htonll"}
types['uint16_t'] = {"size": 2, "align": 2, "ntoh": None,     "hton": None}
types['uint32_t'] = {"size": 4, "align": 4, "ntoh": None,     "hton": None}
types['uint64_t'] = {"size": 8, "align": 8, "ntoh": None,     "hton": None}

line = ""

arg_structs = set()

def round_up(x, y):
    return int((x + (y - 1)) / y) * y

def open_file(fn):
    global file_name
    global input_file
    global line_number
    file_name = fn
    input_file = open(file_name)
    line_number = 0

def get_line():
    global input_file
    global line
    global line_number
    line = input_file.readline()
    line_number += 1
    if line == "":
        fatal("unexpected end of input")
    return line

n_errors = 0
def error(msg):
    global n_errors
    sys.stderr.write("%s:%d: %s\n" % (file_name, line_number, msg))
    n_errors += 1

def fatal(msg):
    error(msg)
    sys.exit(1)

def usage():
    argv0 = os.path.basename(sys.argv[0])
    print('''\
%(argv0)s, for extracting OpenFlow action data
usage: %(argv0)s [prototypes | definitions] OFP-ACTIONS.c
usage: %(argv0)s ovs-actions OVS-ACTIONS.XML

Commands:

  prototypes OFP-ACTIONS.C
    Reads ofp-actions.c and prints a set of prototypes to #include early in
    ofp-actions.c.

  definitions OFP-ACTIONS.C
    Reads ofp-actions.c and prints a set of definitions to #include late in
    ofp-actions.c.

  ovs-actions OVS-ACTIONS.XML
    Reads ovs-actions.xml and prints documentation in troff format.\
''' % {"argv0": argv0})
    sys.exit(0)

def extract_ofp_actions(fn, definitions):
    error_types = {}

    comments = []
    names = []
    domain = {}
    for code, size in vendor_map.values():
        domain[code] = {}
    enums = {}

    n_errors = 0

    open_file(fn)

    while True:
        get_line()
        if re.match('enum ofp_raw_action_type {', line):
            break

    while True:
        get_line()
        if line.startswith('/*') or not line or line.isspace():
            continue
        elif re.match('}', line):
            break

        if not line.lstrip().startswith('/*'):
            fatal("unexpected syntax between actions")

        comment = line.lstrip()[2:].strip()
        while not comment.endswith('*/'):
            get_line()
            if line.startswith('/*') or not line or line.isspace():
                fatal("unexpected syntax within action")
            comment += ' %s' % line.lstrip('* \t').rstrip(' \t\r\n')
        comment = re.sub('\[[^]]*\]', '', comment)
        comment = comment[:-2].rstrip()

        m = re.match('([^:]+):\s+(.*)$', comment)
        if not m:
            fatal("unexpected syntax between actions")

        dsts = m.group(1)
        argtypes = m.group(2).strip().replace('.', '', 1)

        if 'VLMFF' in argtypes:
            arg_vl_mff_map = True
        else:
            arg_vl_mff_map = False
        argtype = argtypes.replace('VLMFF', '', 1).rstrip()

        get_line()
        m = re.match(r'\s+(([A-Z]+)_RAW([0-9]*)_([A-Z0-9_]+)),?', line)
        if not m:
            fatal("syntax error expecting enum value")

        enum = m.group(1)
        if enum in names:
            fatal("%s specified twice" % enum)

        names.append(enum)

        for dst in dsts.split(', '):
            m = re.match(r'([A-Z]+)([0-9.]+)(\+|-[0-9.]+)?(?:\((\d+)\))(?: is deprecated \(([^)]+)\))?$', dst)
            if not m:
                fatal("%r: syntax error in destination" % dst)
            vendor_name = m.group(1)
            version1_name = m.group(2)
            version2_name = m.group(3)
            type_ = int(m.group(4))
            deprecation = m.group(5)

            if vendor_name not in vendor_map:
                fatal("%s: unknown vendor" % vendor_name)
            vendor = vendor_map[vendor_name][0]

            if version1_name not in version_map:
                fatal("%s: unknown OpenFlow version" % version1_name)
            v1 = version_map[version1_name]

            if version2_name is None:
                v2 = v1
            elif version2_name == "+":
                v2 = max(version_map.values())
            elif version2_name[1:] not in version_map:
                fatal("%s: unknown OpenFlow version" % version2_name[1:])
            else:
                v2 = version_map[version2_name[1:]]

            if v2 < v1:
                fatal("%s%s: %s precedes %s"
                      % (version1_name, version2_name,
                         version2_name, version1_name))

            for version in range(v1, v2 + 1):
                domain[vendor].setdefault(type_, {})
                if version in domain[vendor][type_]:
                    v = domain[vendor][type_][version]
                    msg = "%#x,%d in OF%s means both %s and %s" % (
                        vendor, type_, version_reverse_map[version],
                        v["enum"], enum)
                    error("%s: %s." % (dst, msg))
                    sys.stderr.write("%s:%d: %s: Here is the location "
                                     "of the previous definition.\n"
                                     % (v["file_name"], v["line_number"],
                                        dst))
                    n_errors += 1
                else:
                    header_len = vendor_map[vendor_name][1]

                    base_argtype = argtype.replace(', ..', '', 1)
                    if base_argtype in types:
                        arg_align = types[base_argtype]['align']
                        arg_len = types[base_argtype]['size']
                        arg_ofs = round_up(header_len, arg_align)
                        min_length = round_up(arg_ofs + arg_len,
                                              OFP_ACTION_ALIGN)
                    elif base_argtype == 'void':
                        min_length = round_up(header_len, OFP_ACTION_ALIGN)
                        arg_len = 0
                        arg_ofs = 0
                    elif re.match(r'struct [a-zA-Z0-9_]+$', base_argtype):
                        min_length = 'sizeof(%s)' % base_argtype
                        arg_structs.add(base_argtype)
                        arg_len = 0
                        arg_ofs = 0
                        # should also emit OFP_ACTION_ALIGN assertion
                    else:
                        fatal("bad argument type %s" % argtype)

                    ellipsis = argtype != base_argtype
                    if ellipsis:
                        max_length = '65536 - OFP_ACTION_ALIGN'
                    else:
                        max_length = min_length

                    info = {"enum": enum,                     # 0
                            "deprecation": deprecation,       # 1
                            "file_name": file_name,           # 2
                            "line_number": line_number,       # 3
                            "min_length": min_length,         # 4
                            "max_length": max_length,         # 5
                            "arg_ofs": arg_ofs,               # 6
                            "arg_len": arg_len,               # 7
                            "base_argtype": base_argtype,     # 8
                            "arg_vl_mff_map": arg_vl_mff_map, # 9
                            "version": version,               # 10
                            "type": type_}                    # 11
                    domain[vendor][type_][version] = info

                    enums.setdefault(enum, [])
                    enums[enum].append(info)

    input_file.close()

    if n_errors:
        sys.exit(1)

    print("""\
/* Generated automatically; do not modify!     -*- buffer-read-only: t -*- */
""")

    if definitions:
        print("/* Verify that structs used as actions are reasonable sizes. */")
        for s in sorted(arg_structs):
            print("BUILD_ASSERT_DECL(sizeof(%s) %% OFP_ACTION_ALIGN == 0);" % s)

        print("\nstatic struct ofpact_raw_instance all_raw_instances[] = {")
        for vendor in domain:
            for type_ in domain[vendor]:
                for version in domain[vendor][type_]:
                    d = domain[vendor][type_][version]
                    print("    { { 0x%08x, %2d, 0x%02x }, " % (
                        vendor, type_, version))
                    print("      %s," % d["enum"])
                    print("      HMAP_NODE_NULL_INITIALIZER,")
                    print("      HMAP_NODE_NULL_INITIALIZER,")
                    print("      %s," % d["min_length"])
                    print("      %s," % d["max_length"])
                    print("      %s," % d["arg_ofs"])
                    print("      %s," % d["arg_len"])
                    print("      \"%s\"," % re.sub('_RAW[0-9]*', '', d["enum"], 1))
                    if d["deprecation"]:
                        print("      \"%s\"," % re.sub(r'(["\\])', r'\\\1', d["deprecation"]))
                    else:
                        print("      NULL,")
                    print("    },")
        print("};")

    for versions in enums.values():
        need_ofp_version = False
        for v in versions:
            assert v["arg_len"] == versions[0]["arg_len"]
            assert v["base_argtype"] == versions[0]["base_argtype"]
            if (v["min_length"] != versions[0]["min_length"] or
                v["arg_ofs"] != versions[0]["arg_ofs"] or
                v["type"] != versions[0]["type"]):
                need_ofp_version = True
        base_argtype = versions[0]["base_argtype"]

        decl = "static inline "
        if base_argtype.startswith('struct'):
            decl += "%s *" %base_argtype
        else:
            decl += "void"
        decl += "\nput_%s(struct ofpbuf *openflow" % versions[0]["enum"].replace('_RAW', '', 1)
        if need_ofp_version:
            decl += ", enum ofp_version version"
        if base_argtype != 'void' and not base_argtype.startswith('struct'):
            decl += ", %s arg" % base_argtype
        decl += ")"
        if definitions:
            decl += "{\n"
            decl += "    "
            if base_argtype.startswith('struct'):
                decl += "return "
            decl += "ofpact_put_raw(openflow, "
            if need_ofp_version:
                decl += "version"
            else:
                decl += "%s" % versions[0]["version"]
            decl += ", %s, " % versions[0]["enum"]
            if base_argtype.startswith('struct') or base_argtype == 'void':
                decl += "0"
            else:
                ntoh = types[base_argtype]['ntoh']
                if ntoh:
                    decl += "%s(arg)" % ntoh
                else:
                    decl += "arg"
            decl += ");\n"
            decl += "}"
        else:
            decl += ";"
        print(decl)
        print("")

    if definitions:
        print("""\
static enum ofperr
ofpact_decode(const struct ofp_action_header *a, enum ofp_raw_action_type raw,
              enum ofp_version version, uint64_t arg,
              const struct vl_mff_map *vl_mff_map,
              uint64_t *tlv_bitmap, struct ofpbuf *out)
{
    switch (raw) {\
""")
        for versions in enums.values():
            enum = versions[0]["enum"]
            print("    case %s:" % enum)
            base_argtype = versions[0]["base_argtype"]
            arg_vl_mff_map = versions[0]["arg_vl_mff_map"]
            if base_argtype == 'void':
                print("        return decode_%s(out);" % enum)
            else:
                if base_argtype.startswith('struct'):
                    arg = "ALIGNED_CAST(const %s *, a)" % base_argtype
                else:
                    hton = types[base_argtype]['hton']
                    if hton:
                        arg = "%s(arg)" % hton
                    else:
                        arg = "arg"
                if arg_vl_mff_map:
                    print("        return decode_%s(%s, version, vl_mff_map, tlv_bitmap, out);" % (enum, arg))
                else:
                    print("        return decode_%s(%s, version, out);" % (enum, arg))
            print("")
        print("""\
    default:
        OVS_NOT_REACHED();
    }
}\
""")
    else:
        for versions in enums.values():
            enum = versions[0]["enum"]
            prototype = "static enum ofperr decode_%s(" % enum
            base_argtype = versions[0]["base_argtype"]
            arg_vl_mff_map = versions[0]["arg_vl_mff_map"]
            if base_argtype != 'void':
                if base_argtype.startswith('struct'):
                    prototype += "const %s *, enum ofp_version, " % base_argtype
                else:
                    prototype += "%s, enum ofp_version, " % base_argtype
                if arg_vl_mff_map:
                    prototype += 'const struct vl_mff_map *, uint64_t *, '
            prototype += "struct ofpbuf *);"
            print(prototype)

        print("""
static enum ofperr ofpact_decode(const struct ofp_action_header *,
                                 enum ofp_raw_action_type raw,
                                 enum ofp_version version,
                                 uint64_t arg, const struct vl_mff_map *vl_mff_map,
                                 uint64_t *tlv_bitmap, struct ofpbuf *out);
""")

## ------------------------ ##
## Documentation Generation ##
## ------------------------ ##

def action_to_xml(action_node, body):
    syntax = 0
    for node in action_node.childNodes:
        if node.nodeType == node.ELEMENT_NODE and node.tagName == 'syntax':
            if body[-1].strip() == '.PP':
                del body[-1]
            if syntax:
                body += ['.IQ\n']
            else:
                body += ['.IP "\\fBSyntax:\\fR"\n']
            body += [build.nroff.inline_xml_to_nroff(x, r'\fR')
                     for x in node.childNodes] + ['\n']
            syntax += 1
        elif (node.nodeType == node.ELEMENT_NODE
              and node.tagName == 'conformance'):
            body += ['.IP "\\fBConformance:\\fR"\n']
            body += [build.nroff.block_xml_to_nroff(node.childNodes)]
        else:
            body += [build.nroff.block_xml_to_nroff([node])]

def group_xml_to_nroff(group_node):
    title = group_node.attributes['title'].nodeValue

    body = []
    for node in group_node.childNodes:
        if node.nodeType == node.ELEMENT_NODE and node.tagName == 'action':
            action_to_xml(node, body)
        else:
            body += [build.nroff.block_xml_to_nroff([node])]

    content = [
        '.bp\n',
        '.SH \"%s\"\n' % build.nroff.text_to_nroff(title.upper())]
    content += body
    return ''.join(content)

def make_ovs_actions(ovs_actions_xml):
    document = xml.dom.minidom.parse(ovs_actions_xml)
    doc = document.documentElement

    global version
    if version == None:
        version = "UNKNOWN"

    print('''\
'\\" tp
.\\" -*- mode: troff; coding: utf-8 -*-
.TH "ovs\-actions" 7 "%s" "Open vSwitch" "Open vSwitch Manual"
.fp 5 L CR              \\" Make fixed-width font available as \\fL.
.de ST
.  PP
.  RS -0.15in
.  I "\\\\$1"
.  RE
..

.de SU
.  PP
.  I "\\\\$1"
..

.de IQ
.  br
.  ns
.  IP "\\\\$1"
..

.de TQ
.  br
.  ns
.  TP "\\\\$1"
..
.de URL
\\\\$2 \\(laURL: \\\\$1 \\(ra\\\\$3
..
.if \\n[.g] .mso www.tmac
.SH NAME
ovs\-actions \- OpenFlow actions and instructions with Open vSwitch extensions
.
.PP
''' % version)

    s = ''
    for node in doc.childNodes:
        if node.nodeType == node.ELEMENT_NODE and node.tagName == "group":
            s += group_xml_to_nroff(node)
        elif node.nodeType == node.TEXT_NODE:
            assert node.data.isspace()
        elif node.nodeType == node.COMMENT_NODE:
            pass
        else:
            s += build.nroff.block_xml_to_nroff([node])

    if n_errors:
        sys.exit(1)

    output = []
    for oline in s.split("\n"):
        oline = oline.strip()

        # Life is easier with nroff if we don't try to feed it Unicode.
        # Fortunately, we only use a few characters outside the ASCII range.
        oline = oline.replace(u'\u2208', r'\[mo]')
        oline = oline.replace(u'\u2260', r'\[!=]')
        oline = oline.replace(u'\u2264', r'\[<=]')
        oline = oline.replace(u'\u2265', r'\[>=]')
        oline = oline.replace(u'\u00d7', r'\[mu]')
        if len(oline):
            output += [oline]

    # nroff tends to ignore .bp requests if they come after .PP requests,
    # so remove .PPs that precede .bp.
    for i in range(len(output)):
        if output[i] == '.bp':
            j = i - 1
            while j >= 0 and output[j] == '.PP':
                output[j] = None
                j -= 1
    for i in range(len(output)):
        if output[i] is not None:
            print(output[i])


## ------------ ##
## Main Program ##
## ------------ ##

if __name__ == '__main__':
    argv0 = sys.argv[0]
    try:
        options, args = getopt.gnu_getopt(sys.argv[1:], 'h',
                                          ['help', 'ovs-version='])
    except getopt.GetoptError as geo:
        sys.stderr.write("%s: %s\n" % (argv0, geo.msg))
        sys.exit(1)

    global version
    version = None
    for key, value in options:
        if key in ['-h', '--help']:
            usage()
        elif key == '--ovs-version':
            version = value
        else:
            sys.exit(0)

    if not args:
        sys.stderr.write("%s: missing command argument "
                         "(use --help for help)\n" % argv0)
        sys.exit(1)

    commands = {"prototypes": (lambda fn: extract_ofp_actions(fn, False), 1),
                "definitions": (lambda fn: extract_ofp_actions(fn, True), 1),
                "ovs-actions": (make_ovs_actions, 1)}

    if not args[0] in commands:
        sys.stderr.write("%s: unknown command \"%s\" "
                         "(use --help for help)\n" % (argv0, args[0]))
        sys.exit(1)

    func, n_args = commands[args[0]]
    if len(args) - 1 != n_args:
        sys.stderr.write("%s: \"%s\" requires %d arguments but %d "
                         "provided\n"
                         % (argv0, args[0], n_args, len(args) - 1))
        sys.exit(1)

    func(*args[1:])
