view gen/CodeGen.d @ 173:50b98a06a200

Start of support for virtual functions
author Anders Halager <halager@gmail.com>
date Thu, 24 Jul 2008 20:40:04 +0200
parents f0385c044065
children 20ff3c31f600
line wrap: on
line source

module gen.CodeGen;

import tango.io.Stdout,
       Int = tango.text.convert.Integer;
import tango.core.Array : find, partition;

import llvm.llvm;

import ast.Decl,
       ast.Stmt,
       ast.Exp,
       ast.Module : DModule = Module;

import basic.SmallArray,
       basic.Attribute,
       basic.LiteralParsing;

import lexer.Token;

import sema.Scope,
       sema.Visitor;

/**
  Wrapper for Values representing rvalues (things you can only read from)
 **/
private struct RValue
{
    /**
      Returns true if this is a simple value, like an int or a pointer.
      This is basicly anything except a struct, which will contain a Value that
      is a pointer to the struct.
     **/
    bool isSimple() { return simple; }
    /// Opposite of isSimple
    bool isAggregate() { return !simple; }

    Value value;
    private bool simple = true;
}

/**
  Wrapper for Values representing lvalues (things you can write to)
 **/
private struct LValue
{
    Value getAddress() { return value; }
    private Value value;
}

class CodeGen
{
public:
    this()
    {
        b = new Builder;
        ZeroIndex = ConstantInt.GetU(Type.Int32, 0);

        table = new SimpleSymbolTable();

        createBasicTypes();
    }

    /**
       Generate a new module.
     **/
    /*
        Find all function decls and add the functions to the llvm module, so
        they can be referenced.

        Make sure all var-decls are located before functions, so we wont get
        problems when referencing the global vars.

        Generate the actual llvm code needed for all decls

        Optimize if requested

        Write to filehandle (can be a file or stdout)
     */
    void gen(DModule mod, uint handle, bool optimize, bool inline)
    {
        this.mod = mod;
        // create module
        m = new .llvm.llvm.Module("main_module");
        scope(exit) m.dispose();

        table.enterScope;

        BytePtr = PointerType.Get(Type.Int8);
        auto temp = FunctionType.Get(
                Type.Void,
                [BytePtr, BytePtr, Type.Int32, Type.Int32]);
        llvm_memcpy = m.addFunction(temp, "llvm.memcpy.i32");

        auto registerFunc =
            (FuncDecl fd)
            {
                Type[] param_types;
                foreach (i, p; fd.funcArgs)
                {
                    DType t = p.identifier.type;
                    if (auto st = t.asStruct())
                    {
                        Type pointer = PointerType.Get(llvm(st));
                        param_types ~= pointer;
                    }
                    else if (auto ar = t.asArray())
                    {
                        Type pointer = PointerType.Get(llvm(ar));
                        param_types ~= pointer;
                    }
                    else
                        param_types ~= llvm(t);
                }
                auto ret_t = fd.identifier.type;
                if(auto st = cast(DStruct)ret_t)
                    ret_t = DType.Void;
                else if(auto f = cast(DFunction)ret_t)
                    ret_t = f.returnType;
                auto func_t = FunctionType.Get(llvm(ret_t), param_types);
                auto llfunc = m.addFunction(func_t, symbolName(fd));

                foreach (i, p; fd.funcArgs)
                {
                    if(i == 0 && fd.sret)
                        llfunc.addParamAttr(0, ParamAttr.StructRet);

                    DType t = p.identifier.type;
                    if (auto st = t.asStruct)
                    {
                        if (i == 0 && fd.sret)
                            continue;
                        llfunc.addParamAttr(i, ParamAttr.ByVal);
                    }
                    else if (auto ar = t.asArray)
                    {
                        llfunc.addParamAttr(i, ParamAttr.ByVal);
                    }
                }
            };
        auto visitor = new VisitFuncDecls(registerFunc);
        visitor.visit([mod]);
        // Before beginning we move all top level var-decls to the start
        // and then we generate the var-decls first
        // partition is NOT required to be stable, but that should not create
        // any problems.
        partition(mod.decls, (Decl d) { return d.declType == DeclType.VarDecl; });

        foreach (decl; mod.decls)
            genRootDecl(decl);

        table.leaveScope;

//        debug m.verify();

        if(optimize)
            m.optimize(inline);

        m.writeBitcodeToFileHandle(handle);
    }

private:
    /**
      Generate a single top-level Decl
     **/
    void genRootDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.FuncDecl:
                FuncDecl funcDecl = cast(FuncDecl)decl;

                // Empty function - declare;
                if(funcDecl.emptyFunction)
                    return;

                llvm(funcDecl.type);
                auto llfunc = m.getNamedFunction(symbolName(funcDecl));
                auto func_tp = cast(PointerType)llfunc.type;
                auto func_t = cast(FunctionType)func_tp.elementType();
                auto ret_t = func_t.returnType();


                auto bb = llfunc.appendBasicBlock("entry");
                b.positionAtEnd(bb);

                table.enterScope;
                foreach (i, v; funcDecl.funcArgs)
                {
                    llfunc.getParam(i).name = v.identifier.get;
                    auto name = v.identifier.get;
                    if (!cast(PointerType)llfunc.getParam(i).type)
                    {
                        auto AI = b.buildAlloca(llfunc.getParam(i).type, name);
 //                       Value va = b.buildLoad(val, name);
                        b.buildStore(llfunc.getParam(i), AI);
                        table[name] = AI;
                    }
                    else
                        table[name] = llfunc.getParam(i);
                }

                foreach (stmt; funcDecl.statements)
                    genStmt(stmt);

                // if the function didn't end with a return, we automatically
                // add one (return 0 as default)
                if (b.getInsertBlock().terminated() is false)
                    if (ret_t is Type.Void)
                        b.buildRetVoid();
                    else
                        b.buildRet(ConstantInt.GetS(ret_t, 0));

                table.leaveScope;
                break;

            case DeclType.VarDecl:
                auto varDecl = cast(VarDecl)decl;
                auto id = varDecl.identifier;
                Type t = llvm(id.type);
                GlobalVariable g = m.addGlobal(t, id.get);
                g.initializer = ConstantInt.GetS(t, 0);
                table[varDecl.identifier.get] = g;
                break;

            case DeclType.ClassDecl:
                auto cdecl = cast(ClassDecl)decl;
                SmallArray!(Constant) functions;
                foreach (d; cdecl.decls)
                {
                    auto func = cast(FuncDecl)d;
                    if (func is null)
                        continue;
                    genRootDecl(func);
                    auto llvm_f = m.getNamedFunction(symbolName(func));
                    functions ~= Constant.GetBitCast(llvm_f, BytePtr);
                }
                auto class_vtbl = ConstantArray.Get(BytePtr, functions.unsafe());
                auto gv = m.addGlobal(class_vtbl, cdecl.identifier.get ~ "_vtable");
                gv.linkage = Linkage.Weak;
                gv.globalConstant = true;
                break;

            default:
                break;
        }
    }

    /**
      Generate a single local Decl
     **/
    void genDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.VarDecl:
                auto varDecl = cast(VarDecl)decl;
                auto id = varDecl.identifier;
                auto name = id.get;
                auto AI = b.buildAlloca(llvm(id.type), name);
                table[name] = AI;
                if (varDecl.init)
                {
                    LValue dst = genLValue(varDecl.identifier);
                    RValue src = genExpression(varDecl.init);
                    storeThroughLValue(dst, src, id.type);
                }
                break;

            case DeclType.FuncDecl,
                 DeclType.StructDecl,
                 DeclType.ClassDecl:
                genRootDecl(decl);
                break;

            default:
                break;
        }
    }

    // Remove - do it right (basic/Messages.d)
    struct PE
    {
        static char[] NoImplicitConversion =
            "Can't find an implicit conversion between %0 and %1";
        static char[] VoidRetInNonVoidFunc =
            "Only void functions can return without an expression";
    }

    /**
      Generate a single expression.

      This is the most general way of generating expressions and therefore
      returns an RValue.
     **/
    RValue genExpression(Exp exp)
    {
        switch(exp.expType)
        {
            case ExpType.Binary:
                return RValue(genBinExp(cast(BinaryExp)exp));
            case ExpType.IntegerLit:
                auto integerLit = cast(IntegerLit)exp;
                switch(integerLit.number.type)
                {
                    case NumberType.Int:
                        return RValue(ConstantInt.GetS(Type.Int32, integerLit.number.integer));
                    case NumberType.Long:
                        return RValue(ConstantInt.GetS(Type.Int64, integerLit.number.integer));
                    case NumberType.ULong:
                        return RValue(ConstantInt.GetS(Type.Int64, integerLit.number.integer));
                    case NumberType.Float:
                        return RValue(ConstantReal.Get(Type.Float, integerLit.number.floating));
                    case NumberType.Double:
                        return RValue(ConstantReal.Get(Type.Double, integerLit.number.floating));
                    case NumberType.Real:
                        return RValue(ConstantReal.Get(Type.X86_FP80, integerLit.number.floating));
                }
            case ExpType.StringExp:
                auto stringExp = cast(StringExp)exp;
                char[] data = cast(char[])stringExp.data;
                auto string_constant = ConstantArray.GetString(data, true);
                auto gv = m.addGlobal(string_constant, "string");
                gv.linkage = Linkage.Internal;
                gv.globalConstant = true;
                return RValue(gv);
            case ExpType.Negate:
                auto negateExp = cast(NegateExp)exp;
                auto target = genExpression(negateExp.exp);
                return RValue(b.buildNeg(target.value, "neg"));
            case ExpType.Deref:
                auto derefExp = cast(DerefExp)exp;
                auto target = genExpression(derefExp.exp);
                return RValue(b.buildLoad(target.value, "deref"));
            case ExpType.AssignExp:
                auto AE = cast(AssignExp)exp;
                LValue dst = genLValue(AE.identifier);
                RValue src = genExpression(AE.exp);
                storeThroughLValue(dst, src, AE.exp.type());
                return src;
            case ExpType.Index:
                auto indexExp = cast(IndexExp)exp;
                return loadLValue(genLValue(exp));
            case ExpType.NewExp:
                auto newExp = cast(NewExp)exp;
                DClass type = newExp.newType.type().asClass();
                auto llvm_type = cast(PointerType)llvm(type);
                auto pointer = b.buildMalloc(llvm_type.elementType(), "new");
                scope args = new Value[newExp.c_args.length];
                foreach (i, arg; newExp.c_args)
                    args[i] = genExpression(arg).value;
                auto f = m.getNamedFunction(newExp.callSym.getMangledFQN());
                b.buildCall(f, args, "");
                return RValue(pointer);
            case ExpType.CallExp:
                auto callExp = cast(CallExp)exp;
                // BUG: Might not be a simple identifier, a.foo(x) is also a
                // valid call - or foo(x)(y) for that matter.

                // if type of exp is DFunction - safe to call getSymbol and FQN
                // if function ptr: do something else
                // if delegate do a third thing
                // if struct/class check for opCall
                DType type = callExp.exp.type;
                assert (type.isFunction(), "Can only call functions");
                scope args = new Value[callExp.args.length];
                foreach (i, arg; callExp.args)
                    args[i] = genExpression(arg).value;
                DFunction ftype = type.asFunction();
                Type llvm_ftype = llvm(ftype);
                Value f = null;
                if (callExp.callSym is null)
                {
                    // Do a virtual function call
                    f = genLValue(callExp.exp).getAddress();
                    f = b.buildLoad(f, "func_pointer");
                    f = b.buildBitCast(
                            f,
                            PointerType.Get(llvm_ftype),
                            ftype.name);
                }
                else
                {
                    auto sym = callExp.callSym;
                    f = m.getNamedFunction(sym.getMangledFQN());
                }
                bool isVoid = ftype.returnType is DType.Void;
                auto r = b.buildCall(f, args, isVoid? "" : "call");
                return RValue(r);
            case ExpType.CastExp:
                auto castExp = cast(CastExp)exp;
                exp = castExp.exp;
                auto value = genExpression(exp);

                if (!exp.type.hasImplicitConversionTo(castExp.type))
                    assert(0, "Invalid cast");

                return genTypeCast(value, exp.type, castExp.type);

            case ExpType.Identifier:
                auto id = cast(Identifier)exp;
                if (id.type.isStruct()
                        || id.type.isArray()
                        || id.type.isStaticArray()
                        || id.type.isClass())
                    return RValue(table.find(id.get));
                else
                    return RValue(b.buildLoad(table.find(id.get), id.get));
            case ExpType.MemberReference:
                return loadLValue(genLValue(exp));
        }
        assert(0, "Reached end of switch in genExpression");
        return RValue(null);
    }

    /**
      Generate a binary expression.

      Currently only works for signed and unsigned integers, but is almost
      ready for floats and should be expanded to everything else.
     **/
    Value genBinExp(BinaryExp e)
    {
        auto left = genExpression(e.left).value;
        auto right = genExpression(e.right).value;
        DType t_a = e.left.type;
        DType t_b = e.right.type;

        Value res;
        // TODO: do usual type promotions on a and b
        // TODO: support floats
        if (t_a.isArithmetic() && t_b.isArithmetic())
        {
            Operation op = t_a.getOperationWith(op2op(e.op), t_b);
            assert(op.isBuiltin(),
                    "numbers should only use builtin ops");
            alias BuiltinOperation BO;
            BO val = op.builtinOp();
            // map val to buildAdd or similar
            switch (val) {
                case BO.Add: res = b.buildAdd(left, right, "add"); break;
                case BO.Sub: res = b.buildSub(left, right, "sub"); break;
                case BO.Mul: res = b.buildMul(left, right, "mul"); break;
                case BO.SDiv: res = b.buildSDiv(left, right, "div"); break;
                case BO.UDiv: res = b.buildUDiv(left, right, "div"); break;
                case BO.FDiv: res = b.buildFDiv(left, right, "div"); break;
                case BO.SRem: res = b.buildSRem(left, right, "rem"); break;
                case BO.URem: res = b.buildURem(left, right, "rem"); break;
                case BO.FRem: res = b.buildFRem(left, right, "rem"); break;

                case BO.Shl:  res = b.buildShl(left, right, "shl"); break;
                case BO.LShr: res = b.buildLShr(left, right, "lshr"); break;
                case BO.AShr: res = b.buildAShr(left, right, "ashr"); break;

                case BO.And: res = b.buildAnd(left, right, "and"); break;
                case BO.Or:  res = b.buildOr (left, right, "or"); break;
                case BO.Xor: res = b.buildXor(left, right, "xor"); break;

                default:
                    LLVMPred pred = predFromBI(val);
                    if (t_a.isReal())
                        if (val == BO.Eq)
                            pred = LLVMPred.Real(RealPredicate.OEQ);
                        else if (val == BO.Ne)
                            pred = LLVMPred.Real(RealPredicate.ONE);
                    IntPredicate ip = pred.intPred;
                    RealPredicate rp = pred.realPred;
                    assert(pred.isValid, "Not a predicate");
                    if (pred.isReal)
                        res = b.buildFCmp(rp, left, right, "cmp");
                    else
                        res = b.buildICmp(ip, left, right, "cmp");
                    break;
                }
        }
        else
        /*
            if left has op for right's type:
                a_op = left.op(right)
            if right has usable op_r:
                b_op_r = right.op_r(left)
            if a_op or b_op_r is set, choose the best one
            else if op is commutative
                if left has usable op_r
                    a_op_r = left.op_r(right)
                if right has usable op
                    b_op = right.op(left)
                choose best one from a_op_r and b_op
            else error
         */
            assert(0, "Not integers?");

        return res;
    }

    /**
      Generates one statement
     **/
    // This should be split into specific methods - one per Stmt type?
    void genStmt(Stmt stmt)
    {
        switch(stmt.stmtType)
        {
            case StmtType.Compound:
                auto stmts = cast(CompoundStatement)stmt;
                foreach (s; stmts.statements)
                    genStmt(s);
                break;
            case StmtType.Return:
                auto ret = cast(ReturnStmt)stmt;
                DFunction type = stmt.env.parentFunction().type();
                Type t = llvm(type.returnType);
                if (ret.exp is null)
                    if (t is Type.Void)
                    {
                        b.buildRetVoid();
                        return;
                    }
                    else
                        assert(0, PE.VoidRetInNonVoidFunc);

                RValue v = genExpression(ret.exp);
                b.buildRet(v.value);
                break;
            case StmtType.Decl:
                auto declStmt = cast(DeclStmt)stmt;
                genDecl(declStmt.decl);
                break;
            case StmtType.Exp:
                auto expStmt = cast(ExpStmt)stmt;
                genExpression(expStmt.exp);
                break;
            case StmtType.If:
                auto ifStmt = cast(IfStmt)stmt;
                Value cond = genExpression(ifStmt.cond).value;
                if (cond.type !is Type.Int1)
                {
                    Value False = ConstantInt.GetS(cond.type, 0);
                    cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond");
                }
                auto func_name = symbolName(stmt.env.parentFunction());
                Function func = m.getNamedFunction(func_name);
                bool has_else = (ifStmt.else_body !is null);

                auto thenBB = func.appendBasicBlock("then");
                auto elseBB = has_else? func.appendBasicBlock("else") : null;
                auto mergeBB = func.appendBasicBlock("merge");

                b.buildCondBr(cond, thenBB, has_else? elseBB : mergeBB);
                b.positionAtEnd(thenBB);
                genStmt(ifStmt.then_body);
                thenBB = b.getInsertBlock();
                if (b.getInsertBlock().terminated() is false)
                    b.buildBr(mergeBB);

                if (has_else)
                {
                    b.positionAtEnd(elseBB);
                    genStmt(ifStmt.else_body);
                    elseBB = b.getInsertBlock();
                    if (elseBB.terminated() is false)
                        b.buildBr(mergeBB);
                }

                b.positionAtEnd(mergeBB);
                break;
            case StmtType.While:
                auto wStmt = cast(WhileStmt)stmt;
                genLoop(stmt.env, wStmt.cond, false, wStmt.whileBody);
                break;
            /+
            case StmtType.DoWhile:
                auto wStmt = cast(DoWhileStmt)stmt;
                genLoop(stmt.env, wStmt.cond, true, wStmt.whileBody);
                break;
            +/
            case StmtType.For:
                auto fStmt = cast(ForStmt)stmt;
                genStmt(fStmt.init);
                scope inc = new ExpStmt(fStmt.incre);
                genLoop(stmt.env, fStmt.cond, false, fStmt.forBody, inc);
                break;
            case StmtType.Switch:
                auto sw = cast(SwitchStmt)stmt;
                Value cond = genExpression(sw.cond).value;

                auto fc = stmt.env.parentFunction();
                Function func = m.getNamedFunction(symbolName(fc));

                BasicBlock oldBB = b.getInsertBlock();
                BasicBlock defBB;
                BasicBlock endBB = func.appendBasicBlock("sw.end");
                if (sw.defaultBlock)
                {
                    defBB = Function.InsertBasicBlock(endBB, "sw.def");
                    b.positionAtEnd(defBB);
                    foreach (case_statement; sw.defaultBlock)
                        genStmt(case_statement);
                    if (!defBB.terminated())
                        b.buildBr(endBB);
                    b.positionAtEnd(oldBB);
                }
                else
                    defBB = endBB;
                auto SI = b.buildSwitch(cond, defBB, sw.cases.length);
                foreach (c; sw.cases)
                {
                    BasicBlock prevBB;
                    foreach (i, val; c.values)
                    {
                        auto BB = Function.InsertBasicBlock(defBB, "sw.bb");
                        SI.addCase(ConstantInt.GetS(cond.type, c.values_converted[i]), BB);

                        if (i + 1 == c.values.length)
                        {
                            b.positionAtEnd(BB);
                            foreach (case_statement; c.stmts)
                                genStmt(case_statement);
                            if (!BB.terminated())
                                b.buildBr(c.followedByDefault? defBB : endBB);
                        }

                        if (prevBB !is null && !prevBB.terminated())
                        {
                            b.positionAtEnd(prevBB);
                            b.buildBr(BB);
                        }
                        prevBB = BB;
                    }
                }
                b.positionAtEnd(endBB);
                break;
        }
    }

    /**
      Generate a loop.

      Loops while cond is true, executing all statements in stmts every time.

      If skipFirstCond is set, the condition is skipped the first time around,
      like in a do-while loop.
     **/
    void genLoop(Scope env, Exp cond, bool skipFirstCond, Stmt[] stmts...)
    {
        auto fd = env.parentFunction();
        Function func = m.getNamedFunction(symbolName(fd));

        auto condBB = func.appendBasicBlock("cond");
        auto bodyBB = func.appendBasicBlock("body");
        auto doneBB = func.appendBasicBlock("done");

        b.buildBr(skipFirstCond? bodyBB : condBB);
        b.positionAtEnd(condBB);

        Value cond_v = genExpression(cond).value;
        if (cond_v.type !is Type.Int1)
        {
            Value False = ConstantInt.GetS(cond_v.type, 0);
            cond_v = b.buildICmp(IntPredicate.NE, cond_v, False, ".cond");
        }
        b.buildCondBr(cond_v, bodyBB, doneBB);

        b.positionAtEnd(bodyBB);
        foreach (stmt; stmts)
            genStmt(stmt);
        if (b.getInsertBlock().terminated() is false)
            b.buildBr(condBB);

        b.positionAtEnd(doneBB);
    }

    /*
       Get the address of an expression - allowing us to modify something in
       memory or on the stack.
     */
    LValue genLValue(Exp exp)
    {
        switch(exp.expType)
        {
            case ExpType.Identifier:
                auto id = cast(Identifier)exp;
                return LValue(table.find(id.get));
            case ExpType.Deref:
                // LValue(*x): load(x)
                // RValue(*x): load(load(x))
                // This way *x = *x + 1 will work
                // We need an ekstra load, because we get an i32** rather than
                // i32* since stuff is alloc'd
                auto DE = cast(DerefExp)exp;
                return LValue(genExpression(DE.exp).value);
            case ExpType.Index:
                auto indexExp = cast(IndexExp)exp;
                auto type = indexExp.target.type;
                auto index = genExpression(indexExp.index);
                Value[2] gep_indices;
                gep_indices[0] = ZeroIndex;
                gep_indices[1] = index.value;
                Value res;
                auto target = genLValue(indexExp.target).getAddress();
                if (type.isStaticArray())
                    res = b.buildGEP(target, gep_indices[0 .. 2], "index");
                else if (type.isPointer())
                    res = b.buildGEP(target, gep_indices[1 .. 2], "index");
                else assert(0, "Can only index pointers and arrays");
                return LValue(res);
            case ExpType.MemberReference:
                auto mem = cast(MemberReference)exp;
                switch (mem.target.expType)
                {
                    case ExpType.Identifier:
                        auto id = cast(Identifier)mem.target;
                        auto child = mem.child;
                        Value v = table.find(id.get);
                        DType t = id.type;
                        if (auto st = t.asStruct)
                        {
                            int i = st.indexOf(child.get);
                            if (i == -1)
                            {
                                auto fname = mem.getSymbol.getMangledFQN();
                                auto f = m.getNamedFunction(fname);
                                return LValue(f);
                            }

                            Value[2] vals;
                            vals[0] = ZeroIndex;
                            vals[1] = ConstantInt.GetU(IntegerType.Int32, i);

                            Value val = b.buildGEP(v, vals, id.get~"."~child.get);
                            return LValue(val);
                        }
                        else if (auto ct = t.asClass)
                        {
                            int i = ct.indexOf(child.get);
                            Value[2] vals;
                            vals[0] = ZeroIndex;
                            // A normal member
                            if (i != -1)
                            {
                                vals[1] = ConstantInt.GetU(IntegerType.Int32, i);
                                Value val = b.buildGEP(v, vals, id.get~"."~child.get);
                                return LValue(val);
                            }
                            // A method
                            else
                            {
                                vals[1] = ZeroIndex;
                                //vals[1] = ConstantInt.GetU(IntegerType.Int32, 1);
                                auto vtbl_name = ct.name ~ "_vtable";
                                auto vtbl = m.getNamedGlobal(vtbl_name);
                                v = vtbl;
                            }

                            Value val = b.buildGEP(v, vals, id.get~"."~child.get);
                            return LValue(val);
                        }
                        else
                            assert(0, "Can only access members in classes "
                                      "and structs");

                    case ExpType.MemberReference:
                        auto addr = genLValue(mem.target).getAddress();
                        auto child = mem.child;
                        DStruct t = mem.target.type.asStruct();

                        int i = t.indexOf(child.get);

                        Value[2] vals;   
                        vals[0] = ZeroIndex;
                        vals[1] = ConstantInt.GetU(IntegerType.Int32, i);

                        Value val = b.buildGEP(addr, vals, "."~child.get);
                        return LValue(val);
                }
                break;
        }
        assert(0, "Reached end of switch in getPointer");
        return LValue(null);
    }

    /**
      Store into an lvalue from a rvalue. Both are assumed to have type t.
     **/
    void storeThroughLValue(LValue dst, RValue src, DType t)
    {
        Value to = dst.getAddress();
        Value from = src.value;

        auto a = cast(PointerType)to.type;
        assert(a !is null, "Can only store through pointers");

        if (auto st = t.asStruct())
            genMemcpy(to, from, t);
        else if (auto sa = t.asStaticArray())
            genMemcpy(to, from, t, sa.arrayOf.byteSize);
        else
            b.buildStore(from, to);
    }

    /**
      Copy from src into dst. The values are assumed to have the same size,
      and the amount of bytes to copy is taken from t.
     **/
    void genMemcpy(Value dst, Value src, DType t, int alignment = 16)
    {
        Value from = b.buildBitCast(src, BytePtr, ".copy_from");
        Value to = b.buildBitCast(dst, BytePtr, ".copy_to");
        Value[4] args;
        args[0] = to;
        args[1] = from;
        args[2] = ConstantInt.GetS(Type.Int32, t.byteSize());
        args[3] = ConstantInt.GetS(Type.Int32, alignment);
        b.buildCall(llvm_memcpy, args[], null);
    }

    /**
      Generate the statements necessary to convert V, from type 'from' to type
      'to'.
     **/
    RValue genTypeCast(RValue V, DType from, DType to)
    {
        Value delegate(Value, Type, char[]) extend, trunc;
        if(auto ito = to.asInteger())
        {
            extend = ito.unsigned? &b.buildZExt : &b.buildSExt;
            trunc = &b.buildTrunc;
            if(auto rfrom = from.isReal())
            {
                extend = ito.unsigned? &b.buildFPToUI : &b.buildFPToSI;
                trunc = extend;
            }
        }
        else if (auto rto = to.asReal())
        {
            extend = &b.buildFPExt;
            trunc = &b.buildFPTrunc;
            if(auto ifrom = from.isInteger())
            {
                extend = rto.unsigned? &b.buildUIToFP : &b.buildSIToFP;
                trunc = extend;
            }
        }
        else
            assert(0, "implicit cast need implimentation");

        Value res;
        if (from.byteSize() < to.byteSize())
            res = extend(V.value, llvm(to), "ext");
        else
            res = trunc(V.value, llvm(to), "trunc");
        return RValue(res);
    }

    /**
      Given the address of something, load it into an alloc.
     **/
    RValue loadLValue(LValue addr, char[] name = null)
    {
        Value val = addr.getAddress();
        if (name is null)
            name = val.name.length > 0? val.name : "load";

        auto res = b.buildLoad(val, name);
        return RValue(res);
    }

    /// Get the mangled name of a function
    char[] symbolName(FuncDecl f)
    {
        if (f.att.getExtern == Extern.D)
            return f.sym.getMangledFQN();
        return f.sym.getName;
    }

    char[] symbolName(Exp f)
    {
        if (f.getSymbol.decl.att.getExtern == Extern.D)
            return f.getSymbol.getMangledFQN();
        return f.getSymbol.getName;
    }

    /**
      Get the LLVM Type corresponding to a DType.

      Currently using the built-in associative array - not sure if it works
      well when the hashes are so uniform.

      Other possibilities would be to find a hash-function that works on
      something as small as 4 bytes or to create a sparse array perhaps.
     */
    Type llvm(DType t)
    {
        if (auto llvm_t = t in type_map)
            return *llvm_t;
        return llvmCreateNew(t);
    }

    // Create an LLVM type and insert it into the type map, and return the
    // result
    Type llvmCreateNew(DType t)
    {
        if (auto i = cast(DInteger)t)
        {
            Type res = IntegerType.Get(i.byteSize() * 8);
            type_map[t] = res;
            return res;
        }
        else if (auto s = t.asStruct)
        {
            SmallArray!(Type, 8) members;
            DType[] array;
            array.length = s.members.length;

            foreach (m; s.members)
                array[m.index] = m.type;

            foreach (m; array)
                members ~= llvm(m);

            Type res = StructType.Get(members.unsafe());
            type_map[t] = res;
            m.addTypeName("struct." ~ s.name, res);
            return res;
        }
        else if (auto c = t.asClass)
        {
            SmallArray!(Type) members;
            if (c.members.length > 0)
            {
                DType[] array;
                array.length = c.members.length;

                foreach (m; c.members)
                    array[m.index] = m.type;

                foreach (m; array)
                    members ~= llvm(m);
            }
            else members ~= Type.Int32;

            Type res = StructType.Get(members.unsafe());
            res = PointerType.Get(res);
            type_map[t] = res;
            m.addTypeName("class." ~ c.name, res);
            return res;
        }
        else if (auto f = t.asFunction)
        {
            // We should never have a function returning structs, because of
            // the simplify step
            assert(f.returnType.isStruct() == false, "Can't return structs");
            Type ret_t = llvm(f.returnType);

            SmallArray!(Type, 8) params;
            foreach(param; f.params)
                if (param.isStruct)
                    params ~= PointerType.Get(llvm(param));
                else if (param.isArray)
                    params ~= PointerType.Get(llvm(param));
                else
                    params ~= llvm(param);

            Type res = FunctionType.Get(ret_t, params.unsafe());
            type_map[t] = res;
            return res;
        }
        else if (auto f = t.asPointer)
        {
            Type res = PointerType.Get(llvm(f.pointerOf));
            type_map[t] = res;
            return res;
        }
        else if (auto f = t.asStaticArray)
        {
            Type res = ArrayType.Get(llvm(f.arrayOf), f.size);
            type_map[t] = res;
            return res;
        }
        assert(0, "Only integers, structs and functions are supported");
    }

    // Might as well insert all the basic types from the start
    void createBasicTypes()
    {
        type_map[DType.Void] = Type.Void;

        type_map[DType.Bool]   = Type.Int1;
        type_map[DType.Byte]   = Type.Int8;
        type_map[DType.UByte]  = Type.Int8;
        type_map[DType.Short]  = Type.Int16;
        type_map[DType.UShort] = Type.Int16;
        type_map[DType.Int]    = Type.Int32;
        type_map[DType.UInt]   = Type.Int32;
        type_map[DType.Long]   = Type.Int64;
        type_map[DType.ULong]  = Type.Int64;

        type_map[DType.Float]  = Type.Float;
        type_map[DType.Double] = Type.Double;
        type_map[DType.Real]   = Type.X86_FP80;

        type_map[DType.Char]   = Type.Int8;
        type_map[DType.WChar]  = Type.Int16;
        type_map[DType.DChar]  = Type.Int32;
    }

