Mercurial > projects > dang
view gen/LLVMGen.d @ 16:bd5f9f81c24b
Oops.. while was generating incorrect code
Need to add a "br label %while_cond" just before the while_cond label
author | Anders Halager <halager@gmail.com> |
---|---|
date | Fri, 18 Apr 2008 15:53:30 +0200 |
parents | e5caf9971207 |
children | 7e79c42d20f5 |
line wrap: on
line source
module gen.LLVMGen; import tango.io.Stdout, Int = tango.text.convert.Integer; import tango.core.Array : find; import ast.Decl, ast.Stmt, ast.Exp; import lexer.Token; import sema.SymbolTableBuilder; class LLVMGen { public: this() { typeToLLVM = [ "int"[] : "i32"[], "byte" : "i8", "short" : "i16", "long" : "i64", "bool" : "i1", "float" : "float", "double" : "double", "void" : "void" ]; alias BinaryExp.Operator op; opToLLVM = [ op.Add : "add"[], op.Sub : "sub", op.Mul : "mul", op.Div : "div", op.Eq : "icmp eq", op.Ne : "icmp ne", op.Lt : "icmp slt", op.Le : "icmp sle", op.Gt : "icmp glt", op.Ge : "icmp gle" ]; table = new SimpleSymbolTable(); } void gen(Decl[] decls) { // Fill in scopes table.enterScope; foreach(decl ; decls) genRootDecl(decl); table.leaveScope; } void genRootDecl(Decl decl) { switch(decl.declType) { case DeclType.FuncDecl: FuncDecl funcDecl = cast(FuncDecl)decl; auto return_type = typeToLLVM[funcDecl.type.token.get]; printBeginLine("define "); print(return_type); print(" @"); genIdentifier(funcDecl.identifier); print("("); table.enterScope; Identifier[] args; foreach(i, funcArg ; funcDecl.funcArgs) { print(typeToLLVM[funcArg.type.token.get]); print(" %"); print("."~Integer.toString(i)); args ~= funcArg.identifier; table.find(funcArg.identifier.get); if(i+1 < funcDecl.funcArgs.length) print(", "); } printEndLine(") {"); indent; foreach(i, arg ; args) { auto sym = arg.env.find(arg); auto type = typeToLLVM[sym.type.get]; printBeginLine("%"~arg.get); printEndLine(" = alloca " ~ type); printBeginLine("store " ~ type ~ " %."); print(Integer.toString(i)); print(", " ~ type ~ "* %"); printEndLine(arg.get); } printEndLine(); foreach (stmt; funcDecl.statements) genStmt(stmt); if (return_type == "void") { printBeginLine("ret void"); printEndLine(); } table.leaveScope; dedent; printBeginLine("}"); printEndLine(); break; case DeclType.VarDecl: auto varDecl = cast(VarDecl)decl; printBeginLine("@"); genIdentifier(varDecl.identifier); print(" = "); if(varDecl.init) { if(cast(IntegerLit)varDecl.init) printEndLine("global i32 " ~ (cast(IntegerLit)varDecl.init).token.get); else assert(0,"Declaring an variable to an expression is not allowed"); } else printEndLine("i32 0"); printEndLine(); default: } } void genDecl(Decl decl) { switch(decl.declType) { case DeclType.VarDecl: auto varDecl = cast(VarDecl)decl; printBeginLine("%"); print(table.find(varDecl.identifier.get)); print(" = alloca "); printEndLine(typeToLLVM[varDecl.type.get]); if(varDecl.init) { auto assignExp = new AssignExp(varDecl.identifier, varDecl.init); assignExp.env = decl.env; assignExp.identifier.env = decl.env; genExpression(assignExp); } default: } } void unify(Ref* a, Ref* b) { if (a.type != b.type) { auto a_val = intTypes.find(a.type); auto b_val = intTypes.find(b.type); // swap types so a is always the "largest" type if (a_val < b_val) { Ref* tmp = b; b = a; a = tmp; } auto res = table.find("%.cast"); printBeginLine(res); printCastFromTo(b, a); print(*b); print(" to "); printEndLine(a.type); b.type = a.type; b.name = res; } } Ref genExpression(Exp exp) { switch(exp.expType) { case ExpType.Binary: auto binaryExp = cast(BinaryExp)exp; auto left = genExpression(binaryExp.left); auto right = genExpression(binaryExp.right); unify(&left, &right); auto res = Ref(left.type, table.find); printBeginLine(res.name); print(" = "~opToLLVM[binaryExp.op]~" "); print(left); print(", "); printEndLine(right.name); // exp always returns known type (== returns bool no matter // what the params are) if (binaryExp.resultType) res.type = typeToLLVM[binaryExp.resultType]; return res; case ExpType.IntegerLit: auto integetLit = cast(IntegerLit)exp; auto t = integetLit.token; return Ref("int", t.get, true); case ExpType.Negate: auto negateExp = cast(NegateExp)exp; auto target = genExpression(negateExp.exp); auto res = table.find; printBeginLine(res); print(" = sub "~target.type~" 0, "); printEndLine(target.name); return Ref(target.type, res); case ExpType.AssignExp: auto assignExp = cast(AssignExp)exp; auto sym = exp.env.find(assignExp.identifier); Ref val = genExpression(assignExp.exp); Ref r = Ref(typeToLLVM[sym.type.get], val.name); if (val.type != r.type) { auto res = table.find("%.cast"); printBeginLine(res); printCastFromTo(val.type, r.type); print(val); print(" to "); printEndLine(r.type); r.name = res; } printBeginLine("store "); print(r); print(", "); print(r.type ~ "* %"); printEndLine(assignExp.identifier.get); break; case ExpType.CallExp: auto callExp = cast(CallExp)exp; auto func_sym = exp.env.find(cast(Identifier)callExp.exp); auto func_type = typeToLLVM[func_sym.type.get]; Ref[] args; foreach(i, arg ; callExp.args) args ~= genExpression(arg); char[] res = ""; if (func_type != "void") { res = table.find; printBeginLine(res); print(" = call "); } else printBeginLine("call "); print(func_type); print(" @"); print(func_sym.id.get); print("("); foreach(i, arg ; args) { print(arg); if(i+1 < args.length) print(", "); } printEndLine(")"); return Ref(func_sym.type.get, res); case ExpType.Identifier: auto identifier = cast(Identifier)exp; auto sym = exp.env.find(identifier); char[] res = table.find; printBeginLine(res); print(" = load "); print(typeToLLVM[sym.type.get]); print("* %"); printEndLine(sym.id.name); return Ref(sym.type.get, res); } return Ref(); } void genStmt(Stmt stmt) { switch(stmt.stmtType) { case StmtType.Return: auto ret = cast(ReturnStmt)stmt; auto sym = stmt.env.parentFunction(); auto type = typeToLLVM[sym.type.get]; Ref res = genExpression(ret.exp); if (type != res.type) { auto cast_res = table.find("%.cast"); printBeginLine(cast_res); printCastFromTo(res.type, type); print(res); print(" to "); printEndLine(type); res.name = cast_res; res.type = type; } printBeginLine("ret "); printEndLine(res); break; case StmtType.Decl: auto declStmt = cast(DeclStmt)stmt; genDecl(declStmt.decl); break; case StmtType.Exp: auto expStmt = cast(ExpStmt)stmt; genExpression(expStmt.exp); break; case StmtType.If: auto ifStmt = cast(IfStmt)stmt; Ref val = genExpression(ifStmt.cond); auto cond = table.find("%.cond"); printBeginLine(cond); print(" = icmp ne "); print(val); printEndLine(", 0"); auto then_branch = table.find("then"); auto else_branch = table.find("else"); auto done_label = table.find("done"); printBeginLine("br i1 "); print(cond); print(", label %"); print(then_branch); print(", label %"); printEndLine(ifStmt.else_body? else_branch : done_label); printBeginLine(then_branch); printEndLine(":"); indent(); foreach (s; ifStmt.then_body) genStmt(s); printBeginLine("br label %"); printEndLine(done_label); dedent(); if (ifStmt.else_body) { printBeginLine(else_branch); printEndLine(":"); indent(); foreach (s; ifStmt.else_body) genStmt(s); printBeginLine("br label %"); printEndLine(done_label); dedent(); } printBeginLine(done_label); printEndLine(":"); break; case StmtType.While: auto wStmt = cast(WhileStmt)stmt; auto body_label = table.find("while_body"); auto cond_label = table.find("while_cond"); auto done_label = table.find("while_done"); printBeginLine("br label %"); printEndLine(cond_label); printBeginLine(cond_label); printEndLine(":"); indent(); Ref val = genExpression(wStmt.cond); auto cond = table.find("%.cond"); printBeginLine(cond); print(" = icmp ne "); print(val); printEndLine(", 0"); printBeginLine("br i1 "); print(cond); print(", label %"); print(body_label); print(", label %"); printEndLine(done_label); dedent(); printBeginLine(body_label); printEndLine(":"); indent(); foreach (s; wStmt.stmts) genStmt(s); printBeginLine("br label %"); printEndLine(cond_label); dedent(); printBeginLine(done_label); printEndLine(":"); break; } } void genIdentifier(Identifier identifier) { print(identifier.get); } void indent() { tabIndex ~= tabType; } void dedent() { tabIndex = tabIndex[0 .. $-tabType.length]; } void printBeginLine(char[] line = "") { Stdout(tabIndex~line); } void printBeginLine(Ref r) { Stdout(tabIndex~r.type~" "~r.name); } void printEndLine(char[] line = "") { Stdout(line).newline; } void printEndLine(Ref r) { Stdout(r.type~" "~r.name).newline; } void print(char[] line) { Stdout(line); } void print(Ref r) { Stdout(r.type~" "~r.name); } void printCastFromTo(size_t t1, size_t t2) { if (t1 < t2) print(" = zext "); else print(" = trunc "); } void printCastFromTo(char[] t1, char[] t2) { printCastFromTo(intTypes.find(t1), intTypes.find(t2)); } void printCastFromTo(Ref* t1, Ref* t2) { printCastFromTo(intTypes.find(t1.type), intTypes.find(t2.type)); } private: char[] tabIndex; const char[] tabType = " "; // 4 spaces FuncDecl[char[]] functions; SimpleSymbolTable table; SymbolTable symbolTable; static char[][char[]] typeToLLVM; static char[][BinaryExp.Operator] opToLLVM; static char[][] intTypes = [ "i1", "i8", "i16", "i32", "i64" ]; } struct Ref { char[] type; char[] name; bool atomic = false; static Ref opCall(char[] type = "void", char[] name = "", bool atomic = false) { Ref r; if(auto llvm_t = type in LLVMGen.typeToLLVM) r.type = *llvm_t; else r.type = type; r.name = name; r.atomic = atomic; return r; } } class SimpleSymbolTable { int[char[]][] variables; void enterScope() { variables ~= cast(int[char[]])["__dollar":-1]; } void leaveScope() { variables.length = variables.length - 1; } char[] find(char[] v = "%.tmp") { foreach_reverse(map ; variables) { if(v in map) return v~"."~Integer.toString(++map[v]); } variables[$-1][v] = 0; return v; } }