#! /usr/bin/env python

# -----------------------------------------------------------------------
# Convert a macro prototype to a LaTeX \newcommand
# By Scott Pakin <scott+nc@pakin.org>
# -----------------------------------------------------------------------
# Copyright (C) 2010 Scott Pakin, scott+nc@pakin.org
#
# This package may be distributed and/or modified under the conditions
# of the LaTeX Project Public License, either version 1.3c of this
# license or (at your option) any later version.  The latest version of
# this license is in:
#
#     http://www.latex-project.org/lppl.txt
#
# and version 1.3c or later is part of all distributions of LaTeX version
# 2006/05/20 or later.
# -----------------------------------------------------------------------

from spark import GenericScanner, GenericParser, GenericASTTraversal
import re
import copy

class ParseError(Exception):
    "Represent any error that occurs during processing."
    pass


class Token:
    "Represent a single lexed token."

    def __init__(self, type, charOffset, attr=None):
        self.type = type
        self.attr = attr
        self.charOffset = charOffset

    def __cmp__(self, o):
        return cmp(self.type, o)

    def __str__(self):
        return self.attr


class AST:
    "Represent an abstract syntax tree."

    def __init__(self, type, charOffset, attr=None, kids=[]):
        self.type = type
        self.charOffset = charOffset
        self.attr = attr
        self.kids = kids

    def __getitem__(self, child):
        return self.kids[child]

    def __len__(self):
        return len(self.kids)


class CmdScanner(GenericScanner):
    "Defines a lexer for macro prototypes."

    def __init__(self):
        GenericScanner.__init__(self)
        self.charOffset = 0

    def tokenize(self, input):
        self.rv = []
        GenericScanner.tokenize(self, input)
        return self.rv

    def t_whitespace(self, whiteSpace):
        r' [\s\r\n]+ '
        self.charOffset = self.charOffset + len(whiteSpace)

    def t_command(self, cmd):
        r' MACRO '
        self.rv.append(Token(type='command',
                             attr=cmd,
                             charOffset=self.charOffset))
        self.charOffset = self.charOffset + len(cmd)

    def t_argument_type(self, arg):
        r' OPT '
        self.rv.append(Token(type='argtype',
                             attr=arg,
                             charOffset=self.charOffset))
        self.charOffset = self.charOffset + len(arg)

    def t_argument(self, arg):
        r' \#\d+ '
        self.rv.append(Token(type='argument',
                             attr=arg,
                             charOffset=self.charOffset))
        self.charOffset = self.charOffset + len(arg)

    def t_equal(self, equal):
        r' = '
        self.rv.append(Token(type=equal,
                             attr=equal,
                             charOffset=self.charOffset))
        self.charOffset = self.charOffset + len(equal)

    def t_quoted(self, quoted):
        r' \{[^\}]*\} '
        self.rv.append(Token(type='quoted',
                             attr=quoted,
                             charOffset=self.charOffset))
        self.charOffset = self.charOffset + len(quoted)

    def t_identifier(self, ident):
        r' [A-Za-z]+ '
        self.rv.append(Token(type='ident',
                             attr=ident,
                             charOffset=self.charOffset))
        self.charOffset = self.charOffset + len(ident)

    def t_delimiter(self, delim):
        r' [()\[\]] '
        self.rv.append(Token(type='delim',
                             attr=delim,
                             charOffset=self.charOffset))
        self.charOffset = self.charOffset + len(delim)

    def t_other(self, other):
        r' [^()\[\]\{\}\#\s\r\n]+ '
        self.rv.append(Token(type='other',
                             attr=other,
                             charOffset=self.charOffset))
        self.charOffset = self.charOffset + len(other)