private:

    // llvm stuff
    DModule mod;
    .llvm.llvm.Module m;
    Builder b;
    Function llvm_memcpy;
    ConstantInt ZeroIndex;
    Type BytePtr;
    Type[DType] type_map;

    SimpleSymbolTable table;
}

private Operator op2op(BinaryExp.Operator op)
{
    alias BinaryExp.Operator O;
    Operator res;
    switch (op) {
        case O.Add: res = Operator.Add; break;
        case O.Sub: res = Operator.Sub; break;
        case O.Mul: res = Operator.Mul; break;
        case O.Div: res = Operator.Div; break;
        case O.LeftShift: res = Operator.Shl; break;
        case O.RightShift: res = Operator.AShr; break;
        case O.UnsignedRightShift: res = Operator.LShr; break;

        case O.And: res = Operator.And; break;
        case O.Or: res = Operator.Or; break;
        case O.Xor: res = Operator.Xor; break;

        case O.Eq: res = Operator.Eq; break;
        case O.Ne: res = Operator.Ne; break;
        case O.Lt: res = Operator.Lt; break;
        case O.Le: res = Operator.Le; break;
        case O.Gt: res = Operator.Gt; break;
        case O.Ge: res = Operator.Ge; break;
    }
    return res;
}

private struct LLVMPred
{
    bool isValid = false;
    bool isReal;
    union {
        IntPredicate intPred;
        RealPredicate realPred;
    }

