Mercurial > projects > dang
view gen/LLVMGen.d @ 45:9bc660cbdbec new_gen
If statements are back
Also fixed a bug in the codegen preventing return in the else branch, now
it is optional.
Also found an issue with the way we are generating our llvm from ifs - it
doesn't mean anything but the code looks ugly.
if (cond_1)
if (cond_2)
statement;
return 0;
Becomes:
br cond_1, then, merge
then:
br cond_2 then2, merge2
merge:
ret 0
then2:
statements
merge2:
br merge
This is because we use appendBasicBlock on the function
author | Anders Halager <halager@gmail.com> |
---|---|
date | Wed, 23 Apr 2008 16:43:42 +0200 |
parents | 495188f9078e |
children | 90fb4fdfefdd |
line wrap: on
line source
module gen.LLVMGen; import tango.io.Stdout, Int = tango.text.convert.Integer; import tango.core.Array : find; import llvm.llvm; import ast.Decl, ast.Stmt, ast.Exp; import misc.Error; import lexer.Token; import sema.SymbolTableBuilder, sema.Visitor; private char[] genBuildCmp(char[] p) { return ` (Value l, Value r, char[] n) { return b.buildICmp(IntPredicate.`~p~`, l, r, n); }`; } class LLVMGen { public: this() { alias BinaryExp.Operator op; b = new Builder; opToLLVM = [ op.Add : &b.buildAdd, op.Sub : &b.buildSub, op.Mul : &b.buildMul, op.Div : &b.buildSDiv, op.Eq : mixin(genBuildCmp("EQ")), op.Ne : mixin(genBuildCmp("NE")), op.Lt : mixin(genBuildCmp("SLT")), op.Le : mixin(genBuildCmp("SLE")), op.Gt : mixin(genBuildCmp("SGT")), op.Ge : mixin(genBuildCmp("SGE")) ]; table = new SimpleSymbolTable(); } ~this() { b.dispose(); } void gen(Decl[] decls, bool optimize, bool inline) { // create module m = new Module("main_module"); scope(exit) m.dispose(); table.enterScope; BytePtr = PointerType.Get(Type.Int8); auto temp = FunctionType.Get(Type.Void, [BytePtr, BytePtr, Type.Int32, Type.Int32]); llvm_memcpy = m.addFunction(temp, "llvm.memcpy.i32"); auto registerFunc = (FuncDecl fd) { Type[] param_types; foreach (p; fd.funcArgs) { DType t = p.env.find(p.identifier).type; if(cast(DStruct)t) { Type pointer = PointerType.Get(t.llvm()); param_types ~= pointer; } else param_types ~= t.llvm(); } auto ret_t = fd.env.find(fd.identifier).type.llvm(); auto func_t = FunctionType.Get(ret_t, param_types); auto llfunc = m.addFunction(func_t, fd.identifier.get); }; auto visitor = new VisitFuncDecls(registerFunc); visitor.visit(decls); foreach (decl; decls) genRootDecl(decl); table.leaveScope; char[] err; m.verify(err); Stderr(err).newline; if(optimize) m.optimize(inline); m.writeBitcodeToFile("out.bc"); } void genRootDecl(Decl decl) { switch(decl.declType) { case DeclType.FuncDecl: FuncDecl funcDecl = cast(FuncDecl)decl; auto llfunc = m.getNamedFunction(funcDecl.identifier.get); auto func_tp = cast(PointerType)llfunc.type; auto func_t = cast(FunctionType)func_tp.elementType(); auto ret_t = func_t.returnType(); auto bb = llfunc.appendBasicBlock("entry"); b.positionAtEnd(bb); table.enterScope; foreach (i, v; funcDecl.funcArgs) { llfunc.getParam(i).name = v.identifier.get; auto name = v.identifier.get; if(!cast(PointerType)llfunc.getParam(i).type) { auto AI = b.buildAlloca(llfunc.getParam(i).type, name); // Value va = b.buildLoad(val, name); b.buildStore(llfunc.getParam(i), AI); table[name] = AI; } else table[name] = llfunc.getParam(i); } foreach (stmt; funcDecl.statements) genStmt(stmt); // if the function didn't end with a return, we automatically // add one (return 0 as default) if (b.getInsertBlock().terminated() is false) if (ret_t is Type.Void) b.buildRetVoid(); else b.buildRet(ConstantInt.GetS(ret_t, 0)); table.leaveScope; break; case DeclType.VarDecl: auto varDecl = cast(VarDecl)decl; auto sym = varDecl.env.find(varDecl.identifier); Type t = sym.type.llvm(); GlobalVariable g = m.addGlobal(t, sym.id.get); g.initializer = ConstantInt.GetS(t, 0); table[varDecl.identifier.get] = g; break; case DeclType.StructDecl: auto structDecl = cast(StructDecl)decl; Type[] types; foreach(varDecl ; structDecl.vars) { auto sym = varDecl.env.find(varDecl.identifier); Type t = sym.type.llvm(); types ~= t; } StructType t = StructType.Get(types); m.addTypeName(structDecl.identifier.get, t); // table[structDecl.identifier.get] = g; break; default: break; } } void genDecl(Decl decl) { switch(decl.declType) { case DeclType.VarDecl: auto varDecl = cast(VarDecl)decl; auto name = varDecl.identifier.get; auto sym = varDecl.env.find(varDecl.identifier); auto AI = b.buildAlloca(sym.type.llvm(), name); table[name] = AI; if (varDecl.init) buildAssign(varDecl.identifier, varDecl.init); break; default: } } struct PE { static char[] NoImplicitConversion = "Can't find an implicit conversion between %0 and %1"; static char[] VoidRetInNonVoidFunc = "Only void functions can return without an expression"; } void sextSmallerToLarger(ref Value left, ref Value right) { if (left.type != right.type) { // try to find a convertion - only works for iX IntegerType l = cast(IntegerType) left.type; IntegerType r = cast(IntegerType) right.type; if (l is null || r is null) throw error(__LINE__, PE.NoImplicitConversion) .arg(left.type.toString) .arg(right.type.toString); if (l.numBits() < r.numBits()) left = b.buildSExt(left, r, ".cast"); else right = b.buildSExt(right, l, ".cast"); } } Value genExpression(Exp exp) { switch(exp.expType) { case ExpType.Binary: auto binaryExp = cast(BinaryExp)exp; auto left = genExpression(binaryExp.left); auto right = genExpression(binaryExp.right); sextSmallerToLarger(left, right); OpBuilder op = opToLLVM[binaryExp.op]; return op(left, right, "."); case ExpType.IntegerLit: auto integetLit = cast(IntegerLit)exp; auto val = integetLit.token.get; return ConstantInt.GetS(Type.Int32, Integer.parse(val)); case ExpType.Negate: auto negateExp = cast(NegateExp)exp; auto target = genExpression(negateExp.exp); return b.buildNeg(target, "neg"); case ExpType.AssignExp: auto assignExp = cast(AssignExp)exp; return buildAssign(assignExp.identifier, assignExp.exp); case ExpType.CallExp: auto callExp = cast(CallExp)exp; auto func_sym = exp.env.find(cast(Identifier)callExp.exp); Value[] args; foreach (arg; callExp.args) { Value v = genExpression(arg); if(auto ptr = cast(PointerType)v.type) { args ~= v; } else args ~= v; } return b.buildCall(m.getNamedFunction(func_sym.id.get), args, ".call"); case ExpType.Identifier: auto identifier = cast(Identifier)exp; auto sym = exp.env.find(identifier); if(cast(DStruct)sym.type) return table.find(sym.id.get); else return b.buildLoad(table.find(sym.id.get), sym.id.get); case ExpType.MemberLookup: auto v = getPointer(exp); return b.buildLoad(v, v.name); } assert(0, "Reached end of switch in genExpression"); return null; } void genStmt(Stmt stmt) { switch(stmt.stmtType) { case StmtType.Compound: auto stmts = cast(CompoundStatement)stmt; foreach (s; stmts.statements) genStmt(s); break; case StmtType.Return: auto ret = cast(ReturnStmt)stmt; auto sym = stmt.env.parentFunction(); Type t = sym.type.llvm(); if (ret.exp is null) if (t is Type.Void) { b.buildRetVoid(); return; } else throw error(__LINE__, PE.VoidRetInNonVoidFunc); Value v = genExpression(ret.exp); if (v.type != t) { IntegerType v_t = cast(IntegerType) v.type; IntegerType i_t = cast(IntegerType) t; if (v_t is null || i_t is null) throw error(__LINE__, PE.NoImplicitConversion) .arg(v.type.toString) .arg(t.toString); if (v_t.numBits() < i_t.numBits()) v = b.buildSExt(v, t, ".cast"); else v = b.buildTrunc(v, t, ".cast"); } b.buildRet(v); break; case StmtType.Decl: auto declStmt = cast(DeclStmt)stmt; genDecl(declStmt.decl); break; case StmtType.Exp: auto expStmt = cast(ExpStmt)stmt; genExpression(expStmt.exp); break; case StmtType.If: auto ifStmt = cast(IfStmt)stmt; Value cond = genExpression(ifStmt.cond); if (cond.type !is Type.Int1) { Value False = ConstantInt.GetS(cond.type, 0); cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond"); } auto func_name = stmt.env.parentFunction().id.get; Function func = m.getNamedFunction(func_name); bool has_else = (ifStmt.else_body !is null); auto thenBB = func.appendBasicBlock("then"); auto elseBB = has_else? func.appendBasicBlock("else") : null; auto mergeBB = func.appendBasicBlock("merge"); b.buildCondBr(cond, thenBB, has_else? elseBB : mergeBB); b.positionAtEnd(thenBB); genStmt(ifStmt.then_body); if (b.getInsertBlock().terminated() is false) b.buildBr(mergeBB); thenBB = b.getInsertBlock(); if (has_else) { b.positionAtEnd(elseBB); genStmt(ifStmt.else_body); if (elseBB.terminated() is false) b.buildBr(mergeBB); elseBB = b.getInsertBlock(); } b.positionAtEnd(mergeBB); break; case StmtType.While: auto wStmt = cast(WhileStmt)stmt; auto func_name = stmt.env.parentFunction().id.get; Function func = m.getNamedFunction(func_name); auto condBB = func.appendBasicBlock("cond"); auto bodyBB = func.appendBasicBlock("body"); auto doneBB = func.appendBasicBlock("done"); b.buildBr(condBB); b.positionAtEnd(condBB); Value cond = genExpression(wStmt.cond); if (cond.type !is Type.Int1) { Value False = ConstantInt.GetS(cond.type, 0); cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond"); } b.buildCondBr(cond, bodyBB, doneBB); b.positionAtEnd(bodyBB); foreach (s; wStmt.stmts) genStmt(s); if (b.getInsertBlock().terminated() is false) b.buildBr(condBB); b.positionAtEnd(doneBB); break; case StmtType.Switch: auto sw = cast(SwitchStmt)stmt; Value cond = genExpression(sw.cond); auto func_name = stmt.env.parentFunction().id.get; Function func = m.getNamedFunction(func_name); BasicBlock oldBB = b.getInsertBlock(); BasicBlock defBB; BasicBlock endBB = func.appendBasicBlock("sw.end"); if (sw.defaultBlock) { defBB = Function.InsertBasicBlock(endBB, "sw.def"); b.positionAtEnd(defBB); foreach (case_statement; sw.defaultBlock) genStmt(case_statement); if (!defBB.terminated()) b.buildBr(endBB); b.positionAtEnd(oldBB); } else defBB = endBB; auto SI = b.buildSwitch(cond, defBB, sw.cases.length); foreach (c; sw.cases) { BasicBlock prevBB; foreach (i, val; c.values) { auto BB = Function.InsertBasicBlock(defBB, "sw.bb"); SI.addCase(ConstantInt.GetS(cond.type, c.values_converted[i]), BB); if (i + 1 == c.values.length) { b.positionAtEnd(BB); foreach (case_statement; c.stmts) genStmt(case_statement); if (!BB.terminated()) b.buildBr(c.followedByDefault? defBB : endBB); } if (prevBB !is null && !prevBB.terminated()) { b.positionAtEnd(prevBB); b.buildBr(BB); } prevBB = BB; } } b.positionAtEnd(endBB); break; } } Value getPointer(Exp exp) { switch(exp.expType) { case ExpType.Identifier: auto identifier = cast(Identifier)exp; auto sym = exp.env.find(identifier); return table.find(sym.id.get); case ExpType.MemberLookup: auto mem = cast(MemberLookup)exp; switch(mem.target.expType) { case ExpType.Identifier: auto identifier = cast(Identifier)mem.target; auto child = mem.child; auto sym = exp.env.find(identifier); auto symChild = child.env.find(child); Value v = table.find(sym.id.get); DType t = sym.type; auto st = cast(DStruct)t; int i = 0; foreach(char[] name, DType type ; st.members) if(name == child.get) break; else i++; Value[] vals; vals ~= ConstantInt.Get(IntegerType.Int32, 0, false); vals ~= ConstantInt.Get(IntegerType.Int32, i, false); Value val = b.buildGEP(v, vals, sym.id.get~"."~child.get); return val; default: Value val = genExpression(exp); auto AI = b.buildAlloca(val.type, ".s"); return b.buildStore(val, AI); } default: Value val = genExpression(exp); auto AI = b.buildAlloca(val.type, ".s"); return b.buildStore(val, AI); } assert(0, "Reached end of switch in getPointer"); return null; } private Value buildAssign(Exp target, Exp exp) { Value t = getPointer(target); Value v = genExpression(exp); auto a = cast(PointerType)t.type; assert(a, "Assing to type have to be of type PointerType"); Type value_type = v.type; if (auto value_ptr = cast(PointerType)v.type) { value_type = value_ptr.elementType; if (a.elementType is value_type && cast(StructType)value_type) { // bitcast "from" to i8* Value from = b.buildBitCast(v, BytePtr, ".copy_from"); // bitcast "to" to i8* Value to = b.buildBitCast(t, BytePtr, ".copy_to"); // call llvm.memcpy.i32( "from", "to", type_size, alignment (32 in clang) ); b.buildCall(llvm_memcpy, [from, to, ConstantInt.GetS(Type.Int32, 8), ConstantInt.GetS(Type.Int32, 32)], null); // return "to" return t; } } if (value_type != a.elementType) { IntegerType v_t = cast(IntegerType) value_type; IntegerType i_t = cast(IntegerType) a.elementType; if (v_t is null || i_t is null) throw error(__LINE__, PE.NoImplicitConversion) .arg(a.elementType.toString) .arg(v.type.toString); if (v_t.numBits() < i_t.numBits()) v = b.buildSExt(v, a.elementType, ".cast"); else v = b.buildTrunc(v, a.elementType, ".cast"); } return b.buildStore(v, t); } Error error(uint line, char[] msg) { return new Error(msg); } private: // llvm stuff Module m; Builder b; Function llvm_memcpy; Type BytePtr; FuncDecl[char[]] functions; SimpleSymbolTable table; alias Value delegate(Value, Value, char[]) OpBuilder; static OpBuilder[BinaryExp.Operator] opToLLVM; } private class VisitFuncDecls : Visitor!(void) { void delegate(FuncDecl) dg; this(void delegate(FuncDecl funcDecl) dg) { this.dg = dg; } override void visit(Decl[] decls) { foreach (decl; decls) if (auto f = cast(FuncDecl)decl) dg(f); } } private class SimpleSymbolTable { Value[char[]][] namedValues; void enterScope() { namedValues ~= cast(Value[char[]])["__dollar":null]; } void leaveScope() { namedValues.length = namedValues.length - 1; } Value put(Value val, char[] key) { namedValues[$ - 1][key] = val; return val; } Value find(char[] key) { foreach_reverse (map; namedValues) if(auto val_ptr = key in map) return *val_ptr; return null; } alias find opIndex; alias put opIndexAssign; }