#!/usr/bin/env python3

# Module: pl_gen_fused
#
# **** License ****
# Copyright (c) 2017-2021, AT&T Intellectual Property.  All rights reserved.
#
# SPDX-License-Identifier: LGPL-2.1-only
# **** End License ****
#
# Generate fused-mode pipeline code from node and feature declarations
#
# Generates the following in the header file if requested:
# - per-node disposition enum derived from node declaration next_nodes
#   (if num_next > 0)
# - per-node fused function declaration
# - per-node feature iterator function declaration (if node fills in
#   feat_iterator field)
# - function declarations for fused graph entry points
# - function declarations for fused node feature invocation
#
# Generates the following in the implementation source file if requested:
# - fused graph entry point functions calling fused node functions for
#   requested entry points
# - fused node feature invocation for requested feature points
#

import sys
import argparse
import os
import io

nodes = {}
feats_for_feat_point = {}


def remove_quotes(string):
    """Remove leading and trailing double quotes from a string"""
    if not (string.startswith('"') and string.endswith('"')):
        raise RuntimeError('expected double quotes around string "{}"'.format(string))
    return string[1:-1]


class NodeDecl:
    """
    Parsed pipeline node declaration
    """
    def __init__(self):
        self.name = None
        self.next_nodes = {}
        self.__disp_order = []
        self.handler = None
        self.default_disp = None
        # Since PL_PROC has a value of 0 in the C header and if the
        # type field is omitted it will be initialised to 0 and hence
        # set to PL_PROC
        self.node_type = 'PL_PROC'
        self.__references_self = False
        self.num_next = None
        self.feat_iterate = None
        self.feat_type_find = None

    def set_handler(self, handler):
        self.handler = handler

    def set_name(self, name):
        self.name = remove_quotes(name)

    def add_next_node(self, next_node, disp):
        """
        Adds a next node to the declaration

        The first next node is the default disposition which will be
        treated as the "fall-through" case when generating the fused
        graph entry point.
        """
        if self.default_disp is None:
            self.default_disp = disp
        self.__disp_order.append(disp)
        self.next_nodes[disp] = next_node
        if self.get_next_node(disp) == self.name:
            self.__references_self = True

    def set_type(self, node_type):
        self.node_type = node_type

    def set_num_next_sym(self, num_next):
        self.num_next = num_next

    def set_feat_iterate(self, feat_iterate):
        self.feat_iterate = feat_iterate

    def set_feat_type_find(self, feat_type_find):
        self.feat_type_find = feat_type_find

    @property
    def fused_no_dyn_feats_handler(self):
        if self.feat_iterate is not None:
            return self.handler.replace('_process', '_fused_no_dyn_feats')
        if self.feat_type_find is not None:
            return self.handler.replace('_process', '_fused_no_dyn_feats')
        return self.fused_handler

    @property
    def fused_handler(self):
        return self.handler.replace('_process', '_fused')

    @property
    def references_self(self):
        return self.__references_self

    @property
    def domain(self):
        return self.name[:self.name.find(':')]

    @property
    def c_name(self):
        """Returns the name of the node is a form usable in a C identifier"""
        name = self.name[self.name.find(':') + 1:]
        return name.replace('-', '_').replace(' ', '_')

    @property
    def ordered_disps(self):
        return self.__disp_order

    def get_next_node(self, disp):
        """Returns the next node name corresponding to the given disposition"""
        next_node_name = self.next_nodes[disp]
        if ':' not in next_node_name:
            return self.domain + ':' + next_node_name
        return next_node_name

    def validate(self):
        if self.name is None:
            raise RuntimeError('no name for node')
        if self.handler is None:
            raise RuntimeError('handler not set for node {}'.format(self.name))


