view gen/LLVMGen.d @ 7:2ce5209f1954

Starting to work on bool support, for now == works
author Anders Halager <halager@gmail.com>
date Fri, 18 Apr 2008 12:50:54 +0200
parents 2c5a8f4c254a
children 2f493057cf17
line wrap: on
line source

module gen.LLVMGen;

import tango.io.Stdout,
       Int = tango.text.convert.Integer;
import tango.core.Array : find;

import ast.Decl,
       ast.Stmt,
       ast.Exp;

import lexer.Token;

import sema.SymbolTableBuilder;

class LLVMGen
{
public:
    this()
    {
        typeToLLVM =
        [
            "int"[]  : "i32"[],
            "byte"   : "i8",
            "short"  : "i16",
            "long"   : "i64",
            "bool"   : "i1",
            "float"  : "float",
            "double" : "double",
            "void"   : "void"
        ];
        alias BinaryExp.Operator op;
        opToLLVM = [
            op.Add      : "add"[],
            op.Sub      : "sub",
            op.Mul      : "mul",
            op.Div      : "div",
            op.Eq       : "icmp eq"
        ];
        table = new SimpleSymbolTable();
    }

    void gen(Decl[] decls)
    {
        // Fill in scopes
        (new SymbolTableBuilder).visit(decls);

        table.enterScope;

        foreach(decl ; decls)
                genRootDecl(decl);

        table.leaveScope;
    }

    void genRootDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.FuncDecl:
                FuncDecl funcDecl = cast(FuncDecl)decl;
                auto return_type = typeToLLVM[funcDecl.type.token.get];

                printBeginLine("define ");
                print(return_type);
                print(" @");
                genIdentifier(funcDecl.identifier);
                print("(");

                table.enterScope;
                Identifier[] args;
                foreach(i, funcArg ; funcDecl.funcArgs)
                {
                    print(typeToLLVM[funcArg.type.token.get]);
                    print(" %");
                    print("."~Integer.toString(i));
                    args ~= funcArg.identifier;
                    table.find(funcArg.identifier.get);
                    if(i+1 < funcDecl.funcArgs.length)
                        print(", ");
                }
                
                printEndLine(") {");
                
                indent;

                foreach(i, arg ; args)
                {
                    auto sym = arg.env.find(arg);
                    auto type = typeToLLVM[sym.type.get];
                    printBeginLine("%"~arg.get);
                    printEndLine(" = alloca " ~ type);
                    printBeginLine("store " ~ type ~ " %.");
                    print(Integer.toString(i));
                    print(", " ~ type ~ "* %");
                    printEndLine(arg.get);
                }

                printEndLine();

                foreach (stmt; funcDecl.statements)
                    genStmt(stmt);
                if (return_type == "void")
                {
                    printBeginLine("ret void");
                    printEndLine();
                }
                table.leaveScope;
                dedent;
                printBeginLine("}");
                printEndLine();
                
                break;

            case DeclType.VarDecl:
                auto varDecl = cast(VarDecl)decl;
                printBeginLine("@");
                genIdentifier(varDecl.identifier);
                
                print(" = ");
                if(varDecl.init)
                {
                    if(cast(IntegerLit)varDecl.init)
                        printEndLine("global i32 " ~ (cast(IntegerLit)varDecl.init).token.get);
                    else
                        assert(0,"Declaring an variable to an expression is not allowed");
                }
                else
                    printEndLine("i32 0");

                printEndLine();
        
            default:
        }
    }

    void genDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.VarDecl:
                auto varDecl = cast(VarDecl)decl;
                printBeginLine("%");
                print(table.find(varDecl.identifier.get));
                print(" = alloca ");
                printEndLine(typeToLLVM[varDecl.type.get]);
                if(varDecl.init)
                {
                    auto assignExp = new AssignExp(varDecl.identifier, varDecl.init);
                    assignExp.env = decl.env;
                    assignExp.identifier.env = decl.env;
                    genExpression(assignExp);
                }
        
            default:
        }
    }

    void unify(Ref* a, Ref* b)
    {
        if (a.type != b.type)
        {
            auto a_val = intTypes.find(a.type);
            auto b_val = intTypes.find(b.type);
            // swap types so a is always the "largest" type 
            if (a_val < b_val)
            {
                Ref* tmp = b;
                b = a;
                a = tmp;
            }

            auto res = table.find("%.cast");
            printBeginLine(res);
            printCastFromTo(b, a);
            print(*b);
            print(" to ");
            printEndLine(a.type);

            b.type = a.type;
            b.name = res;
        }
    }

    Ref genExpression(Exp exp)
    {
        switch(exp.expType)
        {
            case ExpType.Binary:
                auto binaryExp = cast(BinaryExp)exp;

                auto left = genExpression(binaryExp.left);
                auto right = genExpression(binaryExp.right);

                unify(&left, &right);

                auto res = Ref(left.type, table.find);
                printBeginLine(res.name);
                print(" = "~opToLLVM[binaryExp.op]~" ");
                print(left);
                print(", ");
                printEndLine(right.name);

                // exp always returns known type (== returns bool no matter
                // what the params are)
                if (binaryExp.resultType)
                    res.type = typeToLLVM[binaryExp.resultType];

                return res;
            case ExpType.IntegerLit:
                auto integetLit = cast(IntegerLit)exp;
                auto t = integetLit.token;
                return Ref("int", t.get, true);
            case ExpType.Negate:
                auto negateExp = cast(NegateExp)exp;
                auto target = genExpression(negateExp.exp);
                auto res = table.find;
                printBeginLine(res);
                print(" = sub "~target.type~" 0, ");
                printEndLine(target.name);
                return Ref(target.type, res);
            case ExpType.AssignExp:
                auto assignExp = cast(AssignExp)exp;
                auto sym = exp.env.find(assignExp.identifier);

                Ref val = genExpression(assignExp.exp);
                Ref r = Ref(typeToLLVM[sym.type.get], val.name);

                if (val.type != r.type)
                {
                    auto res = table.find("%.cast");
                    printBeginLine(res);
                    printCastFromTo(val.type, r.type);
                    print(val);
                    print(" to ");
                    printEndLine(r.type);
                    r.name = res;
                }

                printBeginLine("store ");
                print(r);
                print(", ");
                print(r.type ~ "* %");
                printEndLine(assignExp.identifier.get);
                break;
            case ExpType.CallExp:
                auto callExp = cast(CallExp)exp;
                auto func_sym = exp.env.find(cast(Identifier)callExp.exp);
                auto func_type = typeToLLVM[func_sym.type.get];
                Ref[] args;
                foreach(i, arg ; callExp.args)
                    args ~= genExpression(arg);

                char[] res = "";
                if (func_type != "void")
                {
                    res = table.find;
                    printBeginLine(res);
                    print(" = call ");
                }
                else
                    printBeginLine("call ");

                print(func_type);
                print(" @");
                
                print(func_sym.id.get);

                print("(");
                foreach(i, arg ; args)
                {
                    print(arg);
                    if(i+1 < args.length)
                        print(", ");
                }
                printEndLine(")");
                return Ref(func_sym.type.get, res);
            case ExpType.Identifier:
                auto identifier = cast(Identifier)exp;
                auto sym = exp.env.find(identifier);
                char[] res = table.find;
                printBeginLine(res);
                print(" = load ");
                print(typeToLLVM[sym.type.get]);
                print("* %");
                printEndLine(sym.id.name);
                return Ref(sym.type.get, res);
        }
        return Ref();
    }

    void genStmt(Stmt stmt)
    {
        switch(stmt.stmtType)
        {
            case StmtType.Return:
                auto ret = cast(ReturnStmt)stmt;
                auto sym = stmt.env.parentFunction();
                auto type = typeToLLVM[sym.type.get];

                Ref res = genExpression(ret.exp);

                if (type != res.type)
                {
                    auto cast_res = table.find("%.cast");
                    printBeginLine(cast_res);
                    printCastFromTo(res.type, type);
                    print(res);
                    print(" to ");
                    printEndLine(type);
                    res.name = cast_res;
                    res.type = type;
                }
                printBeginLine("ret ");
                printEndLine(res);
                break;
            case StmtType.Decl:
                auto declStmt = cast(DeclStmt)stmt;
                genDecl(declStmt.decl);
                break;
            case StmtType.Exp:
                auto expStmt = cast(ExpStmt)stmt;
                genExpression(expStmt.exp);
                break;
            case StmtType.If:
                auto ifStmt = cast(IfStmt)stmt;
                Ref val = genExpression(ifStmt.cond);
                auto cond = table.find("%.cond");
                printBeginLine(cond);
                print(" = icmp ne ");
                print(val);
                printEndLine(", 0");

                auto then_branch = table.find("then");
                auto else_branch = table.find("else");
                printBeginLine("br i1 ");
                print(cond);
                print(", label %");
                print(then_branch);
                print(", label %");
                printEndLine(else_branch);

                printBeginLine(then_branch);
                printEndLine(":");

                indent();
                foreach (s; ifStmt.then)
                    genStmt(s);
                printBeginLine("br label %");
                printEndLine(else_branch);
                dedent();

                printBeginLine(else_branch);
                printEndLine(":");

                break;
        }
    }

    void genIdentifier(Identifier identifier)
    {
        print(identifier.get);
    }

    void indent()
    {
        tabIndex ~= tabType;
    }

    void dedent()
    {
        tabIndex = tabIndex[0 .. $-tabType.length];
    }

    void printBeginLine(char[] line = "")
    {
        Stdout(tabIndex~line);
    }
    void printBeginLine(Ref r)
    {
        Stdout(tabIndex~r.type~" "~r.name);
    }

    void printEndLine(char[] line = "")
    {
        Stdout(line).newline;
    }

    void printEndLine(Ref r)
    {
        Stdout(r.type~" "~r.name).newline;
    }

    void print(char[] line)
    {
        Stdout(line);
    }

    void print(Ref r)
    {
        Stdout(r.type~" "~r.name);
    }

    void printCastFromTo(size_t t1, size_t t2)
    {
        if (t1 < t2)
            print(" = zext ");
        else
            print(" = trunc ");
    }

    void printCastFromTo(char[] t1, char[] t2)
    {
        printCastFromTo(intTypes.find(t1), intTypes.find(t2));
    }

    void printCastFromTo(Ref* t1, Ref* t2)
    {
        printCastFromTo(intTypes.find(t1.type), intTypes.find(t2.type));
    }

private:

    char[] tabIndex;
    const char[] tabType = "    "; // 4 spaces
    FuncDecl[char[]] functions;

    SimpleSymbolTable table;
    SymbolTable symbolTable;
    static char[][char[]] typeToLLVM;
    static char[][BinaryExp.Operator] opToLLVM;

    static char[][] intTypes = [ "i1", "i8", "i16", "i32", "i64" ];
}

struct Ref
{
    char[] type;
    char[] name;
    bool atomic = false;
    static Ref opCall(char[] type = "void", char[] name = "", bool atomic = false)
    {
        Ref r;
        if(auto llvm_t = type in LLVMGen.typeToLLVM)
            r.type = *llvm_t;
        else
            r.type = type;
        r.name = name;
        r.atomic = atomic;
        return r;
    }
}

class SimpleSymbolTable
{
    int[char[]][] variables;

    void enterScope()
    {
        variables ~= cast(int[char[]])["__dollar":-1];
    }

    void leaveScope()
    {
        variables.length = variables.length - 1;
    }

    char[] find(char[] v = "%.tmp")
    {
        foreach_reverse(map ; variables)
        {
            if(v in map)
                return v~"."~Integer.toString(++map[v]);
        }
        variables[$-1][v] = 0;
        return v;
    }
}