view gen/CodeGen.d @ 88:eb5b2c719a39 new_gen

Major change to locations, tokens and expressions. A location (now SourceLocation or SLoc) is only 32 bit in size - disadvantage is that it can't find its own text. You have to go through the new SourceManager to do that. This has caused changes to a lot of stuff and removal of DataSource and the old Location Additionally Exp has gotten some location stuff, so we can give proper error messages. Not in Decl and Stmt yet, but thats coming too.
author Anders Halager <halager@gmail.com>
date Sun, 04 May 2008 18:13:46 +0200
parents 9a35a973175a
children a49bb982a7b0
line wrap: on
line source

module gen.CodeGen;

import tango.io.Stdout,
       Int = tango.text.convert.Integer;
import tango.core.Array : find, partition;

import llvm.llvm;

import ast.Decl,
       ast.Stmt,
       ast.Exp;

import misc.Error,
       basic.SmallArray;

import lexer.Token;

import sema.SymbolTableBuilder,
       sema.Visitor;

private char[] genBuildCmp(char[] p)
{
    return `
        Value build`~p~`(Value l, Value r, char[] n)
        {
            return b.buildICmp(IntPredicate.`~p~`, l, r, n);
        }`;
}

class CodeGen
{
private:
    mixin(genBuildCmp("EQ"));
    mixin(genBuildCmp("NE"));
    mixin(genBuildCmp("SLT"));
    mixin(genBuildCmp("SLE"));
    mixin(genBuildCmp("SGT"));
    mixin(genBuildCmp("SGE"));

public:
    this()
    {
        alias BinaryExp.Operator op;
        b = new Builder;
        
        opToLLVM = [
            op.Add      : &b.buildAdd,
            op.Sub      : &b.buildSub,
            op.Mul      : &b.buildMul,
            op.Div      : &b.buildSDiv,
            op.Eq       : &buildEQ,
            op.Ne       : &buildNE,
            op.Lt       : &buildSLT,
            op.Le       : &buildSLE,
            op.Gt       : &buildSGT,
            op.Ge       : &buildSGE
        ];
        table = new SimpleSymbolTable();

        createBasicTypes();
    }

    ~this()
    {
        b.dispose();
    }

    void gen(Decl[] decls, bool optimize, bool inline)
    {
        // create module
        m = new Module("main_module");
        scope(exit) m.dispose();

        table.enterScope;

        BytePtr = PointerType.Get(Type.Int8);
        auto temp = FunctionType.Get(Type.Void, [BytePtr, BytePtr, Type.Int32, Type.Int32]);
        llvm_memcpy = m.addFunction(temp, "llvm.memcpy.i32");

        auto registerFunc =
            (FuncDecl fd)
            {
                Type[] param_types;
                foreach (i, p; fd.funcArgs)
                {
                    DType t = p.env.find(p.identifier).type;
                    if(auto st = t.asStruct)
                    {
                        Type pointer = PointerType.Get(llvm(st));
                        param_types ~= pointer;
                    }
                    if(auto ar = t.asArray)
                    {
                        Type pointer = PointerType.Get(llvm(ar));
                        param_types ~= pointer;
                    }
                    else
                        param_types ~= llvm(t);
                }
                auto ret_t = fd.env.find(fd.identifier).type;
                if(auto st = cast(DStruct)ret_t)
                    ret_t = DType.Void;
                else if(auto f = cast(DFunction)ret_t)
                    ret_t = f.returnType;
                auto func_t = FunctionType.Get(llvm(ret_t), param_types);
                auto llfunc = m.addFunction(func_t, fd.identifier.get);

                foreach (i, p; fd.funcArgs)
                {
                    if(i == 0 && fd.sret)
                        llfunc.addParamAttr(0, ParamAttr.StructRet);

                    DType t = p.env.find(p.identifier).type;
                    if(auto st = t.asStruct)
                    {
                        if(i == 0 && fd.sret)
                            continue;
                        llfunc.addParamAttr(i,ParamAttr.ByVal);
                    }
                    else if(auto ar = t.asArray)
                    {
                        llfunc.addParamAttr(i,ParamAttr.ByVal);
                    }
                }
            };
        auto visitor = new VisitFuncDecls(registerFunc);
        visitor.visit(decls);
        // Before beginning we move all top level var-decls to the start
        // and then we generate the var-decls first
        // partition is NOT required to be stable, but that should not create
        // any problems.
        partition(decls, (Decl d) { return d.declType == DeclType.VarDecl; });

        foreach (decl; decls)
            genRootDecl(decl);

        table.leaveScope;

        debug m.verify();

        if(optimize)
            m.optimize(inline);

        m.writeBitcodeToFile("out.bc");
    }

