$include_dir="/home/hyper-archives/boost-commit/include"; include("$include_dir/msg-header.inc") ?>
Subject: [Boost-commit] svn:boost r56820 - in sandbox/numeric_bindings/libs/numeric/bindings/tools: . templates
From: rutger_at_[hidden]
Date: 2009-10-14 08:42:24
Author: rutger
Date: 2009-10-14 08:42:23 EDT (Wed, 14 Oct 2009)
New Revision: 56820
URL: http://svn.boost.org/trac/boost/changeset/56820
Log:
improved cblas and cublas support
Text files modified: 
   sandbox/numeric_bindings/libs/numeric/bindings/tools/blas_generator.py  |    17 ++++++                                  
   sandbox/numeric_bindings/libs/numeric/bindings/tools/cblas.py           |    43 ++++++++++++-----                       
   sandbox/numeric_bindings/libs/numeric/bindings/tools/cublas.py          |    96 ++++++++++++++++++++++----------------- 
   sandbox/numeric_bindings/libs/numeric/bindings/tools/templates/blas.hpp |     2                                         
   4 files changed, 101 insertions(+), 57 deletions(-)
Modified: sandbox/numeric_bindings/libs/numeric/bindings/tools/blas_generator.py
==============================================================================
--- sandbox/numeric_bindings/libs/numeric/bindings/tools/blas_generator.py	(original)
+++ sandbox/numeric_bindings/libs/numeric/bindings/tools/blas_generator.py	2009-10-14 08:42:23 EDT (Wed, 14 Oct 2009)
@@ -71,25 +71,38 @@
       arg_list = []
       lapack_arg_list = []
       cblas_arg_list = []
+      cublas_arg_list = []
+
+      if info_map[ subroutine ][ "has_cblas_order_arg" ]:
+        cblas_arg_list += [ "CblasColMajor" ]
+
       for arg in info_map[ subroutine ][ 'arguments' ]:
         print "Subroutine ", subroutine, " arg ", arg
         arg_list += [ info_map[ subroutine ][ 'argument_map' ][ arg ][ 'code' ][ 'level_0' ] ]
         lapack_arg_list += [ info_map[ subroutine ][ 'argument_map' ][ arg ][ 'code' ][ 'call_blas_header' ] ]
         cblas_arg_list += [ info_map[ subroutine ][ 'argument_map' ][ arg ][ 'code' ][ 'call_cblas_header' ] ]
+
       sub_template = sub_template.replace( "$LEVEL0", ", ".join( arg_list ) )
       sub_template = sub_template.replace( "$CALL_BLAS_HEADER", ", ".join( lapack_arg_list ) )
       sub_template = sub_template.replace( "$CALL_CBLAS_HEADER", ", ".join( cblas_arg_list ) )
       sub_template = sub_template.replace( "$SUBROUTINE", subroutine )
+
+      # CBLAS stuff
       if 'cblas_routine' in info_map[ subroutine ]:
         cblas_routine = info_map[ subroutine ][ 'cblas_routine' ]
       else:
-        cblas_routine = '//TODO'
+        cblas_routine = '// TODO'
       sub_template = sub_template.replace( "$CBLAS_ROUTINE", cblas_routine )
 
+      # CUBLAS stuff
       if 'cublas_routine' in info_map[ subroutine ]:
         cublas_routine = info_map[ subroutine ][ 'cublas_routine' ]
+        for arg in info_map[ subroutine ][ 'arguments' ]:
+          cublas_arg_list += [ info_map[ subroutine ][ 'argument_map' ][ arg ][ 'code' ][ 'call_cublas_header' ] ]
       else:
-        cublas_routine = '//TODO'
+        cublas_routine = '// NOT FOUND'
+
+      sub_template = sub_template.replace( "$CALL_CUBLAS_HEADER", ", ".join( cublas_arg_list ) )
       sub_template = sub_template.replace( "$CUBLAS_ROUTINE", cublas_routine )
 
       sub_template = sub_template.replace( '$groupname', group_name.lower() )
Modified: sandbox/numeric_bindings/libs/numeric/bindings/tools/cblas.py
==============================================================================
--- sandbox/numeric_bindings/libs/numeric/bindings/tools/cblas.py	(original)
+++ sandbox/numeric_bindings/libs/numeric/bindings/tools/cblas.py	2009-10-14 08:42:23 EDT (Wed, 14 Oct 2009)
@@ -17,13 +17,15 @@
 import pprint
 
 def parse_file( filename, info_map, template_map ):
+
     pp = pprint.PrettyPrinter( indent = 2 )
     source = open( filename ).read() 
 
     for match in re.compile( '(void|float|double) +cblas_([^\(]+)\(([^\)]+)\)', re.M | re.S ).findall( source ):