class CmdParser(GenericParser):
    "Defines a parser for macro prototypes."

    def __init__(self, start='decl'):
        GenericParser.__init__(self, start)

    def error(self, token):
        raise ParseError, \
            ("Syntax error", 1+token.charOffset)

    def p_optarg(self, args):
        ' optarg ::= argtype delim defvals delim '
        return AST(type='optarg',
                   charOffset=args[0].charOffset,
                   attr=(args[1].attr, args[3].attr),
                   kids=[args[2]])

    def p_rawtext(self, args):
        ' rawtext ::= other '
        return AST(type='rawtext',
                   charOffset=args[0].charOffset,
                   attr=args[0].attr)

    def p_defval(self, args):
        ' defval ::= argument = quoted '
        return AST(type='defval',
                   charOffset=args[0].charOffset,
                   attr=(args[0].attr, args[2].attr))

    def p_defvals_1(self, args):
        ' defvals ::= defval '
        return AST(type='defvals',
                   charOffset=args[0].charOffset,
                   kids=args)

    def p_defvals_2(self, args):
        '''
        defvals ::= defval rawtext defvals
        defvals ::= defval ident defvals
        defvals ::= defval quoted defvals
        '''
        return AST(type='defvals',
                   charOffset=args[0].charOffset,
                   attr=(args[1].type, args[1].attr, args[1].charOffset),
                   kids=[args[0],args[2]])

    # Top-level macro argument
    def p_arg_1(self, args):
        '''
        arg ::= quoted
        arg ::= argument
        '''
        return AST(type='arg',
                   charOffset=args[0].charOffset,
                   attr=[args[0].type]+[args[0].attr])

    def p_arg_2(self, args):
        ' arg ::= optarg '
        return AST(type='arg',
                   charOffset=args[0].charOffset,
                   attr=[args[0].type]+[args[0].attr],
                   kids=args[0].kids)

    def p_arg_3(self, args):
        ' arg ::= rawtext '
        if args[0].attr != "*":
            raise ParseError, \
                ('Literal text must be quoted between "{" and "}"',
                 args[0].charOffset + 1)
        return AST(type='arg',
                   charOffset=args[0].charOffset,
                   attr=[args[0].type]+[args[0].attr])

    def p_arglist_1(self, args):
        ' arglist ::= arg '
        return AST(type='arglist',
                   charOffset=args[0].charOffset,
                   kids=args)

    def p_arglist_2(self, args):
        ' arglist ::= arg arglist '
        return AST(type='arglist',
                   charOffset=args[0].charOffset,
                   kids=args)

    def p_decl_1(self, args):
        ' decl ::= command ident '
        return AST(type='decl',
                   charOffset=args[0].charOffset,
                   attr=(args[0].attr, args[1].attr),
                   kids=[])

    def p_decl_2(self, args):
        ' decl ::= command ident arglist '
        return AST(type='decl',
                   charOffset=args[0].charOffset,
                   attr=(args[0].attr, args[1].attr),
                   kids=[args[2]])


def flattenAST(ast):
    class FlattenAST(GenericASTTraversal):
        "Flatten an AST into a list of arguments."

        def __init__(self, ast):
            GenericASTTraversal.__init__(self, ast)
            self.postorder()
            self.argList = ast.argList

        def n_defval(self, node):
            node.argList = (node.attr[0], node.attr[1], node.charOffset)

        def n_defvals(self, node):
            node.argList = [node.kids[0].argList]
            if node.attr:
                node.argList = node.argList + [node.attr] + node.kids[1].argList

        def n_arg(self, node):
            if node.attr[0] == "optarg":
                node.argList = node.attr + node.kids[0].argList
            else:
                node.argList = tuple(node.attr + [node.charOffset])

        def n_arglist(self, node):
            node.argList = [node.kids[0].argList]
            if len(node.kids) == 2:
                node.argList = node.argList + node.kids[1].argList

        def n_decl(self, node):
            node.argList = [(node.attr[0], node.attr[1], node.charOffset)]
            if node.kids != []:
                node.argList = node.argList + node.kids[0].argList

        def default(self, node):
            raise ParseError, \
                ('Internal error -- node type "%s" was unexpected' % node.type,
                 1+node.charOffset)

    return FlattenAST(ast).argList