class FeatureDecl:
    """
    Parsed pipeline feature declaration
    """
    def __init__(self):
        self.name = None
        self.__node_name = None
        self.__feature_point = None
        self.id = None
        self.visit_before = None
        self.__visit_after = None
        self.next_feature = None
        self.feat_type = None
        self.always_on = None

    def set_name(self, name):
        self.name = remove_quotes(name)

    def set_node_name(self, node_name):
        self.__node_name = remove_quotes(node_name)

    def set_feature_point(self, feature_point):
        self.__feature_point = remove_quotes(feature_point)

    def set_id(self, id):
        self.id = id

    def set_visit_before(self, visit_before):
        self.visit_before = remove_quotes(visit_before)

    def set_visit_after(self, visit_after):
        self.__visit_after = remove_quotes(visit_after)

    def set_feat_type(self, feat_type):
        self.feat_type = feat_type

    def set_always_on(self, always_on):
        self.always_on = always_on

    @property
    def domain(self):
        return self.name[:self.name.find(':')]

    @property
    def node_name(self):
        node_name = self.__node_name
        if ':' not in node_name:
            return self.domain + ':' + node_name
        return node_name

    @property
    def feature_point(self):
        feature_point = self.__feature_point
        if ':' not in feature_point:
            return self.domain + ':' + feature_point
        return feature_point

    @property
    def visit_after(self):
        visit_after = self.__visit_after
        if visit_after is not None and ':' not in visit_after:
            return self.domain + ':' + visit_after
        return visit_after

    def validate(self):
        if self.name is None:
            raise RuntimeError('no name for feature')
        if self.node_name is None:
            raise RuntimeError(
                'node_name not set for feature {}'.format(self.name))
        if self.feature_point is None:
            raise RuntimeError(
                'feature_point not set for feature {}'.format(self.name))


def parse_source_file(filename):
    """
    Parse a node/feature source file

    Adds the parsed nodes and features to the nodes and
    feats_for_feat_point global dictionaries respectively.
    """
    parsing_node_decl = None
    parsing_node_decl_next = False
    parsing_feature_decl = None
    f = open(filename, 'r')
    lines = f.readlines()

    for line in lines:
        if parsing_node_decl is not None:
            if parsing_node_decl_next:
                if '[' in line:
                    start_disp = line.find('[') + len('[')
                    end_disp = line.find(']', start_disp)
                    start_next_node = line.find('"', end_disp + 1) + len('"')
                    end_next_node = line.find('"', start_next_node)
                    parsing_node_decl.add_next_node(
                        line[start_next_node:end_next_node].strip(),
                        line[start_disp:end_disp].strip())
                elif '}' in line:
                    parsing_node_decl_next = False
            else:
                if '};' in line:
                    if args.debug:
                        print("{}: Node '{}': handler '{}', next nodes:".format(filename, parsing_node_decl.name, parsing_node_decl.handler))
                        for (disp, next_node) in parsing_node_decl.next_nodes.items():
                            print('     {} = {}'.format(disp, next_node))
                    parsing_node_decl.validate()
                    nodes[parsing_node_decl.name] = parsing_node_decl
                    parsing_node_decl = None
                else:
                    match = {
                        'name': parsing_node_decl.set_name,
                        'handler': parsing_node_decl.set_handler,
                        'type': parsing_node_decl.set_type,
                        'num_next': parsing_node_decl.set_num_next_sym,
                        'feat_iterate': parsing_node_decl.set_feat_iterate,
                        'feat_type_find': parsing_node_decl.set_feat_type_find,
                    }
                    field_start = line.find('.')
                    if field_start < 0:
                        continue
                    field_end = line.find('=', field_start)
                    if field_end < 0:
                        continue
                    field = line[field_start + 1:field_end].strip()
                    if field in match:
                        start = field_end + len('=')
                        end = line.find(',', start)
                        value = line[start:end].strip()
                        match[field](value)
                    elif field == 'next':
                        parsing_node_decl_next = True
        elif parsing_feature_decl is not None:
            if '};' in line:
                if args.debug:
                    print("{}: Feature '{}': node '{}', point '{}', id '{}', after '{}', type '{}'".format(
                        filename, parsing_feature_decl.name,
                        parsing_feature_decl.node_name,
                        parsing_feature_decl.feature_point,
                        parsing_feature_decl.id,
                        parsing_feature_decl.visit_after,
                        parsing_feature_decl.feat_type))
                parsing_feature_decl.validate()
                if parsing_feature_decl.feature_point not in feats_for_feat_point:
                    feats_for_feat_point[parsing_feature_decl.feature_point] = {}
                feats_for_feat_point[parsing_feature_decl.feature_point][parsing_feature_decl.name] = parsing_feature_decl
                parsing_feature_decl = None
            else:
                match = {
                    'name': parsing_feature_decl.set_name,
                    'node_name': parsing_feature_decl.set_node_name,
                    'feature_point': parsing_feature_decl.set_feature_point,
                    'visit_after': parsing_feature_decl.set_visit_after,
                    'visit_before': parsing_feature_decl.set_visit_before,
                    'id': parsing_feature_decl.set_id,
                    'feat_type': parsing_feature_decl.set_feat_type,
                    'always_on': parsing_feature_decl.set_always_on,
                }
                field_start = line.find('.')
                if field_start < 0:
                    continue
                field_end = line.find('=', field_start)
                if field_end < 0:
                    continue
                field = line[field_start + 1:field_end].strip()
                if field in match:
                    start = field_end + len('=')
                    end = line.find(',', start)
                    value = line[start:end].strip()
                    match[field](value)
        else:
            if 'PL_REGISTER_NODE(' in line:
                parsing_node_decl = NodeDecl()
            elif 'PL_REGISTER_FEATURE(' in line:
                parsing_feature_decl = FeatureDecl()


