view sema/Visitor.d @ 185:7b274cfdc1dc

Added support for array literals. Codegen is broken, though.
author Anders Johnsen <skabet@gmail.com>
date Fri, 25 Jul 2008 12:18:05 +0200
parents dc9bf56b7ace
children 08f68d684047
line wrap: on
line source

module sema.Visitor;

import tango.io.Stdout;

public
import ast.Module,
       ast.Decl,
       ast.Stmt,
       ast.Exp;

import lexer.Token;

class Visitor(FinalT = int, ModuleT = FinalT, DeclT = ModuleT, StmtT = DeclT, ExpT = StmtT)
{
public:
    FinalT visit(Module[] modules)
    {
        foreach(m ; modules)
            visitModule(m);
        static if (is(FinalT == void))
            return;
        else
            return FinalT.init;
    }

    ModuleT visitModule(Module m)
    {
        foreach (decl; m.decls)
            visitDecl(decl);
        static if (is(ModuleT == void))
            return;
        else
            return ModuleT.init;
    }

    DeclT visitDecl(Decl decl)
    {
        switch(decl.declType)
        {
            case DeclType.FuncDecl:
                return visitFuncDecl(cast(FuncDecl)decl);
            case DeclType.VarDecl:
                return visitVarDecl(cast(VarDecl)decl);
            case DeclType.ImportDecl:
                return visitImportDecl(cast(ImportDecl)decl);
            case DeclType.StructDecl:
                return visitStructDecl(cast(StructDecl)decl);
            case DeclType.ClassDecl:
                return visitClassDecl(cast(ClassDecl)decl);
            case DeclType.InterfaceDecl:
                return visitInterfaceDecl(cast(InterfaceDecl)decl);
            default:
                throw new Exception("Unknown declaration type");
        }
    }

    StmtT visitStmt(Stmt stmt)
    {
        switch(stmt.stmtType)
        {
            case StmtType.Return:
                return visitReturnStmt(cast(ReturnStmt)stmt);
            case StmtType.Compound:
                return visitCompoundStmt(cast(CompoundStatement)stmt);
            case StmtType.Decl:
                return visitDeclStmt(cast(DeclStmt)stmt);
            case StmtType.Exp:
                return visitExpStmt(cast(ExpStmt)stmt);
            case StmtType.If:
                return visitIfStmt(cast(IfStmt)stmt);
            case StmtType.While:
                return visitWhileStmt(cast(WhileStmt)stmt);
            case StmtType.For:
                return visitForStmt(cast(ForStmt)stmt);
            case StmtType.Switch:
                return visitSwitchStmt(cast(SwitchStmt)stmt);
            default:
                throw new Exception("Unknown statement type");
        }
    }

    ExpT visitExp(Exp exp)
    {
        switch(exp.expType)
        {
            case ExpType.Binary:
                return visitBinaryExp(cast(BinaryExp)exp);
            case ExpType.IntegerLit:
                return visitIntegerLit(cast(IntegerLit)exp);
            case ExpType.Negate:
                return visitNegateExp(cast(NegateExp)exp);
            case ExpType.Deref:
                return visitDerefExp(cast(DerefExp)exp);
            case ExpType.AddressOfExp:
                return visitAddressOfExp(cast(AddressOfExp)exp);
            case ExpType.AssignExp:
                return visitAssignExp(cast(AssignExp)exp);
            case ExpType.CallExp:
                return visitCallExp(cast(CallExp)exp);
            case ExpType.CastExp:
                return visitCastExp(cast(CastExp)exp);
            case ExpType.Identifier:
                return visitIdentifier(cast(Identifier)exp);
            case ExpType.IdentifierTypeExp:
                return visitIdentifier(cast(Identifier)exp);
            case ExpType.PointerTypeExp:
                return visitPointerTypeExp(cast(PointerTypeExp)exp);
            case ExpType.StaticArrayTypeExp:
                return visitStaticArrayTypeExp(cast(StaticArrayTypeExp)exp);
            case ExpType.FunctionTypeExp:
                return visitFunctionTypeExp(cast(FunctionTypeExp)exp);
            case ExpType.StringExp:
                return visitStringExp(cast(StringExp)exp);
            case ExpType.Index:
                return visitIndexExp(cast(IndexExp)exp);
            case ExpType.MemberReference:
                return visitMemberReference(cast(MemberReference)exp);
            case ExpType.NewExp:
                return visitNewExp(cast(NewExp)exp);
            case ExpType.ArrayLiteralExp:
                return visitArrayLiteralExp(cast(ArrayLiteralExp)exp);
            default:
                throw new Exception("Unknown expression type");
        }
    }

    // Declarations:
    DeclT visitVarDecl(VarDecl d)
    {
        visitExp(d.varType);
        if(d.identifier)
            visitExp(d.identifier);
        if (d.init)
            visitExp(d.init);

        static if (is(DeclT == void))
            return;
        else
            return DeclT.init;
    }

    DeclT visitImportDecl(ImportDecl d)
    {
        visitIdentifier(d.name);
        visitIdentifier(d.aliasedName);
        foreach (id; d.packages)
            visitIdentifier(id);
        foreach (ids; d.explicitSymbols)
        {
            visitIdentifier(ids[0]);
            visitIdentifier(ids[1]);
        }

        static if (is(DeclT == void))
            return;
        else
            return DeclT.init;
    }