+        print "----"
         return_type  = match[0]
         blas_routine = match[1].split("_sub")[0].upper().strip()
-        print "CBLAS routine: ", blas_routine
+        print "CBLAS routine:", match[1] , "   BLAS equivalent:", blas_routine
         arguments = {}
         for arg in match[2].replace('\n','').split( ',' ):
             arg = arg.strip()
@@ -41,24 +43,26 @@
             # read aliases, if they are there
             my_key = blas_routine[ 1: ].lower() + '.all.cblas_alias'
             alias_map = {}
-            print my_key
+            #print my_key
             if netlib.my_has_key( my_key, template_map ) != None:
-                print "Has key.."
+                #print "Has key.."
                 for line in template_map[ netlib.my_has_key( my_key, template_map ) ].splitlines():
-                    print "Line:", line
+                    #print "Line:", line
                     alias_map[ line.split( "," )[0] ] = line.split(",")[1]
 
-            print alias_map
-               
+            #print alias_map
+
             # Try to match and insert arguments
+            # argument_map is the data gathered through the Fortran interface
             for arg in info_map[ blas_routine ][ 'argument_map' ]:
                 cblas_arg = ''
                 if arg in arguments:
                     cblas_arg = arg
-                if arg in alias_map:
+                elif arg in alias_map:
                     if alias_map[ arg ] in arguments:
                         cblas_arg = alias_map[ arg ]
-                print "Looking for argument ", arg, " CBLAS equivalent: ", cblas_arg
+
+                print "Looking for BLAS argument ", arg, " CBLAS equivalent: ", cblas_arg
                 if cblas_arg in arguments:
                     print "Found matching argument, inserting call_cblas_header stuff"
                     call_cblas_header = info_map[ blas_routine ][ "argument_map" ][ arg ][ "code" ][ "call_blas_header" ]
@@ -69,13 +73,26 @@
                     call_cblas_header = call_cblas_header.replace( "complex_ptr", "void_ptr" );
 
                     print "Result:   ", call_cblas_header
-                    info_map[ blas_routine ][ "argument_map" ][ arg ][ "code" ][ "call_cblas_header" ] = call_cblas_header
+                    if arg == 'UPLO':
+                        info_map[ blas_routine ][ "argument_map" ][ arg ][ "code" ][ "call_cblas_header" ] = \
+                            "( uplo == 'U' ? CblasUpper : CblasLower )"
+                    elif arg == 'DIAG':
+                        info_map[ blas_routine ][ "argument_map" ][ arg ][ "code" ][ "call_cblas_header" ] = \
+                            "( uplo == 'N' ? CblasNonUnit : CblasUnit )"
+                    elif arg == 'SIDE':
+                        info_map[ blas_routine ][ "argument_map" ][ arg ][ "code" ][ "call_cblas_header" ] = \
+                            "( uplo == 'L' ? CblasLeft : CblasRight )"
+                    elif  arg == 'TRANS' or arg == 'TRANSA' or arg == 'TRANSB':
+                        info_map[ blas_routine ][ "argument_map" ][ arg ][ "code" ][ "call_cblas_header" ] = \
+                          "( " + arg.lower() + " == 'N' ? CblasNoTrans : ( " + arg.lower() + " == 'T' ? CblasTrans : CblasConjTrans ) )"
+                    else:
+                        info_map[ blas_routine ][ "argument_map" ][ arg ][ "code" ][ "call_cblas_header" ] = call_cblas_header
                 else:
                     exit(0)
 
-
-
-
-
+            if "ORDER" in arguments:
+                info_map[ blas_routine ][ "has_cblas_order_arg" ] = True
+            else:
+                info_map[ blas_routine ][ "has_cblas_order_arg" ] = False
 
 
Modified: sandbox/numeric_bindings/libs/numeric/bindings/tools/cublas.py
==============================================================================
--- sandbox/numeric_bindings/libs/numeric/bindings/tools/cublas.py	(original)
+++ sandbox/numeric_bindings/libs/numeric/bindings/tools/cublas.py	2009-10-14 08:42:23 EDT (Wed, 14 Oct 2009)
@@ -8,7 +8,7 @@
 # http://www.boost.org/LICENSE_1_0.txt)
 #
 
-import re, os.path, copy
+import re, os.path, copy, netlib
 from types import StringType
 
 # for debugging purposes
@@ -18,53 +18,67 @@
     pp = pprint.PrettyPrinter( indent = 2 )
     source = open( filename ).read() 
 