def write_indent(f, indent_lvl, string):
    """Write to a file, indenting the string to the given level"""
    # Don't indent on empty lines
    if not string:
        indent_lvl = 0
    f.write('{}{}\n'.format('\t' * indent_lvl, string))


def gen_invoke_fused_node(f, indent_lvl, from_feature_point, dyn_feats, from_case_feature, node):
    """
    Generate the action of invoking a node

    If from_feature_point is False then encountering continue nodes
    causes the function to return true immediately and to not generate
    calls to pl_release_storage(). Otherwise, a continue statement is
    generated. For output nodes if from_feature_point is False then
    encountering output nodes causes the function to jump to a cleanup
    label (where pl_release_storage() can be called), otherwise the
    function returns false without calling pl_release_storage().

    Handles simple loops in the graph of a node referring to itself by
    generating a do/while loop. However, more complex loops aren't
    detected or processed correctly and will cause the generator to go
    into an infinite loop.

    The node's default disposition is used to generate a fall-through
    case (i.e. not embedded in an if-statement. This will typically be
    the "accept" case, but it doesn't have to be.
    """
    else_str = ''
    resp_assign = 'resp = '
    self_ref_disp = None
    ret = True
    # don't assign resp variable if only one next node, since we
    # assume that a PL_PROC node always returns a valid value and so
    # in that case we wouldn't check it an the C compiler might warn
    # about a useless assignment
    if len(node.next_nodes) <= 1:
        resp_assign = ''
        # there would be no escape from the loop if only one next
        # node, and we'd need to add additional code to handle this
        # case to avoid a compile error
        if node.references_self:
            raise RuntimeError(
                'node {} cannot refer to itself without more than one next node'.format(node.name))
        ret = False

    if node.node_type == 'PL_CONTINUE':
        if node.next_nodes:
            raise RuntimeError(
                'continue node {} cannot have next nodes'.format(node.name))
        if not from_feature_point or from_case_feature:
            write_indent(f, indent_lvl, 'return true;')
        else:
            write_indent(f, indent_lvl, 'continue;')
        return False
    elif node.node_type == 'PL_OUTPUT':
        if node.next_nodes:
            raise RuntimeError(
                'output node {} cannot have next nodes'.format(node.name))
        if dyn_feats:
            write_indent(f, indent_lvl, '{}{}(pl_pkt, NULL);'.format(resp_assign, node.fused_handler))
        else:
            write_indent(f, indent_lvl, '{}{}(pl_pkt, NULL);'.format(resp_assign, node.fused_no_dyn_feats_handler))
        if from_feature_point:
            write_indent(f, indent_lvl, 'return false;')
        else:
            write_indent(f, indent_lvl, 'goto cleanup;')
        return False
    elif node.node_type == 'PL_PROC':
        if from_feature_point and not from_case_feature:
            storage_ctx = "storage_ctx"
        else:
            storage_ctx = "NULL"

        if node.references_self:
            write_indent(f, indent_lvl, 'do {')
            indent_lvl = indent_lvl + 1
        if dyn_feats:
            write_indent(f, indent_lvl, '{}{}(pl_pkt, {});'.format(resp_assign, node.fused_handler, storage_ctx))
        else:
            write_indent(f, indent_lvl, '{}{}(pl_pkt, {});'.format(resp_assign, node.fused_no_dyn_feats_handler, storage_ctx))
    else:
        raise RuntimeError(
            'invalid node type: {} for node {}'.format(node.node_type, node.name))

    for disp in node.ordered_disps:
        next_node = node.get_next_node(disp)
        if next_node == node.name:
            if self_ref_disp:
                raise RuntimeError(
                    'node {} cannot have multiple disps that refer to itself in fused mode'.format(node.name))
            self_ref_disp = disp
            continue
        # default disp is handled below
        if disp == node.default_disp:
            continue

        write_indent(f, indent_lvl, '{}if (unlikely(resp == {})) {{'.format(else_str, disp))
        else_str = '} else '
        r = gen_invoke_fused_node(f, indent_lvl + 1, from_feature_point, dyn_feats, from_case_feature, nodes[next_node])
        if r is True:
            ret = True
    if len(node.next_nodes) > 1:
        write_indent(f, indent_lvl, '}')
    if node.references_self:
        indent_lvl = indent_lvl - 1
        write_indent(f, indent_lvl, '}} while (unlikely(resp == {}));'.format(self_ref_disp))
    write_indent(f, indent_lvl, '')
    r = gen_invoke_fused_node(f, indent_lvl, from_feature_point, dyn_feats, from_case_feature, nodes[node.get_next_node(node.default_disp)])
    if r is True:
        ret = True

    return ret