    DeclT visitFuncDecl(FuncDecl f)
    {
        visitExp(f.returnType);
        visitExp(f.identifier);
        foreach (arg; f.funcArgs)
            visitDecl(arg);
        foreach (stmt; f.statements)
            visitStmt(stmt);

        static if (is(DeclT == void))
            return;
        else
            return DeclT.init;
    }

    DeclT visitStructDecl(StructDecl s)
    {
        visitExp(s.identifier);

        foreach (arg; s.decls)
            visitDecl(arg);

        static if (is(DeclT == void))
            return;
        else
            return DeclT.init;
    }

    DeclT visitClassDecl(ClassDecl s)
    {
        visitExp(s.identifier);

        foreach (arg; s.decls)
            visitDecl(arg);

        foreach (arg; s.baseClasses)
            visitExp(arg);

        static if (is(DeclT == void))
            return;
        else
            return DeclT.init;
    }

    DeclT visitInterfaceDecl(InterfaceDecl s)
    {
        visitExp(s.identifier);

        foreach (arg; s.decls)
            visitDecl(arg);

        foreach (arg; s.baseClasses)
            visitExp(arg);

        static if (is(DeclT == void))
            return;
        else
            return DeclT.init;
    }

    // Statements:
    StmtT visitReturnStmt(ReturnStmt s)
    {
        if (s.exp)
            visitExp(s.exp);
        static if (is(StmtT == void))
            return;
        else
            return StmtT.init;
    }

    StmtT visitDeclStmt(DeclStmt d)
    {
        visitDecl(d.decl);
        static if (is(StmtT == void))
            return;
        else
            return StmtT.init;
    }

    StmtT visitCompoundStmt(CompoundStatement c)
    {
        foreach (stmt; c.statements)
            visitStmt(stmt);
        static if (is(StmtT == void))
            return;
        else
            return StmtT.init;
    }

    StmtT visitIfStmt(IfStmt s)
    {
        visitExp(s.cond);
        visitStmt(s.then_body);
        if (s.else_body !is null)
            visitStmt(s.else_body);
        static if (is(StmtT == void))
            return;
        else
            return StmtT.init;
    }

    StmtT visitWhileStmt(WhileStmt s)
    {
        visitExp(s.cond);
        visitStmt(s.whileBody);
        static if (is(StmtT == void))
            return;
        else
            return StmtT.init;
    }
    
    StmtT visitForStmt(ForStmt s)
    {
        if(s.init)
            visitStmt(s.init);
        if(s.cond)
            visitExp(s.cond);
        if(s.incre)
            visitExp(s.incre);
        visitStmt(s.forBody);
        static if (is(StmtT == void))
            return;
        else
            return StmtT.init;
    }

    StmtT visitSwitchStmt(SwitchStmt s)
    {
        visitExp(s.cond);
        foreach(stmt; s.defaultBlock)
            visitStmt(stmt);
        foreach (c; s.cases)
        {
            foreach(lit; c.values)
                visitExp(lit);
            foreach(stmt; c.stmts)
                visitStmt(stmt);
        }
        static if (is(StmtT == void))
            return;
        else
            return StmtT.init;
    }

    StmtT visitExpStmt(ExpStmt s)
    {
        visitExp(s.exp);
        static if (is(StmtT == void))
            return;
        else
            return StmtT.init;
    }

    // Expressions:
    ExpT visitAssignExp(AssignExp exp)
    {
        visitExp(exp.identifier);
        visitExp(exp.exp);
        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitBinaryExp(BinaryExp exp)
    {
        visitExp(exp.left);
        visitExp(exp.right);
        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitCallExp(CallExp exp)
    {
        visitExp(exp.exp);
        foreach (arg; exp.args)
            visitExp(arg);
        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitCastExp(CastExp exp)
    {
        visitExp(exp.castType);
        visitExp(exp.exp);
        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitNegateExp(NegateExp exp)
    {
        visitExp(exp.exp);
        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitDerefExp(DerefExp exp)
    {
        visitExp(exp.exp);
        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitAddressOfExp(AddressOfExp exp)
    {
        visitExp(exp.exp);
        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitIntegerLit(IntegerLit exp)
    {
        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitStringExp(StringExp exp)
    {
        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitIdentifier(Identifier exp)
    {
        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitPointerTypeExp(PointerTypeExp exp)
    {
        visitExp(exp.pointerOf);

        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitStaticArrayTypeExp(StaticArrayTypeExp exp)
    {
        visitExp(exp.arrayOf);

        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitFunctionTypeExp(FunctionTypeExp exp)
    {
        visitExp(exp.returnType);

        foreach (decl ; exp.decls)
            visitDecl(decl);

        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitIndexExp(IndexExp exp)
    {
        visitExp(exp.target);
        visitExp(exp.index);

        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitMemberReference(MemberReference mem)
    {
        visitExp(mem.target);
        visitExp(mem.child);

        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitNewExp(NewExp n)
    {
        visitExp(n.newType);

        foreach( a ; n.a_args )
            visitExp(a);

        foreach( c ; n.c_args )
            visitExp(c);

        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }

    ExpT visitArrayLiteralExp(ArrayLiteralExp a)
    {
        foreach( e ; a.exps )
            visitExp(e);

        static if (is(ExpT == void))
            return;
        else
            return ExpT.init;
    }
}