view gen/LLVMGen.d @ 39:1a7a308f75b2 new_gen

Added some struct tests, and implemented a wrong struct assignment It assumes 8 bytes for all struct, we have no DType available at that point Slight improvement to an error message (Member access to unknown members)
author Anders Halager <halager@gmail.com>
date Mon, 21 Apr 2008 22:47:12 +0200
parents 858b9805843d
children 9fb190ad81a4
line wrap: on
line source

module gen.LLVMGen;

import tango.io.Stdout,
       Int = tango.text.convert.Integer;
import tango.core.Array : find;

import llvm.llvm;

import ast.Decl,
       ast.Stmt,
       ast.Exp;

import misc.Error;

import lexer.Token;

import sema.SymbolTableBuilder,
       sema.Visitor;

private char[] genBuildCmp(char[] p)
{
    return `
        (Value l, Value r, char[] n)
        {
            return b.buildICmp(IntPredicate.`~p~`, l, r, n);
        }`;
}

class LLVMGen
{
public:
    this()
    {
        alias BinaryExp.Operator op;
        b = new Builder;
        
        opToLLVM = [
            op.Add      : &b.buildAdd,
            op.Sub      : &b.buildSub,
            op.Mul      : &b.buildMul,
            op.Div      : &b.buildSDiv,
            op.Eq       : mixin(genBuildCmp("EQ")),
            op.Ne       : mixin(genBuildCmp("NE")),
            op.Lt       : mixin(genBuildCmp("SLT")),
            op.Le       : mixin(genBuildCmp("SLE")),
            op.Gt       : mixin(genBuildCmp("SGT")),
            op.Ge       : mixin(genBuildCmp("SGE"))
        ];
        table = new SimpleSymbolTable();
    }

    ~this()
    {
        b.dispose();
    }

    void gen(Decl[] decls)
    {
        // create module
        m = new Module("main_module");
        scope(exit) m.dispose();

        table.enterScope;

        BytePtr = PointerType.Get(Type.Int8);
        auto temp = FunctionType.Get(Type.Void, [BytePtr, BytePtr, Type.Int32, Type.Int32]);
        llvm_memcpy = m.addFunction(temp, "llvm.memcpy.i32");
        auto registerFunc =
            (FuncDecl fd)
            {
                Type[] param_types;
                foreach (p; fd.funcArgs)
                {
                    DType t = p.env.find(p.identifier).type;
                    if(cast(DStruct)t)
                    {
                        Type pointer = PointerType.Get(t.llvm());
                        param_types ~= pointer;
                    }
                    else
                        param_types ~= t.llvm();
                }
                auto ret_t = fd.env.find(fd.identifier).type.llvm();
                auto func_t = FunctionType.Get(ret_t, param_types);
                auto llfunc = m.addFunction(func_t, fd.identifier.get);
            };
        auto visitor = new VisitFuncDecls(registerFunc);
        visitor.visit(decls);

        foreach (decl; decls)
            genRootDecl(decl);

        table.leaveScope;

        char[] err;
        m.verify(err);
        Stderr(err).newline;

        // m.optimize(true);

        m.writeBitcodeToFile("test.bc");
    }

    void genRootDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.FuncDecl:
                FuncDecl funcDecl = cast(FuncDecl)decl;

                auto llfunc = m.getNamedFunction(funcDecl.identifier.get);
                auto func_tp = cast(PointerType)llfunc.type;
                auto func_t = cast(FunctionType)func_tp.elementType();
                auto ret_t = func_t.returnType();

                auto bb = llfunc.appendBasicBlock("entry");
                b.positionAtEnd(bb);

                table.enterScope;
                foreach (i, v; funcDecl.funcArgs)
                {
                    llfunc.getParam(i).name = v.identifier.get;
                    auto name = v.identifier.get;
                    if(!cast(PointerType)llfunc.getParam(i).type)
                    {
                        auto AI = b.buildAlloca(llfunc.getParam(i).type, name);
 //                       Value va = b.buildLoad(val, name);
                        b.buildStore(llfunc.getParam(i), AI);
                        table[name] = AI;
                    }
                    else
                        table[name] = llfunc.getParam(i);
                }

                foreach (stmt; funcDecl.statements)
                    genStmt(stmt);

                // if the function didn't end with a return, we automatically
                // add one (return 0 as default)
                if (b.getInsertBlock().terminated() is false)
                    if (ret_t is Type.Void)
                        b.buildRetVoid();
                    else
                        b.buildRet(ConstantInt.GetS(ret_t, 0));

                table.leaveScope;
                break;

            case DeclType.VarDecl:
                auto varDecl = cast(VarDecl)decl;
                auto sym = varDecl.env.find(varDecl.identifier);
                Type t = sym.type.llvm();
                GlobalVariable g = m.addGlobal(t, sym.id.get);
                g.initializer = ConstantInt.GetS(t, 0);
                table[varDecl.identifier.get] = g;
                break;
        
            case DeclType.StructDecl:
                auto structDecl = cast(StructDecl)decl;
                Type[] types;
                foreach(varDecl ; structDecl.vars)
                {
                    auto sym = varDecl.env.find(varDecl.identifier);
                    Type t = sym.type.llvm();
                    types ~= t;
                }

                StructType t = StructType.Get(types);
                m.addTypeName(structDecl.identifier.get, t);
//                table[structDecl.identifier.get] = g;

                break;

            default:
                break;
        }
    }

    void genDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.VarDecl:
                auto varDecl = cast(VarDecl)decl;
                auto name = varDecl.identifier.get;
                auto sym = varDecl.env.find(varDecl.identifier);
                auto AI = b.buildAlloca(sym.type.llvm(), name);
                table[name] = AI;
                if (varDecl.init)
                    buildAssign(varDecl.identifier, varDecl.init);
                break;
        
            default:
        }
    }

    struct PE
    {
        static char[] NoImplicitConversion =
            "Can't find an implicit conversion between %0 and %1";
        static char[] VoidRetInNonVoidFunc =
            "Only void functions can return without an expression";
    }

    void sextSmallerToLarger(ref Value left, ref Value right)
    {
        if (left.type != right.type)
        {
            // try to find a convertion - only works for iX
            IntegerType l = cast(IntegerType) left.type;
            IntegerType r = cast(IntegerType) right.type;
            if (l is null || r is null)
                throw error(__LINE__, PE.NoImplicitConversion)
                    .arg(left.type.toString)
                    .arg(right.type.toString);

            if (l.numBits() < r.numBits())
                left = b.buildSExt(left, r, ".cast");
            else
                right = b.buildSExt(right, l, ".cast");
        }
    }

    Value genExpression(Exp exp)
    {
        switch(exp.expType)
        {
            case ExpType.Binary:
                auto binaryExp = cast(BinaryExp)exp;

                auto left = genExpression(binaryExp.left);
                auto right = genExpression(binaryExp.right);

                sextSmallerToLarger(left, right);

                OpBuilder op = opToLLVM[binaryExp.op];

                return op(left, right, ".");

            case ExpType.IntegerLit:
                auto integetLit = cast(IntegerLit)exp;
                auto val = integetLit.token.get;
                return ConstantInt.GetS(Type.Int32, Integer.parse(val));
            case ExpType.Negate:
                auto negateExp = cast(NegateExp)exp;
                auto target = genExpression(negateExp.exp);
                return b.buildNeg(target, "neg");
            case ExpType.AssignExp:
                auto assignExp = cast(AssignExp)exp;
                return buildAssign(assignExp.identifier, assignExp.exp);
            case ExpType.CallExp:
                auto callExp = cast(CallExp)exp;
                auto func_sym = exp.env.find(cast(Identifier)callExp.exp);
                Value[] args;
                foreach (arg; callExp.args)
                {
                    Value v = genExpression(arg);

                    if(auto ptr = cast(PointerType)v.type)
                    {
                        args ~= v;
                    }
                    else
                        args ~= v;

                }
                return b.buildCall(m.getNamedFunction(func_sym.id.get), args, ".call");
            case ExpType.Identifier:
                auto identifier = cast(Identifier)exp;
                auto sym = exp.env.find(identifier);
                if(cast(DStruct)sym.type)
                    return table.find(sym.id.get);
                else
                    return b.buildLoad(table.find(sym.id.get), sym.id.get);
            case ExpType.MemberLookup:
                auto v = getPointer(exp);
                return b.buildLoad(v, v.name);
        }
        assert(0, "Reached end of switch in genExpression");
        return null;
    }

    void genStmt(Stmt stmt)
    {
        switch(stmt.stmtType)
        {
            case StmtType.Return:
                auto ret = cast(ReturnStmt)stmt;
                auto sym = stmt.env.parentFunction();
                Type t = sym.type.llvm();
                if (ret.exp is null)
                    if (t is Type.Void)
                    {
                        b.buildRetVoid();
                        return;
                    }
                    else
                        throw error(__LINE__, PE.VoidRetInNonVoidFunc);

                Value v = genExpression(ret.exp);
                if (v.type != t)
                {
                    IntegerType v_t = cast(IntegerType) v.type;
                    IntegerType i_t = cast(IntegerType) t;
                    if (v_t is null || i_t is null)
                        throw error(__LINE__, PE.NoImplicitConversion)
                            .arg(v.type.toString)
                            .arg(t.toString);

                    if (v_t.numBits() < i_t.numBits())
                        v = b.buildSExt(v, t, ".cast");
                    else
                        v = b.buildTrunc(v, t, ".cast");
                }
                b.buildRet(v);
                break;
            case StmtType.Decl:
                auto declStmt = cast(DeclStmt)stmt;
                genDecl(declStmt.decl);
                break;
            case StmtType.Exp:
                auto expStmt = cast(ExpStmt)stmt;
                genExpression(expStmt.exp);
                break;
            case StmtType.If:
                auto ifStmt = cast(IfStmt)stmt;
                Value cond = genExpression(ifStmt.cond);
                if (cond.type !is Type.Int1)
                {
                    Value False = ConstantInt.GetS(cond.type, 0);
                    cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond");
                }
                auto func_name = stmt.env.parentFunction().id.get;
                Function func = m.getNamedFunction(func_name);
                bool has_else = ifStmt.else_body.length > 0;

                auto thenBB = func.appendBasicBlock("then");
                auto elseBB = has_else? func.appendBasicBlock("else") : null;
                auto mergeBB = func.appendBasicBlock("merge");

                b.buildCondBr(cond, thenBB, has_else? elseBB : mergeBB);
                b.positionAtEnd(thenBB);
                foreach (s; ifStmt.then_body)
                    genStmt(s);
                if (b.getInsertBlock().terminated() is false)
                    b.buildBr(mergeBB);
                thenBB = b.getInsertBlock();

                if (has_else)
                {
                    b.positionAtEnd(elseBB);
                    foreach (s; ifStmt.else_body)
                        genStmt(s);
                    b.buildBr(mergeBB);
                    elseBB = b.getInsertBlock();
                }

                b.positionAtEnd(mergeBB);
                break;
            case StmtType.While:
                auto wStmt = cast(WhileStmt)stmt;
                auto func_name = stmt.env.parentFunction().id.get;
                Function func = m.getNamedFunction(func_name);

                auto condBB = func.appendBasicBlock("cond");
                auto bodyBB = func.appendBasicBlock("body");
                auto doneBB = func.appendBasicBlock("done");

                b.buildBr(condBB);
                b.positionAtEnd(condBB);
                Value cond = genExpression(wStmt.cond);
                if (cond.type !is Type.Int1)
                {
                    Value False = ConstantInt.GetS(cond.type, 0);
                    cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond");
                }
                b.buildCondBr(cond, bodyBB, doneBB);

                b.positionAtEnd(bodyBB);
                foreach (s; wStmt.stmts)
                    genStmt(s);
                if (b.getInsertBlock().terminated() is false)
                    b.buildBr(condBB);

                b.positionAtEnd(doneBB);
                break;
            case StmtType.Switch:
                auto sw = cast(SwitchStmt)stmt;
                Value cond = genExpression(sw.cond);

                auto func_name = stmt.env.parentFunction().id.get;
                Function func = m.getNamedFunction(func_name);

                BasicBlock oldBB = b.getInsertBlock();
                BasicBlock defBB;
                BasicBlock endBB = func.appendBasicBlock("sw.end");
                if (sw.defaultBlock)
                {
                    defBB = Function.InsertBasicBlock(endBB, "sw.def");
                    b.positionAtEnd(defBB);
                    foreach (case_statement; sw.defaultBlock)
                        genStmt(case_statement);
                    if (!defBB.terminated())
                        b.buildBr(endBB);
                    b.positionAtEnd(oldBB);
                }
                else
                    defBB = endBB;
                auto SI = b.buildSwitch(cond, defBB, sw.cases.length);
                foreach (c; sw.cases)
                {
                    BasicBlock prevBB;
                    foreach (i, val; c.values)
                    {
                        auto BB = Function.InsertBasicBlock(defBB, "sw.bb");
                        SI.addCase(ConstantInt.GetS(cond.type, c.values_converted[i]), BB);

                        if (i + 1 == c.values.length)
                        {
                            b.positionAtEnd(BB);
                            foreach (case_statement; c.stmts)
                                genStmt(case_statement);
                            if (!BB.terminated())
                                b.buildBr(c.followedByDefault? defBB : endBB);
                        }

                        if (prevBB !is null && !prevBB.terminated())
                        {
                            b.positionAtEnd(prevBB);
                            b.buildBr(BB);
                        }
                        prevBB = BB;
                    }
                }
                b.positionAtEnd(endBB);
                break;
        }
    }

    Value getPointer(Exp exp)
    {
        switch(exp.expType)
        {
            case ExpType.Identifier:
                auto identifier = cast(Identifier)exp;
                auto sym = exp.env.find(identifier);
                return table.find(sym.id.get);
            case ExpType.MemberLookup:
                auto mem = cast(MemberLookup)exp;
                switch(mem.target.expType)
                {
                    case ExpType.Identifier:
                        auto identifier = cast(Identifier)mem.target;
                        auto child = mem.child;
                        auto sym = exp.env.find(identifier);
                        auto symChild = child.env.find(child);
                        Value v = table.find(sym.id.get);
                        DType t = sym.type;
                        auto st = cast(DStruct)t;

                        int i = 0;
                        foreach(char[] name, DType type ; st.members)
                            if(name == child.get)
                                break;
                            else
                                i++;

                        Value[] vals;   
                        vals ~= ConstantInt.Get(IntegerType.Int32, 0, false);
                        vals ~= ConstantInt.Get(IntegerType.Int32, i, false);

                        Value val = b.buildGEP(v, vals, sym.id.get~"."~child.get);
                        return val;

                    default:
                        Value val = genExpression(exp);
                        auto AI = b.buildAlloca(val.type, ".s");
                        return b.buildStore(val, AI);
                }
            default:
                Value val = genExpression(exp);
                auto AI = b.buildAlloca(val.type, ".s");
                return b.buildStore(val, AI);
        }
        assert(0, "Reached end of switch in getPointer");
        return null;
    }

    private Value buildAssign(Exp target, Exp exp)
    {
        Value t = getPointer(target);
        Value v = genExpression(exp);

        auto a = cast(PointerType)t.type;

        assert(a, "Assing to type have to be of type PointerType");

        Type value_type = v.type;
        if (auto value_ptr = cast(PointerType)v.type)
        {
            value_type = value_ptr.elementType;

            if (a.elementType is value_type && cast(StructType)value_type)
            {
                // bitcast "from" to i8*
                Value from = b.buildBitCast(v, BytePtr, ".copy_from");
                // bitcast "to" to i8*
                Value to = b.buildBitCast(t, BytePtr, ".copy_to");
                // call llvm.memcpy.i32( "from", "to", type_size, alignment (32 in clang) );
                b.buildCall(llvm_memcpy, [from, to, ConstantInt.GetS(Type.Int32, 8), ConstantInt.GetS(Type.Int32, 32)], null);
                // return "to"
                return t;
            }
        }

        if (value_type != a.elementType)
        {
            IntegerType v_t = cast(IntegerType) value_type;
            IntegerType i_t = cast(IntegerType) a.elementType;
            if (v_t is null || i_t is null)
                throw error(__LINE__, PE.NoImplicitConversion)
                    .arg(a.elementType.toString)
                    .arg(v.type.toString);

            if (v_t.numBits() < i_t.numBits())
                v = b.buildSExt(v, a.elementType, ".cast");
            else
                v = b.buildTrunc(v, a.elementType, ".cast");
        }
        return b.buildStore(v, t);
    }

    Error error(uint line, char[] msg)
    {
        return new Error(msg);
    }