def gen_fused_graph(f, entry, dyn_feats):
    """Generate fused mode graph starting from the given entry point"""
    if entry not in nodes:
        raise RuntimeError('Unknown entry-point node: {}'.format(entry))
    node = nodes[entry]
    write_indent(f, 0, 'ALWAYS_INLINE bool')
    if dyn_feats:
        write_indent(f, 0, 'pipeline_fused_{}(struct pl_packet *pl_pkt)'.format(node.c_name))
    else:
        write_indent(f, 0, 'pipeline_fused_no_dyn_feats_{}(struct pl_packet *pl_pkt)'.format(node.c_name))
    write_indent(f, 0, '{')

    f_temp = io.StringIO()

    ret = gen_invoke_fused_node(f_temp, 1, False, dyn_feats, False, node)
    if ret is True:
        write_indent(f, 1, 'int resp;')
        write_indent(f, 0, '')

    f_temp.seek(0)
    f.write(f_temp.read())

    write_indent(f, 0, 'cleanup:')
    write_indent(f, 1, 'pl_release_storage(pl_pkt);')
    write_indent(f, 1, 'return false;')
    write_indent(f, 0, '}')


def gen_fused_feature_invoke_by_case_find(f, node, feat_point, dyn_feats):
    """
    Generate fused feature find functions for the given node.

    Finds the node based on the value without having to do an external lookup
    """
    if feat_point in feats_for_feat_point:
        features = dict(feats_for_feat_point[feat_point])
    else:
        features = {}

    head_features = dict(features)

    write_indent(f, 0, 'ALWAYS_INLINE unsigned int')
    if dyn_feats:
        write_indent(f, 0, '{}_fused(unsigned int feat)'.format(node.feat_type_find))
    else:
        write_indent(f, 0, '{}_fused_no_dyn_features(unsigned int feat)'.format(node.feat_type_find))

    write_indent(f, 0, '{')
    write_indent(f, 1, '')
    write_indent(f, 1, 'switch (feat) {')
    for feature in head_features.values():
        while True:
            if feature.always_on:
                write_indent(f, 1, 'case {}:'.format(feature.feat_type))
                write_indent(f, 2, 'return {};'.format(feature.id))
            if not feature.next_feature:
                break
            feature = feature.next_feature
    write_indent(f, 1, 'default:')
    write_indent(f, 2, 'break;')
    write_indent(f, 1, '}')
    if dyn_feats:
        write_indent(f, 1, 'return {}(feat);'.format(node.feat_type_find))
    else:
        write_indent(f, 1, 'return 0;')
    write_indent(f, 0, '}')


def gen_fused_feature_invoke_by_case(f, node, feat_point, dyn_feats):
    """
    Generate fused feature invocation for the given feature point
    based on the feature type

    Only the registered feature for the type is invoked
    """

    if args.debug:
        print("::{} {} {}".format(node, feat_point, dyn_feats))
    if feat_point in feats_for_feat_point:
        features = dict(feats_for_feat_point[feat_point])
    else:
        features = {}

    head_features = dict(features)

    write_indent(f, 0, 'ALWAYS_INLINE bool')
    if dyn_feats:
        write_indent(f, 0, 'pipeline_fused_{}_features(struct pl_packet *pl_pkt, unsigned int feat)'.format(node.c_name))
    else:
        write_indent(f, 0, 'pipeline_fused_{}_no_dyn_features(struct pl_packet *pl_pkt, unsigned int feat)'.format(node.c_name))

    write_indent(f, 0, '{')
    write_indent(f, 1, 'int resp = true;')

    write_indent(f, 1, '')
    write_indent(f, 1, 'switch (feat) {')

    write_indent(f, 1, 'case 0:')
    write_indent(f, 2, 'return true;')
    for feature in head_features.values():
        while True:
            write_indent(f, 1, 'case {}:'.format(feature.id))
            gen_invoke_fused_node(f, 2, True, dyn_feats, True, nodes[feature.node_name])
            write_indent(f, 2, '')
            if not feature.next_feature:
                break
            feature = feature.next_feature
    write_indent(f, 1, 'default:')
    write_indent(f, 2, 'if (!pl_node_invoke_feature({}_node_ptr, feat, pl_pkt, NULL))'.format(node.c_name))
    write_indent(f, 3, 'return false;')
    write_indent(f, 2, 'break;')
    write_indent(f, 1, '}')
    write_indent(f, 1, 'return resp;')
    write_indent(f, 0, '}')


