$include_dir="/home/hyper-archives/boost-commit/include"; include("$include_dir/msg-header.inc") ?>
Subject: [Boost-commit] svn:boost r73618 - trunk/libs/spirit/example/qi/compiler_tutorial/conjure3
From: joel_at_[hidden]
Date: 2011-08-08 23:35:07
Author: djowel
Date: 2011-08-08 23:35:05 EDT (Mon, 08 Aug 2011)
New Revision: 73618
URL: http://svn.boost.org/trac/boost/changeset/73618
Log:
refactoring: moving low-level llvm stuff into separate classes
Text files modified: 
   trunk/libs/spirit/example/qi/compiler_tutorial/conjure3/compiler.cpp |   166 +++++++++++++++++++++++++++------------ 
   trunk/libs/spirit/example/qi/compiler_tutorial/conjure3/compiler.hpp |    34 +++++++                                 
   2 files changed, 146 insertions(+), 54 deletions(-)
Modified: trunk/libs/spirit/example/qi/compiler_tutorial/conjure3/compiler.cpp
==============================================================================
--- trunk/libs/spirit/example/qi/compiler_tutorial/conjure3/compiler.cpp	(original)
+++ trunk/libs/spirit/example/qi/compiler_tutorial/conjure3/compiler.cpp	2011-08-08 23:35:05 EDT (Mon, 08 Aug 2011)
@@ -11,6 +11,7 @@
 
 #include <boost/foreach.hpp>
 #include <boost/variant/apply_visitor.hpp>
+#include <boost/range/adaptor/transformed.hpp>
 #include <boost/assert.hpp>
 #include <boost/lexical_cast.hpp>
 #include <set>
@@ -258,13 +259,13 @@
         //  the function. This is used for mutable variables etc.
         llvm::AllocaInst*
         create_entry_block_alloca(
-            llvm::Function* function,
+            llvm::Function* f,
             char const* name,
             llvm::LLVMContext& context)
         {
             llvm::IRBuilder<> builder(
-                &function->getEntryBlock(),
-                function->getEntryBlock().begin());
+                &f->getEntryBlock(),
+                f->getEntryBlock().begin());
 
             return builder.CreateAlloca(
                 llvm::Type::getIntNTy(context, int_size), 0, name);
@@ -273,10 +274,10 @@
 
     value llvm_compiler::var(char const* name)
     {
-        llvm::Function* function = llvm_builder.GetInsertBlock()->getParent();
+        llvm::Function* f = llvm_builder.GetInsertBlock()->getParent();
         llvm::IRBuilder<> builder(
-            &function->getEntryBlock(),
-            function->getEntryBlock().begin());
+            &f->getEntryBlock(),
+            f->getEntryBlock().begin());
 
         llvm::AllocaInst* alloca = builder.CreateAlloca(
             llvm::Type::getIntNTy(context(), int_size), 0, name);
@@ -284,13 +285,66 @@
         return value(alloca, true, &llvm_builder);
     }
 
+    namespace
+    {
+        struct llvm_value
+        {
+            typedef llvm::Value* result_type;
+            llvm::Value* operator()(value const& x) const
+            {
+                return x;
+            }
+        };
+
+        template <typename C>
+        llvm::Value* call_impl(
+            llvm::IRBuilder<>& llvm_builder,
+            function callee,
+            C const& args_)
+        {
+            // Sigh. LLVM requires CreateCall arguments to be random access.
+            // It would have been better if it can accept forward iterators.
+            // I guess it needs the arguments to be in contiguous memory.
+            // So, we have to put the args into a temporary std::vector.
+            std::vector<llvm::Value*> args(
+                args_.begin(), args_.end());
+
+            // Check the args for null values. We can't have null values.
+            // Return 0 if we find one to flag error.
+            BOOST_FOREACH(llvm::Value* arg, args)
+            {
+                if (arg == 0)
+                    return 0;
+            }
+
+            return llvm_builder.CreateCall(
+                callee, args.begin(), args.end(), "call_tmp");
+        }
+    }
+
+    template <typename Container>
     value llvm_compiler::call(
-        llvm::Function* callee,
-        std::vector<llvm::Value*> const& args)
+        function callee,
+        Container const& args)
     {
-        return value(
-            llvm_builder.CreateCall(callee, args.begin(), args.end(), "call_tmp"),
-            false, &llvm_builder);
+        llvm::Value* call = call_impl(
+            llvm_builder, callee,
+            args | boost::adaptors::transformed(llvm_value()));
+
+        if (call == 0)
+            return val();
+        return value(call, false, &llvm_builder);
+    }
+
+    function llvm_compiler::get_function(char const* name) const
+    {
+        return vm.module()->getFunction(name);
+    }
+
+    function llvm_compiler::get_current_function() const
+    {
+        // get the current function
+        return llvm_builder.GetInsertBlock()->getParent();
     }
 
     void llvm_compiler::init_fpm()