-    for match in re.compile( '(float|double|void) ?CUBLASAPI ?cublas([SDCZ][a-z0-9]+) ?\(([^\(]+)\(([^\)]+)\)', re.M | re.S ).findall( source ):
-        blas_routine = match[1].upper()
-
-        if blas_routine in info_map:
-            print "FOUND!"
-            info_map[ blas_routine ][ "cublas_routine" ] = 'cublas' + match[1]
-            pp.pprint( info_map[ blas_routine ] )
-       
-        #print blas_routine
-        #print match[0], match[1]
-
-        #print match
-
-        
-
-
-
-
-    #print info_map.keys
-
-
-    #print source
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-#parse_file( "./cublas.h", dict() )
-
-
-
+    for match in re.compile( '^(cuComplex|cuDoubleComplex|float|double|void) ?CUBLASAPI ?cublas([SDCZ][a-z0-9]+) ?\(([^\)]+)\)', re.M | re.S ).findall( source ):
+        print "----"
 
+        return_type  = match[0]
+        blas_routine = match[1].upper().strip()
+        print "CUBLAS routine:", match[1], "   BLAS equivalent:", blas_routine
+
+        arguments = {}
+        for arg in match[2].replace('\n','').split( ',' ):
+            arg = arg.strip()
+            arg_name = arg.split( " " )[-1].replace( "*", "" ).strip().upper()
+            arguments[ arg_name ] = {}
+            arguments[ arg_name ][ "original" ] = arg
+            arguments[ arg_name ][ "pointer" ] = "*" in arg
 
+        pp.pprint( arguments )
 
+        if blas_routine in info_map:
+            print "Found ", blas_routine, " in Fortran info_map."
+            info_map[ blas_routine ][ "cublas_routine" ] = 'cublas' + match[1]
+            #pp.pprint( info_map[ blas_routine ] )
 
+            # read aliases, if they are there
+            my_key = blas_routine[ 1: ].lower() + '.all.cblas_alias'
+            alias_map = {}
+            print my_key
+            if netlib.my_has_key( my_key, template_map ) != None:
+                #print "Has key.."
+                for line in template_map[ netlib.my_has_key( my_key, template_map ) ].splitlines():
+                    #print "Line:", line
+                    alias_map[ line.split( "," )[0] ] = line.split(",")[1]
+
+            for arg in info_map[ blas_routine ][ 'argument_map' ]:
+                cublas_arg = ''
+                if arg in arguments:
+                    cublas_arg = arg
+                elif 'S' + arg in arguments:
+                        cublas_arg = 'S' + arg
+                # E.g., BLAS DPARAM equals CUBLAS SPARAM
+                elif 'S' + arg[1:] in arguments and arg == 'DPARAM':
+                        cublas_arg = 'S' + arg[1:]
+                elif arg in alias_map:
+                    if alias_map[ arg ] in arguments:
+                        cublas_arg = alias_map[ arg ]
+
+                print "Looking for BLAS argument ", arg, " CUBLAS equivalent: ", cublas_arg
+
+                if cublas_arg in arguments:
+                    print "Found matching argument, inserting call_cublas_header stuff"
+                    call_cublas_header = info_map[ blas_routine ][ "argument_map" ][ arg ][ "code" ][ "call_blas_header" ]
+
+                    print "Original: ", call_cublas_header
+                    if not arguments[ cublas_arg ][ "pointer" ]:
+                        call_cublas_header = call_cublas_header.replace( "&", "" )
 
+                    call_cublas_header = call_cublas_header.replace( "complex_ptr", "void_ptr" );
 
+                    info_map[ blas_routine ][ "argument_map" ][ arg ][ "code" ][ "call_cublas_header" ] = call_cublas_header
 
+                else:
+                    exit(0)
 
 
 
Modified: sandbox/numeric_bindings/libs/numeric/bindings/tools/templates/blas.hpp
==============================================================================
--- sandbox/numeric_bindings/libs/numeric/bindings/tools/templates/blas.hpp	(original)
+++ sandbox/numeric_bindings/libs/numeric/bindings/tools/templates/blas.hpp	2009-10-14 08:42:23 EDT (Wed, 14 Oct 2009)
@@ -52,7 +52,7 @@
 #if defined BOOST_NUMERIC_BINDINGS_BLAS_CBLAS
     $RETURN_STATEMENT$CBLAS_ROUTINE( $CALL_CBLAS_HEADER );
 #elif defined BOOST_NUMERIC_BINDINGS_BLAS_CUBLAS
-    $RETURN_STATEMENT$CUBLAS_ROUTINE( ... ); // FIXME
+    $RETURN_STATEMENT$CUBLAS_ROUTINE( $CALL_CUBLAS_HEADER );
 #else
     $RETURN_STATEMENTBLAS_$SUBROUTINE( $CALL_BLAS_HEADER );
 #endif