def gen_fused_features_invoke(f, feat_point, dyn_feats):
    """
    Generate fused feature invocation for the given feature point

    The features are listed in the order they get added to the feature
    point array
    """
    if feat_point not in nodes:
        raise RuntimeError('Unknown feature point node: {}'.format(feat_point))
    node = nodes[feat_point]

    if node.feat_type_find:
        gen_fused_feature_invoke_by_case(f, node, feat_point, dyn_feats)
        gen_fused_feature_invoke_by_case_find(f, node, feat_point, dyn_feats)
        return

    if not node.feat_iterate:
        raise RuntimeError('Entry-point node not feature point: {}'.format(feat_point))
    if feat_point in feats_for_feat_point:
        features = dict(feats_for_feat_point[feat_point])
    else:
        features = {}

    write_indent(f, 0, 'ALWAYS_INLINE bool')
    if dyn_feats:
        write_indent(f, 0, 'pipeline_fused_{}_features(struct pl_packet *pl_pkt, struct pl_node *node)'.format(node.c_name))
    else:
        write_indent(f, 0, 'pipeline_fused_{}_no_dyn_features(struct pl_packet *pl_pkt __unused, struct pl_node *node)'.format(node.c_name))
    write_indent(f, 0, '{')
    write_indent(f, 1, 'unsigned int feature = ~0u;')
    write_indent(f, 1, 'void *context;')
    write_indent(f, 1, 'void *storage_ctx = NULL;')
    write_indent(f, 1, 'bool more;')
    resp_written = False
    head_features = dict(features)
    for feature in features.values():
        if feature.node_name not in nodes:
            raise RuntimeError(
                "unknown node {} for feature {}".format(feature.node_name, feature.name))
        if len(nodes[feature.node_name].next_nodes) > 1 and not resp_written:
            write_indent(f, 1, 'int resp;')
            resp_written = True
        if feature.visit_after:
            if feature.visit_after not in features:
                raise RuntimeError(
                    "unknown visit after {} for feature {}".format(feature.visit_after, feature.name))
            # We only support one feature referencing another feature
            # as to run after it. However, this is only an
            # optimisation to the C compiler to try to list the
            # features in the order they would be visited if multiple
            # features are enabled
            features[feature.visit_after].next_feature = feature
            del head_features[feature.name]
    write_indent(f, 1, '')
    write_indent(f, 1, 'for (more = {}(node, true, &feature, &context, &storage_ctx);'.format(node.feat_iterate))
    write_indent(f, 1, '     more;')
    write_indent(f, 1, '     more = {}(node, false, &feature, &context, &storage_ctx)) {{'.format(node.feat_iterate))
    write_indent(f, 2, 'switch (feature) {')
    for feature in head_features.values():
        while True:
            if feature.id is None:
                continue
            write_indent(f, 2, 'case {}:'.format(feature.id))
            gen_invoke_fused_node(f, 3, True, dyn_feats, False, nodes[feature.node_name])
            write_indent(f, 3, '')
            if not feature.next_feature:
                break
            feature = feature.next_feature
    if dyn_feats:
        write_indent(f, 2, 'default:')
        write_indent(f, 3, 'if (!pl_node_invoke_feature({}_node_ptr, feature, pl_pkt, storage_ctx))'.format(node.c_name))
        write_indent(f, 4, 'return false;')
        write_indent(f, 3, 'continue;')
    write_indent(f, 2, '}')
    write_indent(f, 1, '}')
    write_indent(f, 1, '')
    write_indent(f, 1, 'return true;')
    write_indent(f, 0, '}')


def gen_preamble(f):
    """Write out preamble comment for generated source and header files"""
    f.write('/*\n')
    f.write(' * Auto-generated by {} from:\n'.format(sys.argv[0]))
    f.write(' *\n')
    for filename in args.source_files:
        f.write(' * {}\n'.format(filename))
    f.write(' */\n')


def gen_fused_impl(f, includes, entry_points, feat_points):
    """Generate fused implementation source file"""
    gen_preamble(f)
    f.write('#include <pl_node.h>\n')
    f.write('#include <rte_branch_prediction.h>\n')
    f.write('#include "compiler.h"\n')
    if includes:
        for include in includes:
            f.write('#include "{}"\n'.format(include))
    if entry_points is not None:
        for entry in entry_points:
            f.write('\n')
            gen_fused_graph(f, entry, False)
            f.write('\n')
            gen_fused_graph(f, entry, True)
    if feat_points is not None:
        for feat_point in feat_points:
            f.write('\n')
            gen_fused_features_invoke(f, feat_point, True)
            f.write('\n')
            gen_fused_features_invoke(f, feat_point, False)

    f.write('void pl_gen_fused_init(struct pl_node_registration *node)\n')
    f.write('{\n')

    for node_name in sorted(nodes.keys()):
        node = nodes[node_name]
        write_indent(f, 1, 'if (strcmp(node->name, "{}") == 0) {{'.format(node.name))
        write_indent(f, 2, 'node->node_decl_id = PL_NODE_{}_ID;'.format(node.c_name.upper()))
        if node.feat_iterate or node.feat_type_find:
            write_indent(f, 2, 'node->feature_point_id = PL_FEATURE_POINT_{}_ID; return; }}'.format(node.c_name.upper()))
        else:
            write_indent(f, 2, 'node->feature_point_id = PL_FEATURE_POINT_NONE_ID; return; }')
    f.write('}\n')


def gen_node_disps(f, node):
    """
    Generate disposition enum for nodes

    The order of the enum values is determined by the order in which
    the dispositions appear in the node's next_node initialisation
    """
    if node.ordered_disps:
        write_indent(f, 0, 'enum {}_dispositions {{'.format(node.c_name))
        for disp in node.ordered_disps:
            write_indent(f, 1, '{},'.format(disp))
        write_indent(f, 1, '{},'.format(node.num_next))
        write_indent(f, 0, '};')
        write_indent(f, 0, '')


def gen_node_fused_func_decls(f):
    """
    Generate node fused processing function declaration and feature
    iteration function declaration
    """
    write_indent(f, 0, '')
    write_indent(f, 0, '/* Node fused declarations */')
    for node_name in sorted(nodes.keys()):
        node = nodes[node_name]
        gen_node_disps(f, node)
        write_indent(f, 0, 'extern unsigned int {}(struct pl_packet *, void *context);'.format(node.handler))
        if node.feat_iterate is not None or node.feat_type_find is not None:
            write_indent(f, 0, '')
            write_indent(f, 0, 'extern unsigned int {}_common(struct pl_packet *, void *context __unused, enum pl_mode);'.format(node.handler))
            write_indent(f, 0, 'static ALWAYS_INLINE unsigned int')
            write_indent(f, 0, '{}(struct pl_packet *pl_pkt, void *context)'.format(node.fused_handler))
            write_indent(f, 0, '{')
            write_indent(f, 1, 'pl_inc_node_stat(PL_NODE_{}_ID);'.format(node.c_name.upper()))
            write_indent(f, 1, 'return {}_common(pl_pkt, context, PL_MODE_FUSED);'.format(node.handler))
            write_indent(f, 0, '}')
            write_indent(f, 0, '')
            write_indent(f, 0, 'static ALWAYS_INLINE unsigned int')
            write_indent(f, 0, '{}(struct pl_packet *pl_pkt, void *context)'.format(node.fused_no_dyn_feats_handler))
            write_indent(f, 0, '{')
            write_indent(f, 1, 'pl_inc_node_stat(PL_NODE_{}_ID);'.format(node.c_name.upper()))
            write_indent(f, 1, 'return {}_common(pl_pkt, context, PL_MODE_FUSED_NO_DYN_FEATS);'.format(node.handler))
            write_indent(f, 0, '}')
            write_indent(f, 0, '')
            if node.feat_iterate is not None:
                write_indent(f, 0, 'bool')
                write_indent(f, 0, '{}(struct pl_node *node, bool first, unsigned int *feature_id, void **context, void **storage_ctx);'.format(node.feat_iterate))
            elif node.feat_type_find is not None:
                write_indent(f, 0, 'int')
                write_indent(f, 0, '{}(uint32_t feat_type);'.format(node.feat_type_find))
                write_indent(f, 0, 'unsigned int {}_fused(unsigned int feat);'.format(node.feat_type_find))
                write_indent(f, 0, 'unsigned int {}_fused_no_dyn_features(unsigned int feat);'.format(node.feat_type_find))
        else:
            write_indent(f, 0, 'static ALWAYS_INLINE unsigned int')
            write_indent(f, 0, '{}(struct pl_packet *pl_pkt, void *context)'.format(node.fused_handler))
            write_indent(f, 0, '{')
            write_indent(f, 1, 'pl_inc_node_stat(PL_NODE_{}_ID);'.format(node.c_name.upper()))
            write_indent(f, 1, 'return {}(pl_pkt, context);'.format(node.handler))
            write_indent(f, 0, '}')