@@ -382,30 +436,40 @@
         }
     }
 
+    namespace
+    {
+        struct compile_args
+        {
+            compiler& c;
+            compile_args(compiler& c) : c(c) {}
+
+            typedef value result_type;
+            value operator()(ast::expression const& expr) const
+            {
+                return c(expr);
+            }
+        };
+    }
+
     value compiler::operator()(ast::function_call const& x)
     {
-        llvm::Function* callee = vm.module()->getFunction(x.function_name.name);
+        function callee = get_function(x.function_name.name);
         if (!callee)
         {
-            error_handler(x.function_name.id, "Function not found: " + x.function_name.name);
+            error_handler(x.function_name.id,
+                "Function not found: " + x.function_name.name);
             return val();
         }
 
-        if (callee->arg_size() != x.args.size())
+        if (callee.arg_size() != x.args.size())
         {
-            error_handler(x.function_name.id, "Wrong number of arguments: " + x.function_name.name);
+            error_handler(x.function_name.id,
+                "Wrong number of arguments: " + x.function_name.name);
             return val();
         }
 
-        std::vector<llvm::Value*> args;
-        BOOST_FOREACH(ast::expression const& expr, x.args)
-        {
-            args.push_back((*this)(expr));
-            if (args.back() == 0)
-                return val();
-        }
-
-        return call(callee, args);
+        return call(callee,
+            x.args | boost::adaptors::transformed(compile_args(*this)));
     }
 
     namespace
@@ -641,11 +705,11 @@
         if (!condition.is_valid())
             return false;
 
-        llvm::Function* function = builder().GetInsertBlock()->getParent();
+        llvm::Function* f = get_current_function();
 
         // Create blocks for the then and else cases.  Insert the 'then' block at the
         // end of the function.
-        llvm::BasicBlock* then_block = llvm::BasicBlock::Create(context(), "if.then", function);
+        llvm::BasicBlock* then_block = llvm::BasicBlock::Create(context(), "if.then", f);
         llvm::BasicBlock* else_block = 0;
         llvm::BasicBlock* exit_block = 0;
 