    void genRootDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.FuncDecl:
                FuncDecl funcDecl = cast(FuncDecl)decl;

                // Empty function - declare;
                if(funcDecl.emptyFunction)
                    return;

                llvm(funcDecl.type);
                auto llfunc = m.getNamedFunction(funcDecl.type.name);
                auto func_tp = cast(PointerType)llfunc.type;
                auto func_t = cast(FunctionType)func_tp.elementType();
                auto ret_t = func_t.returnType();


                auto bb = llfunc.appendBasicBlock("entry");
                b.positionAtEnd(bb);

                table.enterScope;
                foreach (i, v; funcDecl.funcArgs)
                {
                    llfunc.getParam(i).name = v.identifier.get;
                    auto name = v.identifier.get;
                    if(!cast(PointerType)llfunc.getParam(i).type)
                    {
                        auto AI = b.buildAlloca(llfunc.getParam(i).type, name);
 //                       Value va = b.buildLoad(val, name);
                        b.buildStore(llfunc.getParam(i), AI);
                        table[name] = AI;
                    }
                    else
                        table[name] = llfunc.getParam(i);
                }

                foreach (stmt; funcDecl.statements)
                    genStmt(stmt);

                // if the function didn't end with a return, we automatically
                // add one (return 0 as default)
                if (b.getInsertBlock().terminated() is false)
                    if (ret_t is Type.Void)
                        b.buildRetVoid();
                    else
                        b.buildRet(ConstantInt.GetS(ret_t, 0));

                table.leaveScope;
                break;

            case DeclType.VarDecl:
                auto varDecl = cast(VarDecl)decl;
                auto sym = varDecl.env.find(varDecl.identifier);
                Type t = llvm(sym.type);
                GlobalVariable g = m.addGlobal(t, sym.id.get);
                g.initializer = ConstantInt.GetS(t, 0);
                table[varDecl.identifier.get] = g;
                break;
        
            case DeclType.StructDecl:
                auto structDecl = cast(StructDecl)decl;
                llvm(structDecl.type);
                //m.addTypeName(structDecl.identifier.get, llvm(structDecl.type));
                break;

