view gen/LLVMGen.d @ 26:b4dc2b2c0e38 new_gen

Added a DType class
author Anders Halager <halager@gmail.com>
date Sat, 19 Apr 2008 22:19:14 +0200
parents 14c1abba773f
children 9031487e97d7
line wrap: on
line source

module gen.LLVMGen;

import tango.io.Stdout,
       Int = tango.text.convert.Integer;
import tango.core.Array : find;

import llvm.llvm;

import ast.Decl,
       ast.Stmt,
       ast.Exp;

import lexer.Token;

import sema.SymbolTableBuilder,
       sema.Visitor;

private char[] genBuildCmp(char[] p)
{
    return `
        (Value l, Value r, char[] n)
        {
            return b.buildICmp(IntPredicate.`~p~`, l, r, n);
        }`;
}

class LLVMGen
{
public:
    this()
    {
        typeToLLVM =
        [
            "bool"[] : cast(Type) Type.Int1,
            "byte"   : Type.Int8,
            "short"  : Type.Int16,
            "int"    : Type.Int32,
            "long"   : Type.Int64,
            "float"  : Type.Float,
            "double" : Type.Double,
            "void"   : Type.Void
        ];
        alias BinaryExp.Operator op;
        b = new Builder;
        
        opToLLVM = [
            op.Add      : &b.buildAdd,
            op.Sub      : &b.buildSub,
            op.Mul      : &b.buildMul,
            op.Div      : &b.buildSDiv,
            op.Eq       : mixin(genBuildCmp("EQ")),
            op.Ne       : mixin(genBuildCmp("NE")),
            op.Lt       : mixin(genBuildCmp("SLT")),
            op.Le       : mixin(genBuildCmp("SLE")),
            op.Gt       : mixin(genBuildCmp("SGT")),
            op.Ge       : mixin(genBuildCmp("SGE"))
        ];
        table = new SimpleSymbolTable();
    }

    ~this()
    {
        b.dispose();
    }

    void gen(Decl[] decls)
    {
        // create module
        m = new Module("main_module");
        scope(exit) m.dispose();


        table.enterScope;

        auto registerFunc =
            (FuncDecl funcDecl)
            {
                Type[] param_types;
                foreach (param; funcDecl.funcArgs)
                    param_types ~= typeToLLVM[param.type.get];
                auto ret_t = typeToLLVM[funcDecl.type.get];
                auto func_t = FunctionType.Get(ret_t, param_types);
                auto llfunc = m.addFunction(func_t, funcDecl.identifier.get);
            };
        auto visitor = new VisitFuncDecls(registerFunc);
        visitor.visit(decls);

        foreach (decl; decls)
            genRootDecl(decl);

        table.leaveScope;

        char[] err;
        m.verify(err);
        Stderr(err).newline;

        //m.optimize(false);

        m.writeBitcodeToFile("test.bc");
    }

    void genRootDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.FuncDecl:
                FuncDecl funcDecl = cast(FuncDecl)decl;

                auto llfunc = m.getNamedFunction(funcDecl.identifier.get);
                auto ret_t = typeToLLVM[funcDecl.type.get];

                auto bb = llfunc.appendBasicBlock("entry");
                b.positionAtEnd(bb);

                table.enterScope;
                foreach (i, v; funcDecl.funcArgs)
                {
                    llfunc.getParam(i).name = v.identifier.get;
                    auto name = v.identifier.get;
                    auto AI = b.buildAlloca(llfunc.getParam(i).type, name);
                    b.buildStore(llfunc.getParam(i), AI);
                    table[name] = AI;
                }

                foreach (stmt; funcDecl.statements)
                    genStmt(stmt);

                // if the function didn't end with a return, we automatically
                // add one (return 0 as default)
                if (b.getInsertBlock().terminated() is false)
                    if (ret_t is Type.Void)
                        b.buildRetVoid();
                    else
                        b.buildRet(ConstantInt.GetS(ret_t, 0));

                table.leaveScope;
                break;

