#!/usr/bin/perl
#
#########################################################################
#                                                                       #
# Open Dynamics Engine, Copyright (C) 2001,2002 Russell L. Smith.       #
# All rights reserved.  Email: russ@q12.org   Web: www.q12.org          #
#                                                                       #
# This library is free software; you can redistribute it and/or         #
# modify it under the terms of EITHER:                                  #
#   (1) The GNU Lesser General Public License as published by the Free  #
#       Software Foundation; either version 2.1 of the License, or (at  #
#       your option) any later version. The text of the GNU Lesser      #
#       General Public License is included with this library in the     #
#       file LICENSE.TXT.                                               #
#   (2) The BSD-style license that is included with this library in     #
#       the file LICENSE-BSD.TXT.                                       #
#                                                                       #
# This library is distributed in the hope that it will be useful,       #
# but WITHOUT ANY WARRANTY; without even the implied warranty of        #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the files    #
# LICENSE.TXT and LICENSE-BSD.TXT for more details.                     #
#                                                                       #
#########################################################################

# multi-dot-product code generator. this code generator is based on the
# dot-product code generator.
#
# code generation parameters, set in a parameters file:
#    FNAME    : name of source file to generate - a .c file will be made
#    N1       : block size (number of `a' vectors to dot)
#    UNROLL1  : inner loop unrolling factor (1..)
#    FETCH    : max num of a[i]'s and b[i]'s to load ahead of muls
#    LAT1     : load -> mul latency (>=1)
#    LAT2     : mul -> add latency (>=1). if this is 1, use fused mul-add
#
#############################################################################

require ("BuildUtil");

# get and check code generation parameters
error ("Usage: BuildMultidot <parameters-file>") if $#ARGV != 0;
do $ARGV[0];

if (!defined($FNAME) || !defined($N1) || !defined($UNROLL1) ||
    !defined($FETCH) || !defined($LAT1) || !defined($LAT2)) {
  error ("code generation parameters not defined");
}

# check parameters
error ("bad N1") if $N1 < 2;
error ("bad UNROLL1") if $UNROLL1 < 1;
error ("bad FETCH") if $FETCH < 1;
error ("bad LAT1") if $LAT1 < 1;
error ("bad LAT2") if $LAT2 < 1;

#############################################################################

open (FOUT,">$FNAME.c") or die "can't open $FNAME.c for writing";

# file and function header
output (<<END);
/* generated code, do not edit. */

#include "ode/matrix.h"


END
output ("void dMultidot$N1 (");
for ($i=0; $i<$N1; $i++) {
  output ("const dReal *a$i, ");
}
output ("const dReal *b, dReal *outsum, int n)\n{\n");

output ("dReal ");
for ($i=0; $i<$UNROLL1; $i++) {
  for ($j=0; $j<$N1; $j++) {
    output ("p$i$j,");
    output ("m$i$j,") if $LAT2 > 1;
  }
  output ("q$i,");
}
for ($i=0; $i<$N1; $i++) {
  output ("sum$i");
  output (",") if $i < ($N1-1);
}
output (";\n");
for ($i=0; $i<$N1; $i++) {
  output ("sum$i = 0;\n");
}
output (<<END);
n -= $UNROLL1;
while (n >= 0) {
END

@load = ();		# slot where a[i]'s and b[i]'s loaded
@mul = ();		# slot where multiply i happened
@add = ();		# slow where add i happened

for ($i=0; $i<$N1; $i++) {
  output ("p0$i = a$i [0];\n");
}
output ("q0 = b[0];\n");
push (@load,0);

$slot=0;		# one slot for every load/mul/add/nop issued
for (;;) {
  $startslot = $slot;

  # do next load
  if (($#load - $#mul) < $FETCH && ($#load+1) < $UNROLL1) {
    push (@load,$slot);
    for ($j=0; $j<$N1; $j++) {
      output ("p$#load$j = a$j [$#load];\n");
    }
    output ("q$#load = b[$#load];\n");
    $slot++;
  }

  # do next multiply
  if ($#load > $#mul && $slot >= ($load[$#mul+1] + $LAT1) &&
      ($#mul+1) < $UNROLL1) {
    push (@mul,$slot);
    if ($LAT2 > 1) {
      for ($j=0; $j<$N1; $j++) {
        output ("m$#mul$j = p$#mul$j * q$#mul;\n");
      }
    }
    else {
      for ($j=0; $j<$N1; $j++) {
        output ("sum$j += p$#mul$j * q$#mul;\n");
      }
      last if ($#mul+1) >= $UNROLL1;
    }
    $slot++;
  }
  # do next add
  if ($LAT2 > 1) {
    if ($#mul > $#add && $slot >= ($mul[$#add+1] + $LAT2)) {
      push (@add,$slot);
      for ($j=0; $j<$N1; $j++) {
        output ("sum$j += m$#add$j;\n");
      }
      $slot++;
      last if ($#add+1) >= $UNROLL1;
    }
  }

  if ($slot == $startslot) {
    # comment ("nop");
    $slot++;
  }
}

for ($j=0; $j<$N1; $j++) {
  output ("a$j += $UNROLL1;\n");
}
output ("b += $UNROLL1;\n");
output ("n -= $UNROLL1;\n");
output ("}\n");

output ("n += $UNROLL1;\n");
output ("while (n > 0) {\n");
output ("q0 = *b;\n");
for ($j=0; $j<$N1; $j++) {
  output ("sum$j += (*a$j) * q0;\n");
  output ("a$j++;\n");
}
output ("b++;\n");
output ("n--;\n");
output ("}\n");
for ($j=0; $j<$N1; $j++) {
  output ("outsum[$j] = sum$j;\n");
}
output ("}\n");