            default:
                break;
        }
    }

    void genDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.VarDecl:
                auto varDecl = cast(VarDecl)decl;
                auto name = varDecl.identifier.get;
                auto sym = varDecl.env.find(varDecl.identifier);
                auto AI = b.buildAlloca(llvm(sym.type), name);
                table[name] = AI;
                if (varDecl.init)
                    buildAssign(varDecl.identifier, varDecl.init);
                break;
        
            default:
        }
    }

    struct PE
    {
        static char[] NoImplicitConversion =
            "Can't find an implicit conversion between %0 and %1";
        static char[] VoidRetInNonVoidFunc =
            "Only void functions can return without an expression";
    }

    void sextSmallerToLarger(ref Value left, ref Value right)
    {
        if (left.type != right.type)
        {
            // try to find a convertion - only works for iX
            IntegerType l = cast(IntegerType) left.type;
            IntegerType r = cast(IntegerType) right.type;
            if (l is null || r is null)
                throw error(__LINE__, PE.NoImplicitConversion)
                    .arg(left.type.toString)
                    .arg(right.type.toString);

            if (l.numBits() < r.numBits())
                left = b.buildSExt(left, r, ".cast");
            else
                right = b.buildSExt(right, l, ".cast");
        }
    }

    Value genExpression(Exp exp)
    {
        switch(exp.expType)
        {
            case ExpType.Binary:
                auto binaryExp = cast(BinaryExp)exp;

                auto left = genExpression(binaryExp.left);
                auto right = genExpression(binaryExp.right);

//                sextSmallerToLarger(left, right);

                OpBuilder op = opToLLVM[binaryExp.op];

                return op(left, right, ".");

            case ExpType.IntegerLit:
                auto integetLit = cast(IntegerLit)exp;
                auto val = integetLit.get;
                return ConstantInt.GetS(Type.Int32, Integer.parse(val));
            case ExpType.Negate:
                auto negateExp = cast(NegateExp)exp;
                auto target = genExpression(negateExp.exp);
                return b.buildNeg(target, "neg");
            case ExpType.Deref:
                auto derefExp = cast(DerefExp)exp;
                auto target = genExpression(derefExp.exp);
                return b.buildLoad(target, "deref");
            case ExpType.AssignExp:
                auto assignExp = cast(AssignExp)exp;
                return buildAssign(assignExp.identifier, assignExp.exp);
            case ExpType.Index:
                auto indexExp = cast(IndexExp)exp;
                return b.buildLoad(getPointer(exp), ".");
            case ExpType.CallExp:
                auto callExp = cast(CallExp)exp;
                auto func_sym = exp.env.find(cast(Identifier)callExp.exp);
                Value[] args;
                foreach (arg; callExp.args)
                {
                    Value v = genExpression(arg);
                    args ~= v;

                }
                // BUG: doesn't do implicit type-conversion
                if(callExp.sret)
                    return b.buildCall(m.getNamedFunction(func_sym.id.get), args, "");
                return b.buildCall(m.getNamedFunction(func_sym.id.get), args, ".call");
            case ExpType.CastExp:
                auto castExp = cast(CastExp)exp;
                auto value = genExpression(castExp.exp);

                if(!castExp.type.hasImplicitConversionTo(castExp.type))
                    assert(0, "Invalid cast");

                Value v;
                if(castExp.exp.type.byteSize <= castExp.type.byteSize)
                    v = b.buildZExt(value, llvm(castExp.type), "cast");
                else
                    v = b.buildTrunc(value, llvm(castExp.type), "cast");

                return v;

            case ExpType.Identifier:
                auto identifier = cast(Identifier)exp;
                auto sym = exp.env.find(identifier);
                if(sym.type.isStruct || sym.type.isArray)
                    return table.find(sym.id.get);
                else
                    return b.buildLoad(table.find(sym.id.get), sym.id.get);
            case ExpType.MemberReference:
                auto v = getPointer(exp);
//                return v;
                return b.buildLoad(v, v.name);
        }
        assert(0, "Reached end of switch in genExpression");
        return null;
    }

    void genStmt(Stmt stmt)
    {
        switch(stmt.stmtType)
        {
            case StmtType.Compound:
                auto stmts = cast(CompoundStatement)stmt;
                foreach (s; stmts.statements)
                    genStmt(s);
                break;
            case StmtType.Return:
                auto ret = cast(ReturnStmt)stmt;
                DFunction type = stmt.env.parentFunction().type();
                Type t = llvm(type.returnType);
                if (ret.exp is null)
                    if (t is Type.Void)
                    {
                        b.buildRetVoid();
                        return;
                    }
                    else
                        throw error(__LINE__, PE.VoidRetInNonVoidFunc);

                Value v = genExpression(ret.exp);
/*                if (v.type != t)
                {
                    IntegerType v_t = cast(IntegerType) v.type;
                    IntegerType i_t = cast(IntegerType) t;
                    if (v_t is null || i_t is null)
                        throw error(__LINE__, PE.NoImplicitConversion)
                            .arg(v.type.toString)
                            .arg(t.toString);

                    if (v_t.numBits() < i_t.numBits())
                        v = b.buildSExt(v, t, ".cast");
                    else
                        v = b.buildTrunc(v, t, ".cast");
                }*/
                b.buildRet(v);
                break;
            case StmtType.Decl:
                auto declStmt = cast(DeclStmt)stmt;
                genDecl(declStmt.decl);
                break;
            case StmtType.Exp:
                auto expStmt = cast(ExpStmt)stmt;
                genExpression(expStmt.exp);
                break;
            case StmtType.If:
                auto ifStmt = cast(IfStmt)stmt;
                Value cond = genExpression(ifStmt.cond);
                if (cond.type !is Type.Int1)
                {
                    Value False = ConstantInt.GetS(cond.type, 0);
                    cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond");
                }
                auto func_name = stmt.env.parentFunction().identifier.get;
                Function func = m.getNamedFunction(func_name);
                bool has_else = (ifStmt.else_body !is null);

                auto thenBB = func.appendBasicBlock("then");
                auto elseBB = has_else? func.appendBasicBlock("else") : null;
                auto mergeBB = func.appendBasicBlock("merge");

                b.buildCondBr(cond, thenBB, has_else? elseBB : mergeBB);
                b.positionAtEnd(thenBB);
                genStmt(ifStmt.then_body);
                thenBB = b.getInsertBlock();
                if (b.getInsertBlock().terminated() is false)
                    b.buildBr(mergeBB);

                if (has_else)
                {
                    b.positionAtEnd(elseBB);
                    genStmt(ifStmt.else_body);
                    elseBB = b.getInsertBlock();
                    if (elseBB.terminated() is false)
                        b.buildBr(mergeBB);
                }

                b.positionAtEnd(mergeBB);
                break;
            case StmtType.While:
                auto wStmt = cast(WhileStmt)stmt;
                auto func_name = stmt.env.parentFunction().identifier.get;
                Function func = m.getNamedFunction(func_name);

                auto condBB = func.appendBasicBlock("cond");
                auto bodyBB = func.appendBasicBlock("body");
                auto doneBB = func.appendBasicBlock("done");

                b.buildBr(condBB);
                b.positionAtEnd(condBB);
                Value cond = genExpression(wStmt.cond);
                if (cond.type !is Type.Int1)
                {
                    Value False = ConstantInt.GetS(cond.type, 0);
                    cond = b.buildICmp(IntPredicate.NE, cond, False, ".cond");
                }
                b.buildCondBr(cond, bodyBB, doneBB);

                b.positionAtEnd(bodyBB);
                genStmt(wStmt.whileBody);
                if (b.getInsertBlock().terminated() is false)
                    b.buildBr(condBB);

                b.positionAtEnd(doneBB);
                break;
            case StmtType.Switch:
                auto sw = cast(SwitchStmt)stmt;
                Value cond = genExpression(sw.cond);

                auto func_name = stmt.env.parentFunction().identifier.get;
                Function func = m.getNamedFunction(func_name);

                BasicBlock oldBB = b.getInsertBlock();
                BasicBlock defBB;
                BasicBlock endBB = func.appendBasicBlock("sw.end");
                if (sw.defaultBlock)
                {
                    defBB = Function.InsertBasicBlock(endBB, "sw.def");
                    b.positionAtEnd(defBB);
                    foreach (case_statement; sw.defaultBlock)
                        genStmt(case_statement);
                    if (!defBB.terminated())
                        b.buildBr(endBB);
                    b.positionAtEnd(oldBB);
                }
                else
                    defBB = endBB;
                auto SI = b.buildSwitch(cond, defBB, sw.cases.length);
                foreach (c; sw.cases)
                {
                    BasicBlock prevBB;
                    foreach (i, val; c.values)
                    {
                        auto BB = Function.InsertBasicBlock(defBB, "sw.bb");
                        SI.addCase(ConstantInt.GetS(cond.type, c.values_converted[i]), BB);

                        if (i + 1 == c.values.length)
                        {
                            b.positionAtEnd(BB);
                            foreach (case_statement; c.stmts)
                                genStmt(case_statement);
                            if (!BB.terminated())
                                b.buildBr(c.followedByDefault? defBB : endBB);
                        }

                        if (prevBB !is null && !prevBB.terminated())
                        {
                            b.positionAtEnd(prevBB);
                            b.buildBr(BB);
                        }
                        prevBB = BB;
                    }
                }
                b.positionAtEnd(endBB);
                break;
        }
    }

    Value getPointer(Exp exp)
    {
        switch(exp.expType)
        {
            case ExpType.Identifier:
                auto identifier = cast(Identifier)exp;
                auto sym = exp.env.find(identifier);
                return table.find(sym.id.get);
            case ExpType.Deref:
                auto derefExp = cast(DerefExp)exp;
                auto target = getPointer(derefExp.exp);
                return b.buildLoad(target, "deref");
            case ExpType.Index:
                auto indexExp = cast(IndexExp)exp;
                auto type = indexExp.target.type;
                auto index = genExpression(indexExp.index);
                Value[2] gep_indices;
                gep_indices[0] = ConstantInt.Get(IntegerType.Int32, 0, false);
                gep_indices[1] = index;
                if (type.isArray())
                {
                    auto array = getPointer(indexExp.target);
                    return b.buildGEP(array, gep_indices[0 .. 2], "index");
                }
                else if (type.isPointer())
                {
                    auto array = genExpression(indexExp.target);
                    return b.buildGEP(array, gep_indices[1 .. 2], "index");
                }
                else assert(0, "Can only index pointers and arrays");
            case ExpType.MemberReference:
                auto mem = cast(MemberReference)exp;
                Stdout(mem.target).newline;
                switch(mem.target.expType)
                {
                    case ExpType.Identifier:
                        auto identifier = cast(Identifier)mem.target;
                        auto child = mem.child;
                        auto sym = exp.env.find(identifier);
                        auto symChild = child.env.find(child);
                        Value v = table.find(sym.id.get);
                        DType t = sym.type;
                        auto st = t.asStruct;

                        int i = st.indexOf(child.get);

                        Value[] vals;   
                        vals ~= ConstantInt.Get(IntegerType.Int32, 0, false);
                        vals ~= ConstantInt.Get(IntegerType.Int32, i, false);

                        Value val = b.buildGEP(v, vals, sym.id.get~"."~child.get);
                        return val;

                    case ExpType.MemberReference:
                        auto v = getPointer(mem.target);
                        auto child = mem.child;
                        auto symChild = child.env.find(child);
                        DType t = mem.target.type;
                        auto st = t.asStruct;

                        int i = st.indexOf(child.get);

                        Value[] vals;   
                        vals ~= ConstantInt.Get(IntegerType.Int32, 0, false);
                        vals ~= ConstantInt.Get(IntegerType.Int32, i, false);

                        Value val = b.buildGEP(v, vals, "."~child.get);
                        return val;

                    default:
                        Value val = genExpression(exp);
                        auto AI = b.buildAlloca(val.type, ".s");
                        return b.buildStore(val, AI);
                }
            default:
                Value val = genExpression(exp);
                auto AI = b.buildAlloca(val.type, ".s");
                return b.buildStore(val, AI);
        }
        assert(0, "Reached end of switch in getPointer");
        return null;
    }

    private Value buildAssign(Exp target, Exp exp)
    {
        Value t = getPointer(target);
        Value v = genExpression(exp);

        auto a = cast(PointerType)t.type;

        assert(a, "Assing to type have to be of type PointerType");

        Type value_type = v.type;
        if (auto value_ptr = cast(PointerType)v.type)
        {
            value_type = value_ptr.elementType;

            if (a.elementType is value_type && cast(StructType)value_type)
            {
                // bitcast "from" to i8*
                Value from = b.buildBitCast(v, BytePtr, ".copy_from");
                // bitcast "to" to i8*
                Value to = b.buildBitCast(t, BytePtr, ".copy_to");
                // call llvm.memcpy.i32( "to", "from", type_size, alignment (32 in clang) );
                b.buildCall(llvm_memcpy, [to, from, ConstantInt.GetS(Type.Int32, 4), ConstantInt.GetS(Type.Int32, 32)], null);
                // return "to"
                return t;
            }
        }

        if (value_type != a.elementType)
        {
            IntegerType v_t = cast(IntegerType) value_type;
            IntegerType i_t = cast(IntegerType) a.elementType;
            if (v_t is null || i_t is null)
                throw error(__LINE__, PE.NoImplicitConversion)
                    .arg(a.elementType.toString)
                    .arg(v.type.toString);

            if (v_t.numBits() < i_t.numBits())
                v = b.buildSExt(v, a.elementType, ".cast");
            else
                v = b.buildTrunc(v, a.elementType, ".cast");
        }
        return b.buildStore(v, t);
    }

    Error error(uint line, char[] msg)
    {
        return new Error(msg);
    }

    /**
      Get the LLVM Type corresponding to a DType.

      Currently using the built-in associative array - not sure if it works
      well when the hashes are so uniform.

      Other possibilities would be to find a hash-function that works on
      something as small as 4 bytes or to create a sparse array perhaps.
     */
    Type llvm(DType t)
    {
        if (auto llvm_t = t in type_map)
            return *llvm_t;
        return llvmCreateNew(t);
    }

    // Create an LLVM type and insert it into the type map, and return the
    // result
    Type llvmCreateNew(DType t)
    {
        if (auto i = cast(DInteger)t)
        {
            Type res = IntegerType.Get(i.byteSize() * 8);
            type_map[t] = res;
            return res;
        }
        else if (auto s = t.asStruct)
        {
            SmallArray!(Type, 8) members;
            DType[] array;
            array.length = s.members.length;

            foreach(m; s.members)
                array[m.index] = m.type;

            foreach(m; array)
                members ~= llvm(m);

            Type res = StructType.Get(members.unsafe());
            type_map[t] = res;
            m.addTypeName(s.name, res);
            return res;
        }
        else if (auto f = t.asFunction)
        {
            // We should never have a function returning structs, because of
            // the simplify step
            assert(f.returnType.isStruct() == false, "Can't return structs");
            Type ret_t = llvm(f.returnType);

            SmallArray!(Type, 8) params;
            foreach(param; f.params)
                if (param.isStruct)
                    params ~= PointerType.Get(llvm(param));
                else if (param.isArray)
                    params ~= PointerType.Get(llvm(param));
                else
                    params ~= llvm(param);

            Type res = FunctionType.Get(ret_t, params.unsafe());
            type_map[t] = res;
/*            auto llfunc = m.addFunction(res, f.name);
            
            foreach (i, param; f.params)
                if (param.isStruct)
                    llfunc.addParamAttr(i, ParamAttr.ByVal);

            if (f.firstParamIsReturnValue)
            {
                llfunc.removeParamAttr(0, ParamAttr.ByVal);
                llfunc.addParamAttr(0, ParamAttr.StructRet);
            }
*/
            return res;
        }
        else if (auto f = t.asPointer)
        {
            Type res = PointerType.Get(llvm(f.pointerOf));
            type_map[t] = res;
            return res;
        }
        else if (auto f = t.asArray)
        {
            Type res = ArrayType.Get(llvm(f.arrayOf), f.size);
            type_map[t] = res;
            return res;
        }
        assert(0, "Only integers, structs and functions are supported");
    }

    // Might as well insert all the basic types from the start
    void createBasicTypes()
    {
        type_map[DType.Void] = Type.Void;

        type_map[DType.Bool]   = Type.Int1;
        type_map[DType.Byte]   = Type.Int8;
        type_map[DType.UByte]  = Type.Int8;
        type_map[DType.Short]  = Type.Int16;
        type_map[DType.UShort] = Type.Int16;
        type_map[DType.Int]    = Type.Int32;
        type_map[DType.UInt]   = Type.Int32;
        type_map[DType.Long]   = Type.Int64;
        type_map[DType.ULong]  = Type.Int64;
    }