            case DeclType.VarDecl:
                auto varDecl = cast(VarDecl)decl;
                Type t = typeToLLVM[varDecl.type.get];
                GlobalVariable g = m.addGlobal(t, varDecl.identifier.get);
                g.initializer = ConstantInt.GetS(t, 0);
                table[varDecl.identifier.get] = g;
                break;
        
            default:
                break;
        }
    }

    void genDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.VarDecl:
                auto varDecl = cast(VarDecl)decl;
                Type t = typeToLLVM[varDecl.type.get];
                auto name = varDecl.identifier.get;
                auto AI = b.buildAlloca(t, name);
                table[name] = AI;
                if (varDecl.init)
                    buildAssign(varDecl.env.find(varDecl.identifier), varDecl.init);
                break;
        
            default:
        }
    }

    void sextSmallerToLarger(ref Value left, ref Value right)
    {
        if (left.type != right.type)
        {
            // try to find a convertion - only works for iX
            IntegerType l = cast(IntegerType) left.type;
            IntegerType r = cast(IntegerType) right.type;
            if (l is null || r is null)
                throw new Exception(
                        "Can't find a valid convertion between "
                        "a " ~ left.type.toString ~ " and a "
                        ~ right.type.toString);

            if (l.numBits() < r.numBits())
                left = b.buildSExt(left, r, ".cast");
            else
                right = b.buildSExt(right, l, ".cast");
        }
    }

    Value genExpression(Exp exp)
    {
        switch(exp.expType)
        {
            case ExpType.Binary:
                auto binaryExp = cast(BinaryExp)exp;

                auto left = genExpression(binaryExp.left);
                auto right = genExpression(binaryExp.right);

                sextSmallerToLarger(left, right);

                OpBuilder op = opToLLVM[binaryExp.op];

                return op(left, right, ".");

            case ExpType.IntegerLit:
                auto integetLit = cast(IntegerLit)exp;
                auto val = integetLit.token.get;
                return ConstantInt.GetS(Type.Int32, Integer.parse(val));
            case ExpType.Negate:
                auto negateExp = cast(NegateExp)exp;
                auto target = genExpression(negateExp.exp);
                return b.buildNeg(target, "neg");
            case ExpType.AssignExp:
                auto assignExp = cast(AssignExp)exp;
                return buildAssign(exp.env.find(assignExp.identifier), assignExp.exp);
            case ExpType.CallExp:
                auto callExp = cast(CallExp)exp;
                auto func_sym = exp.env.find(cast(Identifier)callExp.exp);
                Value[] args;
                foreach (arg; callExp.args)
                    args ~= genExpression(arg);
                return b.buildCall(m.getNamedFunction(func_sym.id.get), args, ".call");
            case ExpType.Identifier:
                auto identifier = cast(Identifier)exp;
                auto sym = exp.env.find(identifier);
                return b.buildLoad(table.find(sym.id.get), sym.id.get);
        }
        assert(0, "Reached end of switch in genExpression");
        return null;
    }

    void genStmt(Stmt stmt)
    {
        switch(stmt.stmtType)
        {
            case StmtType.Return:
                auto ret = cast(ReturnStmt)stmt;
                auto sym = stmt.env.parentFunction();
                Type t = typeToLLVM[sym.type.name];
                Value v = genExpression(ret.exp);
                if (v.type != t)
                {
                    IntegerType v_t = cast(IntegerType) v.type;
                    IntegerType i_t = cast(IntegerType) t;
                    if (v_t is null || i_t is null)
                        throw new Exception(
                                "Inappropriate assignment"
                                ", types dont match");
                    if (v_t.numBits() < i_t.numBits())
                        v = b.buildSExt(v, t, ".cast");
                    else
                        v = b.buildTrunc(v, t, ".cast");
                }
                b.buildRet(v);
                break;
            case StmtType.Decl:
                auto declStmt = cast(DeclStmt)stmt;
                genDecl(declStmt.decl);
                break;
            case StmtType.Exp:
                auto expStmt = cast(ExpStmt)stmt;
                genExpression(expStmt.exp);
                break;
            case StmtType.If:
                auto ifStmt = cast(IfStmt)stmt;
                Value cond = genExpression(ifStmt.cond);
                if (cond.type !is Type.Int1)
                {
                    Value False = ConstantInt.GetS(cond.type, 0);
                    cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond");
                }
                auto func_name = stmt.env.parentFunction().id.get;
                Function func = m.getNamedFunction(func_name);
                bool has_else = ifStmt.else_body.length > 0;

                auto thenBB = func.appendBasicBlock("then");
                auto elseBB = has_else? func.appendBasicBlock("else") : null;
                auto mergeBB = func.appendBasicBlock("merge");

                b.buildCondBr(cond, thenBB, has_else? elseBB : mergeBB);
                b.positionAtEnd(thenBB);
                foreach (s; ifStmt.then_body)
                    genStmt(s);
                if (b.getInsertBlock().terminated() is false)
                    b.buildBr(mergeBB);
                thenBB = b.getInsertBlock();

                if (has_else)
                {
                    b.positionAtEnd(elseBB);
                    foreach (s; ifStmt.else_body)
                        genStmt(s);
                    b.buildBr(mergeBB);
                    elseBB = b.getInsertBlock();
                }

                b.positionAtEnd(mergeBB);
                break;
            case StmtType.While:
                auto wStmt = cast(WhileStmt)stmt;
                auto func_name = stmt.env.parentFunction().id.get;
                Function func = m.getNamedFunction(func_name);

                auto condBB = func.appendBasicBlock("cond");
                auto bodyBB = func.appendBasicBlock("body");
                auto doneBB = func.appendBasicBlock("done");

                b.buildBr(condBB);
                b.positionAtEnd(condBB);
                Value cond = genExpression(wStmt.cond);
                if (cond.type !is Type.Int1)
                {
                    Value False = ConstantInt.GetS(cond.type, 0);
                    cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond");
                }
                b.buildCondBr(cond, bodyBB, doneBB);

                b.positionAtEnd(bodyBB);
                foreach (s; wStmt.stmts)
                    genStmt(s);
                if (b.getInsertBlock().terminated() is false)
                    b.buildBr(condBB);

                b.positionAtEnd(doneBB);
                break;
        }
    }

    private Value buildAssign(Symbol sym, Exp exp)
    {
        Type t = typeToLLVM[sym.type.name];
        auto name = sym.id.get;
        auto AI = table.find(name);
        Value v = genExpression(exp);
        if (v.type != t)
        {
            IntegerType v_t = cast(IntegerType) v.type;
            IntegerType i_t = cast(IntegerType) t;
            if (v_t is null || i_t is null)
                throw new Exception(
                        "Inappropriate assignment, types don't match");

            if (v_t.numBits() < i_t.numBits())
                v = b.buildSExt(v, t, ".cast");
            else
                v = b.buildTrunc(v, t, ".cast");
        }
        return b.buildStore(v, AI);
    }

private:

    // llvm stuff
    Module m;
    Builder b;

    FuncDecl[char[]] functions;

    SimpleSymbolTable table;
    static Type[char[]] typeToLLVM;
    alias Value delegate(Value, Value, char[]) OpBuilder;
    static OpBuilder[BinaryExp.Operator] opToLLVM;
}

private class VisitFuncDecls : Visitor!(void)
{
    void delegate(FuncDecl) dg;
    this(void delegate(FuncDecl funcDecl) dg)
    {
        this.dg = dg;
    }

    override void visit(Decl[] decls)
    {
        foreach (decl; decls)
            if (auto f = cast(FuncDecl)decl)
                dg(f);
    }
}

private class SimpleSymbolTable
{
    Value[char[]][] namedValues;

    void enterScope()
    {
        namedValues ~= cast(Value[char[]])["__dollar":null];
    }

    void leaveScope()
    {
        namedValues.length = namedValues.length - 1;
    }

    Value put(Value val, char[] key)
    {
        namedValues[$ - 1][key] = val;
        return val;
    }

    Value find(char[] key)
    {
        foreach_reverse (map; namedValues)
            if(auto val_ptr = key in map)
                return *val_ptr;
        return null;
    }

    alias find opIndex;
    alias put opIndexAssign;
}