    static LLVMPred Int(IntPredicate p)
    {
        LLVMPred res;
        res.isValid = true;
        res.isReal = false;
        res.intPred = p;
        return res;
    }
    static LLVMPred Real(RealPredicate p)
    {
        LLVMPred res;
        res.isValid = true;
        res.isReal = true;
        res.realPred = p;
        return res;
    }
}
private LLVMPred predFromBI(BuiltinOperation op)
{
    alias BuiltinOperation O;
    LLVMPred pred;
    switch (op) {
        case O.Eq:  pred = LLVMPred.Int(IntPredicate.EQ);  break;
        case O.Ne:  pred = LLVMPred.Int(IntPredicate.NE);  break;

        case O.SLt: pred = LLVMPred.Int(IntPredicate.SLT); break;
        case O.ULt: pred = LLVMPred.Int(IntPredicate.ULT); break;
        case O.FLt: pred = LLVMPred.Real(RealPredicate.OLT); break;

        case O.SLe: pred = LLVMPred.Int(IntPredicate.SLE); break;
        case O.ULe: pred = LLVMPred.Int(IntPredicate.ULE); break;
        case O.FLe: pred = LLVMPred.Real(RealPredicate.OLE); break;

        case O.SGt: pred = LLVMPred.Int(IntPredicate.SGT); break;
        case O.UGt: pred = LLVMPred.Int(IntPredicate.UGT); break;
        case O.FGt: pred = LLVMPred.Real(RealPredicate.OGT); break;

        case O.SGe: pred = LLVMPred.Int(IntPredicate.SGE); break;
        case O.UGe: pred = LLVMPred.Int(IntPredicate.UGE); break;
        case O.FGe: pred = LLVMPred.Real(RealPredicate.OGE); break;
    };
    return pred;
}

private class VisitFuncDecls : Visitor!(void)
{
    void delegate(FuncDecl) dg;
    this(void delegate(FuncDecl funcDecl) dg)
    {
        this.dg = dg;
    }

    override void visitFuncDecl(FuncDecl fd)
    {
        dg(fd);
    }
}

private class SimpleSymbolTable
{
    Value[char[]][] namedValues;

    void enterScope()
    {
        namedValues ~= cast(Value[char[]])["__dollar":null];
    }

    void leaveScope()
    {
        namedValues.length = namedValues.length - 1;
    }

    Value put(Value val, char[] key)
    {
        namedValues[$ - 1][key] = val;
        return val;
    }

    Value find(char[] key)
    {
        foreach_reverse (map; namedValues)
            if(auto val_ptr = key in map)
                return *val_ptr;
        return null;
    }

    alias find opIndex;
    alias put opIndexAssign;
}