private:

    // llvm stuff
    Module m;
    Builder b;
    Function llvm_memcpy;
    Type BytePtr;
    Type[DType] type_map;

    FuncDecl[char[]] functions;

    SimpleSymbolTable table;
    alias Value delegate(Value, Value, char[]) OpBuilder;
    static OpBuilder[BinaryExp.Operator] opToLLVM;
}

private class VisitFuncDecls : Visitor!(void)
{
    void delegate(FuncDecl) dg;
    this(void delegate(FuncDecl funcDecl) dg)
    {
        this.dg = dg;
    }

    override void visit(Decl[] decls)
    {
        foreach (decl; decls)
            if (auto f = cast(FuncDecl)decl)
                dg(f);
    }
}

private class SimpleSymbolTable
{
    Value[char[]][] namedValues;

    void enterScope()
    {
        namedValues ~= cast(Value[char[]])["__dollar":null];
    }

    void leaveScope()
    {
        namedValues.length = namedValues.length - 1;
    }

    Value put(Value val, char[] key)
    {
        namedValues[$ - 1][key] = val;
        return val;
    }

    Value find(char[] key)
    {
        foreach_reverse (map; namedValues)
            if(auto val_ptr = key in map)
                return *val_ptr;
        return null;
    }

    alias find opIndex;
    alias put opIndexAssign;
}