def gen_fused_header(f, c_file_name, entry_points, feat_points):
    """Generate fused header file"""
    gen_preamble(f)
    c_file_name = c_file_name.upper()
    f.write('#ifndef __{}__\n'.format(c_file_name))
    f.write('#define __{}__\n'.format(c_file_name))
    f.write('\n')
    f.write('#include "compiler.h"\n')
    f.write('#include "pl_common.h"\n')
    f.write('#include "pl_internal.h"\n')
    f.write('#include "util.h"\n')
    f.write('#include "../src/pipeline/pl_fused.h"\n')

    f.write('\n')
    f.write('void pl_gen_fused_init(struct pl_node_registration *node);\n')
    f.write('\n')
    f.write('enum pl_node_decl_id {\n')
    for node_name in sorted(nodes.keys()):
        node = nodes[node_name]
        write_indent(f, 1, 'PL_NODE_{}_ID,'.format(node.c_name.upper()))
    f.write('PL_NODE_NUM_IDS};\n')

    f.write('\n')
    f.write('enum pl_feature_point_id {\n')
    write_indent(f, 1, 'PL_FEATURE_POINT_NONE_ID,')
    for node_name in sorted(nodes.keys()):
        node = nodes[node_name]
        if node.feat_iterate or node.feat_type_find:
            write_indent(f, 1, 'PL_FEATURE_POINT_{}_ID,'.format(node.c_name.upper()))
    f.write('PL_FEATURE_POINT_NUM_IDS};\n')

    gen_node_fused_func_decls(f)
    write_indent(f, 0, '/* Fused-mode graph entry points */')
    if entry_points is not None:
        for entry in entry_points:
            if entry not in nodes:
                raise RuntimeError(
                    'Unknown entry-point node: {}'.format(entry))
            node = nodes[entry]
            f.write('bool pipeline_fused_{}(struct pl_packet *pl_pkt);\n'.format(node.c_name))
            f.write('bool pipeline_fused_no_dyn_feats_{}(struct pl_packet *pl_pkt);\n'.format(node.c_name))
            f.write('\n')
    write_indent(f, 0, '/* Fused-mode feature invocations */')
    if feat_points is not None:
        for feat_point in feat_points:
            if feat_point not in nodes:
                raise RuntimeError(
                    'Unknown feature point node: {}'.format(feat_point))
            node = nodes[feat_point]
            if not node.feat_type_find:
                if not node.feat_iterate:
                    raise RuntimeError(
                        'Feature point {} not declared as feature point node'.format(feat_point))
                f.write('bool\n')
                f.write('pipeline_fused_{}_features(struct pl_packet *pl_pkt, struct pl_node *node);\n'.format(node.c_name))
                f.write('bool\n')
                f.write('pipeline_fused_{}_no_dyn_features(struct pl_packet *pl_pkt, struct pl_node *node);\n'.format(node.c_name))
            else:
                f.write('bool\n')
                f.write('pipeline_fused_{}_features(struct pl_packet *pl_pkt, unsigned int feat);\n'.format(node.c_name))
                f.write('bool\n')
                f.write('pipeline_fused_{}_no_dyn_features(struct pl_packet *pl_pkt, unsigned int feat);\n'.format(node.c_name))

            f.write('\n')
    f.write('#endif /* __{}__ */\n'.format(c_file_name))


arg_parser = argparse.ArgumentParser(description='Generate pipeline fused mode files')
arg_parser.add_argument('--debug', action='store_true',
                        help='Enable printing of debugging information')
arg_parser.add_argument('--entry', action='append',
                        help='Generate function as an entry point into a fused graph')
arg_parser.add_argument('--feature-point', action='append',
                        help='Generate function for invoking fused features on a node')
arg_parser.add_argument('source_files', nargs='+', metavar='source-file',
                        help='a source file containing node or feature declarations')
arg_parser.add_argument('--impl-out', action='store',
                        help='File to output generated fused implementation to')
arg_parser.add_argument('--include', action='append',
                        help='Name of header to include in implementation')
arg_parser.add_argument('--header-out', action='store',
                        help='File to output generated fused header to')
args = arg_parser.parse_args()

for filename in args.source_files:
    parse_source_file(filename)

if args.impl_out:
    f = sys.stdout if args.impl_out == '=' else open(args.impl_out, 'w')
    gen_fused_impl(f, args.include, args.entry, args.feature_point)

if args.header_out:
    f = sys.stdout if args.header_out == '=' else open(args.header_out, 'w')
    c_file_name = os.path.basename(args.header_out).replace('.', '_').replace('-', '_')
    gen_fused_header(f, c_file_name, args.entry, args.feature_point)
