diff gen/CodeGen.d @ 51:c96cdcbdb9d6 new_gen

Rearranged some stuff, and renamed LLVMGen -> CodeGen
author Anders Halager <halager@gmail.com>
date Sat, 26 Apr 2008 15:54:54 +0200
parents
children 6decab6f45c4
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/gen/CodeGen.d	Sat Apr 26 15:54:54 2008 +0200
@@ -0,0 +1,652 @@
+module gen.CodeGen;
+
+import tango.io.Stdout,
+       Int = tango.text.convert.Integer;
+import tango.core.Array : find;
+
+import llvm.llvm;
+
+import ast.Decl,
+       ast.Stmt,
+       ast.Exp;
+
+import misc.Error;
+
+import lexer.Token;
+
+import sema.SymbolTableBuilder,
+       sema.Visitor;
+
+private char[] genBuildCmp(char[] p)
+{
+    return `
+        (Value l, Value r, char[] n)
+        {
+            return b.buildICmp(IntPredicate.`~p~`, l, r, n);
+        }`;
+}
+
+class CodeGen
+{
+public:
+    this()
+    {
+        alias BinaryExp.Operator op;
+        b = new Builder;
+        
+        opToLLVM = [
+            op.Add      : &b.buildAdd,
+            op.Sub      : &b.buildSub,
+            op.Mul      : &b.buildMul,
+            op.Div      : &b.buildSDiv,
+            op.Eq       : mixin(genBuildCmp("EQ")),
+            op.Ne       : mixin(genBuildCmp("NE")),
+            op.Lt       : mixin(genBuildCmp("SLT")),
+            op.Le       : mixin(genBuildCmp("SLE")),
+            op.Gt       : mixin(genBuildCmp("SGT")),
+            op.Ge       : mixin(genBuildCmp("SGE"))
+        ];
+        table = new SimpleSymbolTable();
+
+        createBasicTypes();
+    }
+
+    ~this()
+    {
+        b.dispose();
+    }
+
+    void gen(Decl[] decls, bool optimize, bool inline)
+    {
+        // create module
+        m = new Module("main_module");
+        scope(exit) m.dispose();
+
+        table.enterScope;
+
+        BytePtr = PointerType.Get(Type.Int8);
+        auto temp = FunctionType.Get(Type.Void, [BytePtr, BytePtr, Type.Int32, Type.Int32]);
+        llvm_memcpy = m.addFunction(temp, "llvm.memcpy.i32");
+        auto registerFunc =
+            (FuncDecl fd)
+            {
+                Type[] param_types;
+                foreach (p; fd.funcArgs)
+                {
+                    DType t = p.env.find(p.identifier).type;
+                    if(cast(DStruct)t)
+                    {
+                        Type pointer = PointerType.Get(llvm(t));
+                        param_types ~= pointer;
+                    }
+                    else
+                        param_types ~= llvm(t);
+                }
+                auto ret_t = llvm(fd.env.find(fd.identifier).type);
+                auto func_t = FunctionType.Get(ret_t, param_types);
+                auto llfunc = m.addFunction(func_t, fd.identifier.get);
+            };
+        auto visitor = new VisitFuncDecls(registerFunc);
+        visitor.visit(decls);
+
+        foreach (decl; decls)
+            genRootDecl(decl);
+
+        table.leaveScope;
+
+        char[] err;
+        m.verify(err);
+        Stderr(err).newline;
+
+        if(optimize)
+            m.optimize(inline);
+
+        m.writeBitcodeToFile("out.bc");
+    }
+
+    void genRootDecl(Decl decl)
+    {
+        switch(decl.declType)
+        {
+            case DeclType.FuncDecl:
+                FuncDecl funcDecl = cast(FuncDecl)decl;
+
+                auto llfunc = m.getNamedFunction(funcDecl.identifier.get);
+                auto func_tp = cast(PointerType)llfunc.type;
+                auto func_t = cast(FunctionType)func_tp.elementType();
+                auto ret_t = func_t.returnType();
+
+                auto bb = llfunc.appendBasicBlock("entry");
+                b.positionAtEnd(bb);
+
+                table.enterScope;
+                foreach (i, v; funcDecl.funcArgs)
+                {
+                    llfunc.getParam(i).name = v.identifier.get;
+                    auto name = v.identifier.get;
+                    if(!cast(PointerType)llfunc.getParam(i).type)
+                    {
+                        auto AI = b.buildAlloca(llfunc.getParam(i).type, name);
+ //                       Value va = b.buildLoad(val, name);
+                        b.buildStore(llfunc.getParam(i), AI);
+                        table[name] = AI;
+                    }
+                    else
+                        table[name] = llfunc.getParam(i);
+                }
+
+                foreach (stmt; funcDecl.statements)
+                    genStmt(stmt);
+
+                // if the function didn't end with a return, we automatically
+                // add one (return 0 as default)
+                if (b.getInsertBlock().terminated() is false)
+                    if (ret_t is Type.Void)
+                        b.buildRetVoid();
+                    else
+                        b.buildRet(ConstantInt.GetS(ret_t, 0));
+
+                table.leaveScope;
+                break;
+
+            case DeclType.VarDecl:
+                auto varDecl = cast(VarDecl)decl;
+                auto sym = varDecl.env.find(varDecl.identifier);
+                Type t = llvm(sym.type);
+                GlobalVariable g = m.addGlobal(t, sym.id.get);
+                g.initializer = ConstantInt.GetS(t, 0);
+                table[varDecl.identifier.get] = g;
+                break;
+        
+            case DeclType.StructDecl:
+                auto structDecl = cast(StructDecl)decl;
+                Type[] types;
+                foreach(varDecl ; structDecl.vars)
+                {
+                    auto sym = varDecl.env.find(varDecl.identifier);
+                    Type t = llvm(sym.type);
+                    types ~= t;
+                }
+
+                StructType t = StructType.Get(types);
+                m.addTypeName(structDecl.identifier.get, t);
+//                table[structDecl.identifier.get] = g;
+
+                break;
+
+            default:
+                break;
+        }
+    }
+
+    void genDecl(Decl decl)
+    {
+        switch(decl.declType)
+        {
+            case DeclType.VarDecl:
+                auto varDecl = cast(VarDecl)decl;
+                auto name = varDecl.identifier.get;
+                auto sym = varDecl.env.find(varDecl.identifier);
+                auto AI = b.buildAlloca(llvm(sym.type), name);
+                table[name] = AI;
+                if (varDecl.init)
+                    buildAssign(varDecl.identifier, varDecl.init);
+                break;
+        
+            default:
+        }
+    }
+
+    struct PE
+    {
+        static char[] NoImplicitConversion =
+            "Can't find an implicit conversion between %0 and %1";
+        static char[] VoidRetInNonVoidFunc =
+            "Only void functions can return without an expression";
+    }
+
+    void sextSmallerToLarger(ref Value left, ref Value right)
+    {
+        if (left.type != right.type)
+        {
+            // try to find a convertion - only works for iX
+            IntegerType l = cast(IntegerType) left.type;
+            IntegerType r = cast(IntegerType) right.type;
+            if (l is null || r is null)
+                throw error(__LINE__, PE.NoImplicitConversion)
+                    .arg(left.type.toString)
+                    .arg(right.type.toString);
+
+            if (l.numBits() < r.numBits())
+                left = b.buildSExt(left, r, ".cast");
+            else
+                right = b.buildSExt(right, l, ".cast");
+        }
+    }
+
+    Value genExpression(Exp exp)
+    {
+        switch(exp.expType)
+        {
+            case ExpType.Binary:
+                auto binaryExp = cast(BinaryExp)exp;
+
+                auto left = genExpression(binaryExp.left);
+                auto right = genExpression(binaryExp.right);
+
+                sextSmallerToLarger(left, right);
+
+                OpBuilder op = opToLLVM[binaryExp.op];
+
+                return op(left, right, ".");
+
+            case ExpType.IntegerLit:
+                auto integetLit = cast(IntegerLit)exp;
+                auto val = integetLit.token.get;
+                return ConstantInt.GetS(Type.Int32, Integer.parse(val));
+            case ExpType.Negate:
+                auto negateExp = cast(NegateExp)exp;
+                auto target = genExpression(negateExp.exp);
+                return b.buildNeg(target, "neg");
+            case ExpType.AssignExp:
+                auto assignExp = cast(AssignExp)exp;
+                return buildAssign(assignExp.identifier, assignExp.exp);
+            case ExpType.CallExp:
+                auto callExp = cast(CallExp)exp;
+                auto func_sym = exp.env.find(cast(Identifier)callExp.exp);
+                Value[] args;
+                foreach (arg; callExp.args)
+                {
+                    Value v = genExpression(arg);
+
+                    if(auto ptr = cast(PointerType)v.type)
+                    {
+                        args ~= v;
+                    }
+                    else
+                        args ~= v;
+
+                }
+                return b.buildCall(m.getNamedFunction(func_sym.id.get), args, ".call");
+            case ExpType.Identifier:
+                auto identifier = cast(Identifier)exp;
+                auto sym = exp.env.find(identifier);
+                if(cast(DStruct)sym.type)
+                    return table.find(sym.id.get);
+                else
+                    return b.buildLoad(table.find(sym.id.get), sym.id.get);
+            case ExpType.MemberLookup:
+                auto v = getPointer(exp);
+                return b.buildLoad(v, v.name);
+        }
+        assert(0, "Reached end of switch in genExpression");
+        return null;
+    }
+
+    void genStmt(Stmt stmt)
+    {
+        switch(stmt.stmtType)
+        {
+            case StmtType.Compound:
+                auto stmts = cast(CompoundStatement)stmt;
+                foreach (s; stmts.statements)
+                    genStmt(s);
+                break;
+            case StmtType.Return:
+                auto ret = cast(ReturnStmt)stmt;
+                auto sym = stmt.env.parentFunction();
+                Type t = llvm(sym.type);
+                if (ret.exp is null)
+                    if (t is Type.Void)
+                    {
+                        b.buildRetVoid();
+                        return;
+                    }
+                    else
+                        throw error(__LINE__, PE.VoidRetInNonVoidFunc);
+
+                Value v = genExpression(ret.exp);
+                if (v.type != t)
+                {
+                    IntegerType v_t = cast(IntegerType) v.type;
+                    IntegerType i_t = cast(IntegerType) t;
+                    if (v_t is null || i_t is null)
+                        throw error(__LINE__, PE.NoImplicitConversion)
+                            .arg(v.type.toString)
+                            .arg(t.toString);
+
+                    if (v_t.numBits() < i_t.numBits())
+                        v = b.buildSExt(v, t, ".cast");
+                    else
+                        v = b.buildTrunc(v, t, ".cast");
+                }
+                b.buildRet(v);
+                break;
+            case StmtType.Decl:
+                auto declStmt = cast(DeclStmt)stmt;
+                genDecl(declStmt.decl);
+                break;
+            case StmtType.Exp:
+                auto expStmt = cast(ExpStmt)stmt;
+                genExpression(expStmt.exp);
+                break;
+            case StmtType.If:
+                auto ifStmt = cast(IfStmt)stmt;
+                Value cond = genExpression(ifStmt.cond);
+                if (cond.type !is Type.Int1)
+                {
+                    Value False = ConstantInt.GetS(cond.type, 0);
+                    cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond");
+                }
+                auto func_name = stmt.env.parentFunction().id.get;
+                Function func = m.getNamedFunction(func_name);
+                bool has_else = (ifStmt.else_body !is null);
+
+                auto thenBB = func.appendBasicBlock("then");
+                auto elseBB = has_else? func.appendBasicBlock("else") : null;
+                auto mergeBB = func.appendBasicBlock("merge");
+
+                b.buildCondBr(cond, thenBB, has_else? elseBB : mergeBB);
+                b.positionAtEnd(thenBB);
+                genStmt(ifStmt.then_body);
+                thenBB = b.getInsertBlock();
+                if (b.getInsertBlock().terminated() is false)
+                    b.buildBr(mergeBB);
+
+                if (has_else)
+                {
+                    b.positionAtEnd(elseBB);
+                    genStmt(ifStmt.else_body);
+                    elseBB = b.getInsertBlock();
+                    if (elseBB.terminated() is false)
+                        b.buildBr(mergeBB);
+                }
+
+                b.positionAtEnd(mergeBB);
+                break;
+            case StmtType.While:
+                auto wStmt = cast(WhileStmt)stmt;
+                auto func_name = stmt.env.parentFunction().id.get;
+                Function func = m.getNamedFunction(func_name);
+
+                auto condBB = func.appendBasicBlock("cond");
+                auto bodyBB = func.appendBasicBlock("body");
+                auto doneBB = func.appendBasicBlock("done");
+
+                b.buildBr(condBB);
+                b.positionAtEnd(condBB);
+                Value cond = genExpression(wStmt.cond);
+                if (cond.type !is Type.Int1)
+                {
+                    Value False = ConstantInt.GetS(cond.type, 0);
+                    cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond");
+                }
+                b.buildCondBr(cond, bodyBB, doneBB);
+
+                b.positionAtEnd(bodyBB);
+                genStmt(wStmt.whileBody);
+                if (b.getInsertBlock().terminated() is false)
+                    b.buildBr(condBB);
+
+                b.positionAtEnd(doneBB);
+                break;
+            case StmtType.Switch:
+                auto sw = cast(SwitchStmt)stmt;
+                Value cond = genExpression(sw.cond);
+
+                auto func_name = stmt.env.parentFunction().id.get;
+                Function func = m.getNamedFunction(func_name);
+
+                BasicBlock oldBB = b.getInsertBlock();
+                BasicBlock defBB;
+                BasicBlock endBB = func.appendBasicBlock("sw.end");
+                if (sw.defaultBlock)
+                {
+                    defBB = Function.InsertBasicBlock(endBB, "sw.def");
+                    b.positionAtEnd(defBB);
+                    foreach (case_statement; sw.defaultBlock)
+                        genStmt(case_statement);
+                    if (!defBB.terminated())
+                        b.buildBr(endBB);
+                    b.positionAtEnd(oldBB);
+                }
+                else
+                    defBB = endBB;
+                auto SI = b.buildSwitch(cond, defBB, sw.cases.length);
+                foreach (c; sw.cases)
+                {
+                    BasicBlock prevBB;
+                    foreach (i, val; c.values)
+                    {
+                        auto BB = Function.InsertBasicBlock(defBB, "sw.bb");
+                        SI.addCase(ConstantInt.GetS(cond.type, c.values_converted[i]), BB);
+
+                        if (i + 1 == c.values.length)
+                        {
+                            b.positionAtEnd(BB);
+                            foreach (case_statement; c.stmts)
+                                genStmt(case_statement);
+                            if (!BB.terminated())
+                                b.buildBr(c.followedByDefault? defBB : endBB);
+                        }
+
+                        if (prevBB !is null && !prevBB.terminated())
+                        {
+                            b.positionAtEnd(prevBB);
+                            b.buildBr(BB);
+                        }
+                        prevBB = BB;
+                    }
+                }
+                b.positionAtEnd(endBB);
+                break;
+        }
+    }
+
+    Value getPointer(Exp exp)
+    {
+        switch(exp.expType)
+        {
+            case ExpType.Identifier:
+                auto identifier = cast(Identifier)exp;
+                auto sym = exp.env.find(identifier);
+                return table.find(sym.id.get);
+            case ExpType.MemberLookup:
+                auto mem = cast(MemberLookup)exp;
+                switch(mem.target.expType)
+                {
+                    case ExpType.Identifier:
+                        auto identifier = cast(Identifier)mem.target;
+                        auto child = mem.child;
+                        auto sym = exp.env.find(identifier);
+                        auto symChild = child.env.find(child);
+                        Value v = table.find(sym.id.get);
+                        DType t = sym.type;
+                        auto st = cast(DStruct)t;
+
+                        int i = 0;
+                        foreach(char[] name, DType type ; st.members)
+                            if(name == child.get)
+                                break;
+                            else
+                                i++;
+
+                        Value[] vals;   
+                        vals ~= ConstantInt.Get(IntegerType.Int32, 0, false);
+                        vals ~= ConstantInt.Get(IntegerType.Int32, i, false);
+
+                        Value val = b.buildGEP(v, vals, sym.id.get~"."~child.get);
+                        return val;
+
+                    default:
+                        Value val = genExpression(exp);
+                        auto AI = b.buildAlloca(val.type, ".s");
+                        return b.buildStore(val, AI);
+                }
+            default:
+                Value val = genExpression(exp);
+                auto AI = b.buildAlloca(val.type, ".s");
+                return b.buildStore(val, AI);
+        }
+        assert(0, "Reached end of switch in getPointer");
+        return null;
+    }
+
+    private Value buildAssign(Exp target, Exp exp)
+    {
+        Value t = getPointer(target);
+        Value v = genExpression(exp);
+
+        auto a = cast(PointerType)t.type;
+
+        assert(a, "Assing to type have to be of type PointerType");
+
+        Type value_type = v.type;
+        if (auto value_ptr = cast(PointerType)v.type)
+        {
+            value_type = value_ptr.elementType;
+
+            if (a.elementType is value_type && cast(StructType)value_type)
+            {
+                // bitcast "from" to i8*
+                Value from = b.buildBitCast(v, BytePtr, ".copy_from");
+                // bitcast "to" to i8*
+                Value to = b.buildBitCast(t, BytePtr, ".copy_to");
+                // call llvm.memcpy.i32( "from", "to", type_size, alignment (32 in clang) );
+                b.buildCall(llvm_memcpy, [from, to, ConstantInt.GetS(Type.Int32, 8), ConstantInt.GetS(Type.Int32, 32)], null);
+                // return "to"
+                return t;
+            }
+        }
+
+        if (value_type != a.elementType)
+        {
+            IntegerType v_t = cast(IntegerType) value_type;
+            IntegerType i_t = cast(IntegerType) a.elementType;
+            if (v_t is null || i_t is null)
+                throw error(__LINE__, PE.NoImplicitConversion)
+                    .arg(a.elementType.toString)
+                    .arg(v.type.toString);
+
+            if (v_t.numBits() < i_t.numBits())
+                v = b.buildSExt(v, a.elementType, ".cast");
+            else
+                v = b.buildTrunc(v, a.elementType, ".cast");
+        }
+        return b.buildStore(v, t);
+    }
+
+    Error error(uint line, char[] msg)
+    {
+        return new Error(msg);
+    }
+
+    /**
+      Get the LLVM Type corresponding to a DType.
+
+      Currently using the built-in associative array - not sure if it works
+      well when the hashes are so uniform.
+
+      Other possibilities would be to find a hash-function that works on
+      something as small as 4 bytes or to create a sparse array perhaps.
+     */
+    Type llvm(DType t)
+    {
+        if (auto llvm_t = t in type_map)
+            return *llvm_t;
+        return llvmCreateNew(t);
+    }
+
+    // Create an LLVM type and insert it into the type map, and return the
+    // result
+    Type llvmCreateNew(DType t)
+    {
+        if (auto i = cast(DInteger)t)
+        {
+            Type res = IntegerType.Get(i.byteSize() * 8);
+            type_map[t] = res;
+            return res;
+        }
+        assert(0, "Only integers are supported");
+    }
+
+    // Might as well insert all the basic types from the start
+    void createBasicTypes()
+    {
+        type_map[DType.Void] = Type.Void;
+
+        type_map[DType.Bool]   = Type.Int1;
+        type_map[DType.Byte]   = Type.Int8;
+        type_map[DType.UByte]  = Type.Int8;
+        type_map[DType.Short]  = Type.Int16;
+        type_map[DType.UShort] = Type.Int16;
+        type_map[DType.Int]    = Type.Int32;
+        type_map[DType.UInt]   = Type.Int32;
+        type_map[DType.Long]   = Type.Int64;
+        type_map[DType.ULong]  = Type.Int64;
+    }
+
+private:
+
+    // llvm stuff
+    Module m;
+    Builder b;
+    Function llvm_memcpy;
+    Type BytePtr;
+    Type[DType] type_map;
+
+    FuncDecl[char[]] functions;
+
+    SimpleSymbolTable table;
+    alias Value delegate(Value, Value, char[]) OpBuilder;
+    static OpBuilder[BinaryExp.Operator] opToLLVM;
+}
+
+private class VisitFuncDecls : Visitor!(void)
+{
+    void delegate(FuncDecl) dg;
+    this(void delegate(FuncDecl funcDecl) dg)
+    {
+        this.dg = dg;
+    }
+
+    override void visit(Decl[] decls)
+    {
+        foreach (decl; decls)
+            if (auto f = cast(FuncDecl)decl)
+                dg(f);
+    }
+}
+
+private class SimpleSymbolTable
+{
+    Value[char[]][] namedValues;
+
+    void enterScope()
+    {
+        namedValues ~= cast(Value[char[]])["__dollar":null];
+    }
+
+    void leaveScope()
+    {
+        namedValues.length = namedValues.length - 1;
+    }
+
+    Value put(Value val, char[] key)
+    {
+        namedValues[$ - 1][key] = val;
+        return val;
+    }
+
+    Value find(char[] key)
+    {
+        foreach_reverse (map; namedValues)
+            if(auto val_ptr = key in map)
+                return *val_ptr;
+        return null;
+    }
+
+    alias find opIndex;
+    alias put opIndexAssign;
+}
+