private:

    // llvm stuff
    Module m;
    Builder b;
    Function llvm_memcpy;
    Type BytePtr;

    FuncDecl[char[]] functions;

    SimpleSymbolTable table;
    alias Value delegate(Value, Value, char[]) OpBuilder;
    static OpBuilder[BinaryExp.Operator] opToLLVM;
}

private class VisitFuncDecls : Visitor!(void)
{
    void delegate(FuncDecl) dg;
    this(void delegate(FuncDecl funcDecl) dg)
    {
        this.dg = dg;
    }

    override void visit(Decl[] decls)
    {
        foreach (decl; decls)
            if (auto f = cast(FuncDecl)decl)
                dg(f);
    }
}

private class SimpleSymbolTable
{
    Value[char[]][] namedValues;

    void enterScope()
    {
        namedValues ~= cast(Value[char[]])["__dollar":null];
    }

    void leaveScope()
    {
        namedValues.length = namedValues.length - 1;
    }

    Value put(Value val, char[] key)
    {
        namedValues[$ - 1][key] = val;
        return val;
    }

    Value find(char[] key)
    {
        foreach_reverse (map; namedValues)
            if(auto val_ptr = key in map)
                return *val_ptr;
        return null;
    }

    alias find opIndex;
    alias put opIndexAssign;
}