#!/usr/bin/env python3

from sys import stderr, stdout

def monkey_patch():
    from svgelements import SVGElement

    def SVGElement__imatmul__(self,other):
        print('Dummy SVGElement.__imatmul__')
        return self

    SVGElement.__imatmul__ = SVGElement__imatmul__

monkey_patch()

# --------------------------------------------------------------------
def open_svg(input):
    from svgelements import SVG

    svg =  SVG.parse(input,reify=False)
    print(type(svg[0]),file=stderr)
    return svg
    
# --------------------------------------------------------------------
def get_bb(svg):
    llx,lly,urx,ury =  svg.bbox()
    # print(llx,lly,urx,ury)
    return llx,lly,urx,ury

# --------------------------------------------------------------------
def flush_left(svg):
    from svgelements import Matrix
    
    llx,lly,urx,ury = get_bb(svg)
    svg             *= Matrix.translate(-llx,-lly)
    return svg

# --------------------------------------------------------------------
def scale_unit(svg):
    from svgelements import Matrix
    
    llx,lly,urx,ury =  get_bb(svg)
    scale           =  1 / max(urx,ury)
    svg             *= Matrix.scale(scale)

    return svg

# --------------------------------------------------------------------
def scale_to_cm(svg,px_per_cm = 37.79531157367762334497):
    from svgelements import Matrix
    
    scale           =  1 / px_per_cm
    svg             *= Matrix.scale(scale)

    return svg

# --------------------------------------------------------------------
def shift_to_centre(svg):
    from svgelements import Matrix
    
    llx,lly,urx,ury = get_bb(svg)
    svg             *= Matrix.translate(-urx/2,-ury/2)
    svg             *= Matrix.scale(1,-1)
    svg.reify()
    svg.render()
    return svg

# --------------------------------------------------------------------
def flipy(svg):
    from svgelements import Matrix
    
    llx,lly,urx,ury = get_bb(svg)
    svg             *= Matrix.translate(-llx,-lly)
    svg             *= Matrix.scale(1,-1)
    svg.reify()
    svg.render()
    llx,lly,urx,ury = get_bb(svg)
    svg             *= Matrix.translate(0,-lly)
    return svg

# ====================================================================
def get_paths(svg):
    '''Read in paths from the SVG image 
    
    This will recurse down groups for path operations 

    Returns
    -------
    paths : list 
        List of svg.path.Path objects 
    '''
    from svgelements import Path, Group
    from svg.path import parse_path
    
    paths = []
    
    for e in svg:
        if isinstance(e,Path):
            print('Parse path',file=stderr)
            p = parse_path(e.d())
            paths.append(p)
        if isinstance(e,Group):
            print('Parse group',file=stderr)
            paths.append(get_paths(e))

    return paths

# ====================================================================
def tikz_point(e,f):
    '''Ouput a point 
    
    Parameters
    ---------- 
    e : complex
        The point 
    f : str 
        Floating point format 
    
    Return
    ------ 
    coord : str 
        The point formatted for Tikz 
    '''
    return f'({e.real:{f}},{e.imag:{f}})'
    
# --------------------------------------------------------------------
def tikz_move(e,f):
    '''Output a move-to operation 
    
    Parameters
    ---------- 
    e : complex
        The point to move to 
    f : str 
        Floating point format 
    
    Return
    ------ 
    coord : str 
        The point formatted for Tikz 
    '''
    return tikz_point(e.start,f)

# --------------------------------------------------------------------
def tikz_line(e,f):
    '''Output a line-to operation 
    
    Parameters
    ---------- 
    e : complex
        The point to draw line to
    f : str 
        Floating point format 
    
    Return
    ------ 
    line : str 
        The line-to operation to point  
    '''
    return '-- '+tikz_point(e.end,f)

# --------------------------------------------------------------------
def tikz_curve(e,f):
    '''Output a curve-to operation 
    
    Parameters
    ---------- 
    e : curve operation (3 complex numbers)
        The control points (2 of them) and the point to draw curve to
    f : str 
        Floating point format 
    
    Return
    ------ 
    curve : str 
        The curve-to operation to point  
    '''
    return \
        '.. controls '+tikz_point(e.control1,f) + \
        ' and '       +tikz_point(e.control2,f) + \
        ' ..'         +tikz_point(e.end,f)

# --------------------------------------------------------------------
def tikz_close(e,f):
    '''Ouput a close (or cycle) operation 
    '''
    return '-- cycle'

# ====================================================================
def path_to_tikz(path,f='8.5f'):
    '''Convert a svt.path.Path to a TikZ path 

    Parameters
    ----------
    path : svg.path.Path 
        Path to convert 
    f : str 
        floating point format to use 
    
    Returns 
    ------- 
    code : list 
        List of Tikz path operations as strings
    '''
    from svg.path import Line, Move, Close, CubicBezier, Path

    code = [f'% ']
    for e in path:
        if   isinstance(e, Move):        code.append(tikz_move(e, f))
        elif isinstance(e, Line):        code.append(tikz_line(e, f))
        elif isinstance(e, Close):       code.append(tikz_close(e,f))
        elif isinstance(e, CubicBezier): code.append(tikz_curve(e,f))
        elif isinstance(e, Path):        code.extend(path_to_tikz(e,f))
        else:
            print(f'Unknown type: {type(e)}',file=stderr)

    return code

# ====================================================================
def doc_header(output=stdout):
    '''Prinout put standalone document header 
    '''
    print(r'''%
\documentclass{standalone}
\usepackage{tikz}
\begin{document}%''',file=output)

# --------------------------------------------------------------------
def doc_footer(output=stdout):
    '''Prinout put standalone document footer
    '''
    print(r'''%
\end{document}''',file=output)

# --------------------------------------------------------------------
def path_to_pic(path,name,output,format):
    '''Output a Tikz pic definition from the path

    Parameters
    ----------
    path : Path 
        The path to convert
    name : str 
        The name of the pic 
    output : File 
        Output file 
    format : str 
        Floating point format 
    '''
    print(fr'''\tikzset{{%
  {name}/.pic={{%
    \path[pic actions]''',file=output,end='\n    ')

    code = path_to_tikz(path,format)
    print('\n    '.join(code),file=output)
    print('    ;\n  }',file=output)
    print('}%',file=output)

# ====================================================================

# ====================================================================
def doit(input='soldier_side.svg',
         centre=True,
         unit=True,
         output=stdout,
         standalone=False,
         format='8.5f'):
    svg = open_svg(input)
    orig_bb = get_bb(svg)

    if centre or unit:
        svg = flush_left(svg)
    if unit:
        svg = scale_unit(svg)
    else:
        svg = scale_to_cm(svg)
    if centre:
        svg = shift_to_centre(svg)
    else:
        svg = flipy(svg)
    new_bb = get_bb(svg)

    paths = get_paths(svg)

    if standalone:
        doc_header(output)
    print(f'% Original bounding box: ({orig_bb[0]},{orig_bb[1]}) ({orig_bb[2]},{orig_bb[3]})',
          file=output)
    print(f'% New bounding box:      ({new_bb[0]},{new_bb[1]}) ({new_bb[2]},{new_bb[3]})',
          file=output)
    pics = []
    for i, path in enumerate(paths):
        name = f'pic_{i:03d}'
        path_to_pic(path,name,output,format)
        pics.append(name)
        
    if standalone:
        print(r'\begin{tikzpicture}%',file=output)
        x = 0
        y = 0
        for i, name in enumerate(pics):
            if x > 10:
                x = 0
                y += 2

            print(fr'  \pic[draw,fill=red] at ({x},{y}) {{{name}}};%',
                  file=output)
            x += 2

        print(r'\end{tikzpicture}%',file=output)
        doc_footer(output)
    
# ====================================================================
def test(svgfile='soldier_side.svg',tikzfile=stdout,standalone=True):
    with open(svgfile,'r') as input:
        with open(tikzfile,'w') if isinstance(tikzfile,str) else tikzfile as output:
            doit(input=input,output=output,standalone=standalone)
    

# ====================================================================
if __name__ == '__main__':
    from argparse import ArgumentParser, FileType

    ap = ArgumentParser(description='Convert SVG to unit Tikz picture')
    ap.add_argument('input',type=FileType('r'),nargs='?',default='soldier_side.svg',
                    help='Input SVG file')
    ap.add_argument('-o','--output',type=FileType('w'),default=None,
                    help='Output Tikz file')
    ap.add_argument('-s','--standalone',action='store_true',default=True,
                    help='Write standalone document')
    ap.add_argument('-p','--picture',action='store_false',dest='standalone',
                    help='Only write picture')
    ap.add_argument('-f','--format',type=str,default='8.5f',
                    help='Format for coordinates')
    ap.add_argument('-O','--original-size',dest='unit',action='store_false',
                    help='Do no scale paths to unit size')
    ap.add_argument('-S','--unit-size',dest='unit',action='store_true',
                    help='Scale paths to unit size')
    ap.add_argument('-Z','--no-centre',dest='centre',action='store_false',
                    help='Do not centre paths')
    ap.add_argument('-C','--centre',dest='centre',action='store_true',
                    help='Centre paths')

    args = ap.parse_args()

    if args.output is None:
        from pathlib import Path
        inname = Path(args.input.name)
        outname = inname.with_suffix('.tex')
        args.output  = open(outname,'w')
        

    doit(input      = args.input,
         unit       = args.unit,
         centre     = args.centre,
         output     = args.output,
         standalone = args.standalone,
         format     = args.format)