def checkArgList(argList):
    "Raise an error if any problems are detected with the given argument list."

    def getFormals(sublist):
        "Return the formal-parameter numbers in the order in which they appear."
        if sublist == []:
            return []
        head = sublist[0]
        headval = []
        if head[0] == "argument":
            headval = [(int(head[1][1:]), head[2])]
        elif head[0][0] == "#":
            headval = [(int(head[0][1:]), head[2])]
        elif head[0] == "optarg":
            headval = getFormals(head[2:])
        return headval + getFormals(sublist[1:])

    # Ensure the formals appear in strict increasing order.
    formals = getFormals(argList)
    prevformal = 0
    for form, pos in formals:
        if form != prevformal + 1:
            raise ParseError, \
                ("Expected parameter %d but saw parameter %d" % (prevformal+1, form), 1+pos)
        prevformal = form

    # Ensure that "*" appears at most once at the top level.
    seenStar = False
    for arg in argList:
        if arg[0] == "rawtext" and arg[1] == "*":
            if seenStar:
                raise ParseError, \
                    ("Only one star parameter is allowed", arg[2])
            seenStar = True

    # Ensure that no optional argument contains more than nine formals.
    for arg in argList:
        if arg[0] == "optarg":
            optFormals = 0
            for oarg in arg[2:]:
                if oarg[0][0] == "#":
                    optFormals += 1
                    if optFormals > 9:
                        raise ParseError, \
                            ("An optional argument can contain at most nine formals",
                             oarg[2])

    # Ensure that "#" is used only where it's allowed.
    for arg in argList:
        if arg[0] in ["rawtext", "quoted"]:
            hashidx = string.find(arg[1], "#")
            if hashidx == 0 or (hashidx > 0 and arg[1][hashidx-1] != "\\"):
                if arg[0] == "quoted":
                    hashidx += 1
                raise ParseError, \
                    ('The "#" character cannot be used as a literal character unless escaped with "\\"',
                     arg[2] + hashidx)
        elif arg[0] == "optarg":
            for oarg in arg[2:]:
                if oarg[0] in ["rawtext", "quoted"]:
                    hashidx = string.find(oarg[1], "#")
                    if hashidx == 0 or (hashidx > 0 and oarg[1][hashidx-1] != "\\"):
                        if oarg[0] == "quoted":
                            hashidx += 1
                        raise ParseError, \
                            ('The "#" character cannot be used as a literal character unless escaped with "\\"',
                             oarg[2] + hashidx)


class LaTeXgenerator():
    "Generate LaTeX code from a list of arguments."

    def __init__(self):
        "Initialize all of LaTeXgenerator's instance variables."
        self.argList = []             # List of arguments provided to generate()
        self.topLevelName = "???"     # Base macro name
        self.haveStar = False         # True=need to define \ifNAME@star
        self.haveAt = False           # True=need to use \makeatletter...\makeatother
        self.numFormals = 0           # Total number of formal arguments
        self.codeList = []            # List of lines of code to output

    def toRoman(self, num):
        "Convert a decimal number to roman."
        dec2rom = [("m",  1000),
                   ("cm",  900),
                   ("d",   500),
                   ("cd",  400),
                   ("c",   100),
                   ("xc",   90),
                   ("l",    50),
                   ("xl",   40),
                   ("x",    10),
                   ("ix",    9),
                   ("v",     5),
                   ("iv",    4),
                   ("i",     1)]
        romanStr = ""
        if num > 4000:
            raise ParseError, ("Too many arguments", 0)
        for rom, dec in dec2rom:
            while num >= dec:
                romanStr += rom
                num -= dec
        return romanStr

    def partitionArgList(self):
        "Group arguments, one per macro to generate."
        self.argGroups = []
        argIdx = 1

        # Specially handle the first group because it's limited by
        # \newcomand's semantics.
        group = []
        if len(self.argList) == 1:
            # No arguments whatsoever
            self.argGroups.append(group)
            return
        arg = self.argList[argIdx]
        if arg[0] == "optarg" and arg[1] == ("[", "]") and len(arg) == 3:
            group.append(arg)
            argIdx += 1
        while len(group) < 9 and argIdx < len(self.argList) and self.argList[argIdx][0] == "argument":
            group.append(self.argList[argIdx])
            argIdx += 1
        self.argGroups.append(group)

        # Handle the remaining groups, each ending before an optional
        # argument.
        group = []
        numFormals = 0
        for arg in self.argList[argIdx:]:
            if arg[0] == "rawtext":
                # Treat "*" as an optional argument.
                if arg[1] == "*":
                    if group != []:
                        self.argGroups.append(group)
                        group = []
                    numFormals = 0
                group.append(arg)
            elif arg[0] == "quoted":
                group.append(arg)
            elif arg[0] == "argument":
                group.append(arg)
                numFormals += 1
                if numFormals == 9:
                    if group != []:
                        self.argGroups.append(group)
                        group = []
                    numFormals = 0
            elif arg[0] == "optarg":
                # Note that we know from checkArgList() that there are
                # no more than 10 formals specified within the
                # optional argument.
                if group != []:
                    self.argGroups.append(group)
                    group = []
                numFormals = 0
                optarg = arg[0:2]
                for oarg in arg[2:]:
                    if oarg[0] in ["rawtext", "quoted"]:
                        optarg.append(oarg)
                    elif oarg[0][0] == "#":
                        numFormals += 1
                        optarg.append(oarg)
                    else:
                        optarg.append(oarg)
                group.append(optarg)
        if group != []:
            self.argGroups.append(group)

    def argsToString(self, argList, mode, argSubtract=0):
        '''
            Produce a string version of a list of arguments.
            mode is one of "define", "call", or "calldefault".
            argSubtract is subtracted from each argument number.
        '''
        if mode not in ["define", "call", "calldefault"]:
            raise ParseError, ('Internal error (mode="%s")' % mode, argList[0][2])
        argStr = ""
        findArgRE = re.compile('#(\d+)')
        for arg in argList:
            if arg[0] == "argument":
                if mode == "define":
                    argStr += "#%d" % (int(arg[1][1:]) - argSubtract)
                else:
                    argStr += "{#%d}" % (int(arg[1][1:]) - argSubtract)
            elif arg[0] == "rawtext":
                argStr += arg[1]
            elif arg[0] == "quoted":
                argStr += arg[1][1:-1]
            elif arg[0] == "optarg":
                argStr += arg[1][0]
                for oarg in arg[2:]:
                    if oarg[0][0] == "#":
                        if mode == "define":
                            argStr += "#%d" % (int(oarg[0][1:]) - argSubtract)
                        elif mode == "call":
                            argStr += "{#%d}" % (int(oarg[0][1:]) - argSubtract)
                        else:
                            if self.numFormals > 9:
                                argStr += findArgRE.sub(lambda a: "\\"+self.topLevelName+"@arg@"+self.toRoman(int(a.group(0)[1:])),
                                                        oarg[1])
                            else:
                                argStr += oarg[1]
                    elif oarg[0] == "quoted":
                        argStr += oarg[1][1:-1]
                    elif oarg[0] == "rawtext":
                        argStr += oarg[1]
                    else:
                        raise ParseError, ('Internal error ("%s")' % oarg[0],
                                           oarg[2])
                argStr += arg[1][1]
            else:
                raise ParseError, ('Internal error ("%s")' % arg[0], arg[2])
        return argStr

    def callMacro(self, macroNum):
        "Return an array of strings suitable for calling macro macroNum."
        if macroNum >= len(self.argGroups):
            # No more macros.
            return []
        macroName = "\\%s@%s" % (self.topLevelName, self.toRoman(macroNum))
        nextArg = self.argGroups[macroNum][0]
        callSeq = []
        if self.numFormals > 9:
            # More than nine formal parameters
            if nextArg[0] == "optarg":
                callSeq.append("  \\@ifnextchar%s{%s}{%s%s}%%" % \
                                   (nextArg[1][0], macroName, macroName,
                                    self.argsToString([nextArg], mode="calldefault")))
            elif nextArg[0] == "rawtext" and nextArg[1] == "*":
                callSeq.append("  \\@ifstar{\\%s@startrue%s*}{\\%s@starfalse%s*}%%" % \
                                   (self.topLevelName, macroName,
                                    self.topLevelName, macroName))
            else:
                callSeq.append("  %s" % macroName)
        else:
            # Nine or fewer formal parameters
            argStr = ""
            for g in range(0, macroNum):
                argStr += self.argsToString(self.argGroups[g], mode="call")
            if nextArg[0] == "optarg":
                callSeq.append("  \\@ifnextchar%s{%s%s}{%s%s%s}%%" % \
                                   (nextArg[1][0],
                                    macroName, argStr, macroName, argStr,
                                    self.argsToString([nextArg], mode="calldefault")))
            elif nextArg[0] == "rawtext" and nextArg[1] == "*":
                callSeq.append("  \\@ifstar{\\%s@startrue%s%s*}{\\%s@starfalse%s%s*}%%" % \
                                   (self.topLevelName, macroName, argStr,
                                    self.topLevelName, macroName, argStr))
            else:
                callSeq.append("  %s%s%%" % (macroName, argStr))
        return callSeq

    def putCodeHere(self):
        'Return an array of strings representing "Put code here".'
        code = []
        if self.haveStar:
            code.extend(["  \\if%s@star" % self.topLevelName,
                         '    % Put code for the "*" case here.',
                         "  \\else",
                         '    % Put code for the non-"*" case here.',
                         "  \\fi",
                         "  %% Put code common to both cases here (and/or above the \\if%s@star)." % self.topLevelName])
        else:
            code.append("  % Put your code here.")
        if self.numFormals == 0:
            return code
        if self.numFormals > 9:
            firstArgName = "\\%s@arg@i" % self.topLevelName
            lastArgName = "\\%s@arg@%s" % (self.topLevelName, self.toRoman(self.numFormals))
        else:
            firstArgName = "#1"
            lastArgName = "#%d" % self.numFormals
        if self.numFormals == 1:
            code.append("  %% You can refer to the argument as %s." % firstArgName)
        elif self.numFormals == 2:
            code.append("  %% You can refer to the arguments as %s and %s." % (firstArgName, lastArgName))
        else:
            code.append("  %% You can refer to the arguments as %s through %s." % (firstArgName, lastArgName))
        return code

    def produceTopLevel(self):
        "Generate the code for the top-level macro definition."
        # Generate the macro definition.
        defStr = "\\newcommand{\\%s}" % self.topLevelName
        argList = self.argGroups[0]
        if argList != []:
            defStr += "[%d]" % len(argList)
            firstArg = argList[0]
            if firstArg[0] == "optarg":
                defVal = firstArg[2][1][1:-1]
                if string.find(defVal, "]") != -1:
                    defVal = "{%s}" % defVal
                defStr += "[%s]" % defVal
        defStr += "{%"
        self.codeList.append(defStr)

        # Generate the macro body.
        if len(self.argGroups) == 1:
            # Single macro definition.
            self.codeList.extend(self.putCodeHere())
        else:
            # More macros are forthcoming.
            if self.numFormals > 9:
                # More than nine formal parameters
                for f in range(1, len(argList)+1):
                    self.codeList.append("  \\def\\%s@arg@%s{#%d}%%" % (self.topLevelName, self.toRoman(f), f))
            self.codeList.extend(self.callMacro(1))
        self.codeList.append("}")

    def produceRemainingMacros(self):
        "Generate code for all macros except the first."
        formalsSoFar = len(self.argGroups[0])
        for groupNum in range(1, len(self.argGroups)):
            # Generate the macro header.
            self.codeList.append("")
            argList = self.argGroups[groupNum]
            defStr = "\\def\\%s@%s" % (self.topLevelName, self.toRoman(groupNum))
            if self.numFormals > 9:
                defStr += self.argsToString(argList, mode="define", argSubtract=formalsSoFar)
            else:
                for g in range (0, groupNum+1):
                    defStr += self.argsToString(self.argGroups[g], mode="define")
            self.codeList.append(defStr + "{%")

            # Generate the macro body.
            if self.numFormals > 9:
                # More than nine formal parameters
                for arg in argList:
                    if arg[0] == "argument":
                        formalNum = int(arg[1][1:])
                        self.codeList.append("  \\def\\%s@arg@%s{#%d}%%" % \
                                                 (self.topLevelName,
                                                  self.toRoman(formalNum),
                                                  formalNum - formalsSoFar))
                    elif arg[0] == "optarg":
                        for oarg in arg[2:]:
                            if oarg[0][0] == "#":
                                formalNum = int(oarg[0][1:])
                                self.codeList.append("  \\def\\%s@arg@%s{#%d}%%" % \
                                                         (self.topLevelName,
                                                          self.toRoman(formalNum),
                                                          formalNum - formalsSoFar))
                if groupNum == len(self.argGroups) - 1:
                    self.codeList.extend(self.putCodeHere())
                else:
                    self.codeList.extend(self.callMacro(groupNum + 1))
            else:
                # Nine or fewer formal parameters.
                if groupNum == len(self.argGroups) - 1:
                    self.codeList.extend(self.putCodeHere())
                else:
                    self.codeList.extend(self.callMacro(groupNum + 1))

            # Generate the macro trailer.
            self.codeList.append("}")

            # Increment the count of formals seen so far.
            for arg in argList:
                if arg[0] == "argument":
                    formalsSoFar += 1
                elif arg[0] == "optarg":
                    formalsSoFar += len(filter(lambda o: o[0][0] == "#", arg[2:]))

    def generate(self, argList):
        "Generate LaTeX code from an argument list."
        # Group arguments and identify characteristics that affect the output.
        self.argList = argList
        self.partitionArgList()
        self.haveAt = len(self.argGroups) > 1
        self.haveStar = filter(lambda arg: arg[0]=="rawtext" and arg[1]=="*", self.argList) != []
        self.topLevelName = self.argList[0][1]
        for arg in self.argList:
            if arg[0] == "argument":
                self.numFormals += 1
            elif arg[0] == "optarg":
                for oarg in arg[2:]:
                    if oarg[0][0] == "#":
                        self.numFormals += 1

        # Output LaTeX code.
        if self.haveAt:
            self.codeList.append("\\makeatletter")
        if self.haveStar:
            self.codeList.append("\\newif\\if%s@star" % self.topLevelName)
        self.produceTopLevel()
        self.produceRemainingMacros()
        if self.haveAt:
            self.codeList.append("\\makeatother")
        for codeLine in self.codeList:
            print codeLine


# The buck starts here.
if __name__ == '__main__':
    import sys
    import string

    def processLine():
        "Parses the current value of oneLine."
        global oneLine
        try:
            sys.stdout.softspace = 0        # Cancel the %$#@! space.
            oneLine = string.strip(oneLine)
            if oneLine=="" or oneLine[0]=="%":
                return
            if not isStdin:
                print prompt, oneLine
            scanner = CmdScanner()
            parser = CmdParser()
            tokens = scanner.tokenize(oneLine)
            ast = parser.parse(tokens)
            argList = flattenAST(ast)
            checkArgList(argList)
            gen = LaTeXgenerator()
            gen.generate(argList)
        except ParseError,(message, pos):
            sys.stderr.write((" "*(len(prompt)+pos)) + "^\n")
            sys.stderr.write("%s: %s.\n" % (sys.argv[0], message))
        if isStdin:
            print ""

    sys.setrecursionlimit(5000)
    prompt = "% Prototype:"
    if len(sys.argv) <= 1:
        isStdin = 1
        print prompt + " ",
        while 1:
            oneLine = sys.stdin.readline()
            if not oneLine:
                break
            processLine()
            print prompt + " ",
    else:
        isStdin = 0
        oneLine = string.join(sys.argv[1:])
        processLine()