@@ -676,7 +740,7 @@
         if (x.else_)
         {
             // Emit else block.
-            function->getBasicBlockList().push_back(else_block);
+            f->getBasicBlockList().push_back(else_block);
             builder().SetInsertPoint(else_block);
             if (!(*this)(*x.else_))
                 return false;
@@ -693,7 +757,7 @@
         if (exit_block != 0)
         {
             // Emit exit block
-            function->getBasicBlockList().push_back(exit_block);
+            f->getBasicBlockList().push_back(exit_block);
             builder().SetInsertPoint(exit_block);
         }
         return true;
@@ -701,9 +765,9 @@
 
     bool compiler::operator()(ast::while_statement const& x)
     {
-        llvm::Function* function = builder().GetInsertBlock()->getParent();
+        llvm::Function* f = get_current_function();
 
-        llvm::BasicBlock* cond_block = llvm::BasicBlock::Create(context(), "while.cond", function);
+        llvm::BasicBlock* cond_block = llvm::BasicBlock::Create(context(), "while.cond", f);
         llvm::BasicBlock* body_block = llvm::BasicBlock::Create(context(), "while.body");
         llvm::BasicBlock* exit_block = llvm::BasicBlock::Create(context(), "while.end");
 
@@ -713,7 +777,7 @@
         if (!condition.is_valid())
             return false;
         builder().CreateCondBr(condition, body_block, exit_block);
-        function->getBasicBlockList().push_back(body_block);
+        f->getBasicBlockList().push_back(body_block);
         builder().SetInsertPoint(body_block);
 
         if (!(*this)(x.body))
@@ -723,7 +787,7 @@
             builder().CreateBr(cond_block); // loop back
 
         // Emit exit block
-        function->getBasicBlockList().push_back(exit_block);
+        f->getBasicBlockList().push_back(exit_block);
         builder().SetInsertPoint(exit_block);
 
         return true;
@@ -778,21 +842,21 @@
         llvm::FunctionType* function_type =
             llvm::FunctionType::get(void_return ? void_type : int_type, ints, false);
 
-        llvm::Function* function =
+        llvm::Function* f =
             llvm::Function::Create(
                 function_type, llvm::Function::ExternalLinkage,
                 current_function_name, vm.module());
 
         // If function conflicted, the function already exixts. If it has a
         // body, don't allow redefinition or reextern.
-        if (function->getName() != current_function_name)
+        if (f->getName() != current_function_name)
         {
             // Delete the one we just made and get the existing one.
-            function->eraseFromParent();
-            function = vm.module()->getFunction(current_function_name);
+            f->eraseFromParent();
+            f = get_function(current_function_name);
 
             // If function already has a body, reject this.
-            if (!function->empty())
+            if (!f->empty())
             {
                 error_handler(
                     x.function_name.id,
@@ -801,7 +865,7 @@
             }
 
             // If function took a different number of args, reject.
-            if (function->arg_size() != x.args.size())
+            if (f->arg_size() != x.args.size())
             {
                 error_handler(
                     x.function_name.id,
@@ -811,21 +875,21 @@
             }
 
             // Set names for all arguments.
-            llvm::Function::arg_iterator iter = function->arg_begin();
+            llvm::Function::arg_iterator iter = f->arg_begin();
             BOOST_FOREACH(ast::identifier const& arg, x.args)
             {
                 iter->setName(arg.name);
                 ++iter;
             }
         }
-        return function;
+        return f;
     }
 
-    void compiler::function_allocas(ast::function const& x, llvm::Function* function)
+    void compiler::function_allocas(ast::function const& x, llvm::Function* f)
     {
         // Create an variables for each argument and register the
         // argument in the symbol table so that references to it will succeed.
-        llvm::Function::arg_iterator iter = function->arg_begin();
+        llvm::Function::arg_iterator iter = f->arg_begin();
         BOOST_FOREACH(ast::identifier const& arg, x.args)
         {
             // Create an arg_ for this variable.
@@ -850,8 +914,8 @@
     {
         ///////////////////////////////////////////////////////////////////////
         // the signature:
-        llvm::Function* function = function_decl(x);
-        if (function == 0)
+        llvm::Function* f = function_decl(x);
+        if (f == 0)
             return false;
 
         ///////////////////////////////////////////////////////////////////////
@@ -860,21 +924,21 @@
         {
             // Create a new basic block to start insertion into.
             llvm::BasicBlock* block =
-                llvm::BasicBlock::Create(context(), "entry", function);
+                llvm::BasicBlock::Create(context(), "entry", f);
             builder().SetInsertPoint(block);
 
-            function_allocas(x, function);
+            function_allocas(x, f);
             return_block = llvm::BasicBlock::Create(context(), "return");
 
             if (!(*this)(*x.body))
             {
                 // Error reading body, remove function.
-                function->eraseFromParent();
+                f->eraseFromParent();
                 return false;
             }
 
             llvm::BasicBlock* last_block =
-                &function->getBasicBlockList().back();
+                &f->getBasicBlockList().back();
 
             // If the last block is unterminated, connect it to return_block
             if (last_block->getTerminator() == 0)
@@ -883,7 +947,7 @@
                 builder().CreateBr(return_block);
             }
 
-            function->getBasicBlockList().push_back(return_block);
+            f->getBasicBlockList().push_back(return_block);
             builder().SetInsertPoint(return_block);
 
             if (void_return)
@@ -894,10 +958,10 @@
             //~ vm.module()->dump();
 
             // Validate the generated code, checking for consistency.
-            llvm::verifyFunction(*function);
+            llvm::verifyFunction(*f);
 
             // Optimize the function.
-            fpm.run(*function);
+            fpm.run(*f);
         }
 
         return true;
Modified: trunk/libs/spirit/example/qi/compiler_tutorial/conjure3/compiler.hpp
==============================================================================
--- trunk/libs/spirit/example/qi/compiler_tutorial/conjure3/compiler.hpp	(original)
+++ trunk/libs/spirit/example/qi/compiler_tutorial/conjure3/compiler.hpp	2011-08-08 23:35:05 EDT (Mon, 08 Aug 2011)
@@ -90,6 +90,27 @@
     };
 
     ///////////////////////////////////////////////////////////////////////////
+    struct function
+    {
+        function()
+          : f(0) {}
+
+        operator llvm::Function*() const
+        { return f; }
+
+        std::size_t arg_size() const
+        { return f->arg_size(); }
+
+    private:
+
+        function(llvm::Function* f)
+          : f(f) {}
+
+        friend struct llvm_compiler;
+        llvm::Function* f;
+    };
+
+    ///////////////////////////////////////////////////////////////////////////
     //  The LLVM Compiler. Lower level compiler (does not deal with ASTs)
     ///////////////////////////////////////////////////////////////////////////
     struct llvm_compiler
@@ -109,10 +130,17 @@
         value val(llvm::Value* v);
 
         value var(char const* name);
+        value var(std::string const& name)
+        { return var(name.c_str()); }
+
+        template <typename Container>
+        value call(function callee, Container const& args);
+
+        function get_function(char const* name) const;
+        function get_function(std::string const& name) const
+        { return get_function(name.c_str()); }
 
-        value call(
-            llvm::Function* callee,
-            std::vector<llvm::Value*> const& args);
+        function get_current_function() const;
 
     protected: