changeset 36:ce17bea8e9bd new_gen

Switch statements support Can only switch on IntegerLit's but multiple values per case and the default are supported. An error is emitted if a value is used multiple times or if theres is more than one default block
author Anders Halager <halager@gmail.com>
date Sun, 20 Apr 2008 22:39:07 +0200
parents 371e8cfec764
children 858b9805843d
files ast/Stmt.d gen/LLVMGen.d lexer/Keyword.d lexer/Lexer.d lexer/Token.d misc/Error.d parser/Parser.d sema/Visitor.d test.td
diffstat 9 files changed, 251 insertions(+), 18 deletions(-) [+]
line wrap: on
line diff
--- a/ast/Stmt.d	Sun Apr 20 21:33:50 2008 +0200
+++ b/ast/Stmt.d	Sun Apr 20 22:39:07 2008 +0200
@@ -1,9 +1,13 @@
 module ast.Stmt;
 
+import Array = tango.core.Array,
+       Integer = tango.text.convert.Integer;
+
 import ast.Exp,
        ast.Decl;
 
-import sema.SymbolTable;
+import sema.SymbolTable,
+       misc.Error;
 
 enum StmtType
 {
@@ -13,6 +17,7 @@
     Return,
     If,
     While,
+    Switch,
 }
 
 class Stmt
@@ -86,3 +91,71 @@
     Stmt[] stmts;
 }
 
+class SwitchStmt : Stmt
+{
+    this(Exp target)
+    {
+        super(StmtType.Switch);
+        cond = target;
+    }
+
+    void addCase(IntegerLit[] values, Stmt[] stmts)
+    {
+        long[] new_values;
+        foreach (lit; values)
+            new_values ~= Integer.parse(lit.token.get);
+        cases ~= Case(values, stmts, new_values);
+
+        // Make sure there is no two cases with the same value
+        // Does it belong here?
+        new_values = new_values.dup;
+        Array.sort(new_values);
+        long[] all_values = Array.unionOf(old_values, new_values);
+        if (all_values.length != old_values.length + new_values.length)
+        {
+            // overlap!
+            auto e = new Error(
+                    "Can't have multiple cases with the same value."
+                    " Values appearing in multiple cases: %0");
+            e.loc(values[0].token.location);
+
+            all_values = Array.intersectionOf(old_values, new_values);
+            char[][] vals;
+            foreach (val; all_values)
+                vals ~= Integer.toString(val);
+            e.arg(vals);
+            /*
+            foreach (c; cases)
+                foreach (i, v; c.values_converted)
+                    if (Array.bsearch(all_values, v))
+                        e.tok(c.values[i].token);
+            */
+            throw e;
+        }
+        old_values = all_values;
+    }
+
+    void setDefault(Stmt[] stmts)
+    {
+        if (defaultBlock.length != 0)
+            throw new Error("Switch statements can't have multiple defaults");
+        defaultBlock = stmts;
+        if (cases.length > 0)
+            cases[$ - 1].followedByDefault = true;
+    }
+
+    Exp cond;
+    Case[] cases;
+    Stmt[] defaultBlock;
+
+    struct Case
+    {
+        IntegerLit[] values;
+        Stmt[] stmts;
+        long[] values_converted;
+        bool followedByDefault = false;
+    }
+
+    private long[] old_values;
+}
+
--- a/gen/LLVMGen.d	Sun Apr 20 21:33:50 2008 +0200
+++ b/gen/LLVMGen.d	Sun Apr 20 22:39:07 2008 +0200
@@ -94,7 +94,7 @@
         m.verify(err);
         Stderr(err).newline;
 
-        m.optimize(true);
+        // m.optimize(true);
 
         m.writeBitcodeToFile("test.bc");
     }
@@ -293,6 +293,7 @@
                         throw error(__LINE__, PE.NoImplicitConversion)
                             .arg(v.type.toString)
                             .arg(t.toString);
+
                     if (v_t.numBits() < i_t.numBits())
                         v = b.buildSExt(v, t, ".cast");
                     else
@@ -370,6 +371,56 @@
 
                 b.positionAtEnd(doneBB);
                 break;
+            case StmtType.Switch:
+                auto sw = cast(SwitchStmt)stmt;
+                Value cond = genExpression(sw.cond);
+
+                auto func_name = stmt.env.parentFunction().id.get;
+                Function func = m.getNamedFunction(func_name);
+
+                BasicBlock oldBB = b.getInsertBlock();
+                BasicBlock defBB;
+                BasicBlock endBB = func.appendBasicBlock("sw.end");
+                if (sw.defaultBlock)
+                {
+                    defBB = Function.InsertBasicBlock(endBB, "sw.def");
+                    b.positionAtEnd(defBB);
+                    foreach (case_statement; sw.defaultBlock)
+                        genStmt(case_statement);
+                    if (!defBB.terminated())
+                        b.buildBr(endBB);
+                    b.positionAtEnd(oldBB);
+                }
+                else
+                    defBB = endBB;
+                auto SI = b.buildSwitch(cond, defBB, sw.cases.length);
+                foreach (c; sw.cases)
+                {
+                    BasicBlock prevBB;
+                    foreach (i, val; c.values)
+                    {
+                        auto BB = Function.InsertBasicBlock(defBB, "sw.bb");
+                        SI.addCase(ConstantInt.GetS(cond.type, c.values_converted[i]), BB);
+
+                        if (i + 1 == c.values.length)
+                        {
+                            b.positionAtEnd(BB);
+                            foreach (case_statement; c.stmts)
+                                genStmt(case_statement);
+                            if (!BB.terminated())
+                                b.buildBr(c.followedByDefault? defBB : endBB);
+                        }
+
+                        if (prevBB !is null && !prevBB.terminated())
+                        {
+                            b.positionAtEnd(prevBB);
+                            b.buildBr(BB);
+                        }
+                        prevBB = BB;
+                    }
+                }
+                b.positionAtEnd(endBB);
+                break;
         }
     }
 
--- a/lexer/Keyword.d	Sun Apr 20 21:33:50 2008 +0200
+++ b/lexer/Keyword.d	Sun Apr 20 22:39:07 2008 +0200
@@ -8,6 +8,7 @@
 {
     keywords =
     [
+        // types
         "byte"[]    : Tok.Byte,
         "ubyte"     : Tok.Ubyte,
         "short"     : Tok.Short,
@@ -22,10 +23,16 @@
         "float"     : Tok.Float,
         "double"    : Tok.Double,
 
+        // type related
+        "struct"    : Tok.Struct,
+
+        // control flow
         "if"        : Tok.If,
         "else"      : Tok.Else,
         "while"     : Tok.While,
-        "return"    : Tok.Return,
-        "struct"    : Tok.Struct
+        "switch"    : Tok.Switch,
+        "case"      : Tok.Case,
+        "default"   : Tok.Default,
+        "return"    : Tok.Return
     ];
 }
--- a/lexer/Lexer.d	Sun Apr 20 21:33:50 2008 +0200
+++ b/lexer/Lexer.d	Sun Apr 20 22:39:07 2008 +0200
@@ -81,6 +81,8 @@
                 return Token(Tok.CloseBrace, Location(position - 1, this.source), 1);
             case ';':
                 return Token(Tok.Seperator, Location(position - 1, this.source), 1);
+            case ':':
+                return Token(Tok.Colon, Location(position - 1, this.source), 1);
             case '.':
                 return Token(Tok.Dot, Location(position - 1, this.source), 1);
             case ',':
@@ -212,6 +214,7 @@
             case '{':
             case '}':
             case ';':
+            case ':':
             case '.':
             case ',':
             case '=':
--- a/lexer/Token.d	Sun Apr 20 21:33:50 2008 +0200
+++ b/lexer/Token.d	Sun Apr 20 22:39:07 2008 +0200
@@ -58,6 +58,7 @@
     OpenBrace,
     CloseBrace,
     Seperator,
+    Colon,
     Dot,
 
     /* Comparator operators */
@@ -81,6 +82,7 @@
 
     If, Else,
     While,
+    Switch, Case, Default,
     Return,
 
 }
@@ -117,9 +119,13 @@
         Tok.Integer:"Integer",
         Tok.If:"If",
         Tok.While:"While",
+        Tok.Switch:"Switch",
+        Tok.Case:"Case",
+        Tok.Default:"Default",
         Tok.Comma:"Comma",
         Tok.Return:"Return",
         Tok.Struct:"Struct",
+        Tok.Colon:"Colon",
         Tok.Seperator:"Seperator"
     ];
 }
--- a/misc/Error.d	Sun Apr 20 21:33:50 2008 +0200
+++ b/misc/Error.d	Sun Apr 20 22:39:07 2008 +0200
@@ -71,7 +71,8 @@
     Error arg(char[][] s)
     {
         char[] res = s[0 .. $ - 1].join(", ");
-        res ~= " and ";
+        if (s.length > 1)
+            res ~= " and ";
         res ~= s[$ - 1];
         return arg(res);
     }
--- a/parser/Parser.d	Sun Apr 20 21:33:50 2008 +0200
+++ b/parser/Parser.d	Sun Apr 20 22:39:07 2008 +0200
@@ -70,7 +70,7 @@
                                 return new VarDecl(type, identifier, exp);
                             default:
                                 char[] c = p.getType;
-                                throw error(__LINE__, UnexpextedTokMulti)
+                                throw error(__LINE__, UnexpectedTokMulti)
                                     .tok(p)
                                     .arg(c)
                                     .arg(Tok.OpenParentheses, Tok.Seperator, Tok.Assign);
@@ -78,7 +78,7 @@
                         break;
                     default:
                         char[] c = t.getType;
-                        throw error(__LINE__, UnexpextedTok).tok(iden).arg(c);
+                        throw error(__LINE__, UnexpectedTok).tok(iden).arg(c);
                 }
                 break;
             case Tok.Struct:
@@ -96,7 +96,7 @@
                 return null;
             default:
                 char[] c = t.getType;
-                throw error(__LINE__, UnexpextedTok).tok(t).arg(c);
+                throw error(__LINE__, UnexpectedTok).tok(t).arg(c);
         }
     }
 
@@ -139,7 +139,7 @@
                                 return new VarDecl(type, identifier, exp);
                             default:
                                 char[] c = p.getType;
-                                throw error(__LINE__, UnexpextedTokMulti)
+                                throw error(__LINE__, UnexpectedTokMulti)
                                     .tok(p)
                                     .arg(c)
                                     .arg(Tok.OpenParentheses, Tok.Seperator, Tok.Assign);
@@ -147,7 +147,7 @@
                         break;
                     default:
                         char[] c = iden.getType;
-                        throw error(__LINE__, UnexpextedTokSingle)
+                        throw error(__LINE__, UnexpectedTokSingle)
                             .tok(iden)
                             .arg(c)
                             .arg(Tok.Identifier);
@@ -157,7 +157,7 @@
                 return null;
             default:
                 char[] c = t.getType;
-                throw error(__LINE__, UnexpextedTok).arg(c);
+                throw error(__LINE__, UnexpectedTok).arg(c);
         }
     }
 
@@ -245,6 +245,57 @@
                 }
                 break;
 
+            case Tok.Switch:
+                lexer.next;
+                require(Tok.OpenParentheses);
+                auto target = parseExpression();
+                auto res = new SwitchStmt(target);
+                require(Tok.CloseParentheses);
+                require(Tok.OpenBrace);
+                while (true)
+                {
+                    Stmt[] statements;
+                    if (skip(Tok.Default))
+                    {
+                        require(Tok.Colon);
+                        statements.length = 0;
+                        while (lexer.peek.type != Tok.Case
+                                && lexer.peek.type != Tok.Default
+                                && lexer.peek.type != Tok.CloseBrace)
+                            statements ~= parseStatement();
+                        res.setDefault(statements);
+                        continue;
+                    }
+
+                    Token _case = require(Tok.Case);
+
+                    IntegerLit[] literals;
+                    do
+                    {
+                        Exp e = parseExpression();
+                        IntegerLit lit = cast(IntegerLit)e;
+                        if (lit is null)
+                            throw error(__LINE__, CaseValueMustBeInt)
+                                .tok(_case);
+
+                        literals ~= lit;
+                    }
+                    while (skip(Tok.Comma));
+                    require(Tok.Colon);
+
+                    while (lexer.peek.type != Tok.Case
+                            && lexer.peek.type != Tok.Default
+                            && lexer.peek.type != Tok.CloseBrace)
+                        statements ~= parseStatement();
+
+                    res.addCase(literals, statements);
+
+                    if (lexer.peek.type == Tok.CloseBrace)
+                        break;
+                }
+                require(Tok.CloseBrace);
+                return res;
+
             default:
                 auto decl = new DeclStmt(parseDecl());
                 //require(Tok.Seperator);
@@ -300,7 +351,7 @@
                 return new Identifier(identifier);
                 break;
             default:
-                throw error(__LINE__, "Unexpexted token in Identifier parsing. Got %0")
+                throw error(__LINE__, "Unexpected token in Identifier parsing. Got %0")
                     .arg(identifier.getType)
                     .tok(identifier);
         }
@@ -323,7 +374,7 @@
                 break;
             default:
                 char[] c = type.getType;
-                error(__LINE__, "Unexpexted token in Type parsing. Got %0").arg(c);
+                error(__LINE__, "Unexpected token in Type parsing. Got %0").arg(c);
         }
     }
 
@@ -472,13 +523,21 @@
 
 private:
 
-    void require(Tok t)
+    Token require(Tok t)
     {
         if (lexer.peek().type != t)
-            error(__LINE__, UnexpextedTokSingle)
+            throw error(__LINE__, UnexpectedTokSingle)
                 .arg(lexer.peek.getType)
                 .arg(t);
+        return lexer.next();
+    }
+
+    bool skip(Tok t)
+    {
+        if (lexer.peek().type != t)
+            return false;
         lexer.next();
+        return true;
     }
 
     Error error(uint line, char[] errMsg, Token* tok = null)
@@ -497,9 +556,12 @@
     }
 
     static char[]
-        UnexpextedTokMulti = "Unexpexted token, got %0 expected one of %1",
-        UnexpextedTokSingle = "Unexpexted token, got %0 expected %1",
-        UnexpextedTok = "Unexpexted token %0";
+        UnexpectedTokMulti = "Unexpected token, got %0 expected one of %1",
+        UnexpectedTokSingle = "Unexpected token, got %0 expected %1",
+        UnexpectedTok = "Unexpected token %0";
+
+    static char[]
+        CaseValueMustBeInt = "Cases can only be integer literals";
 
     Lexer lexer;
 }
--- a/sema/Visitor.d	Sun Apr 20 21:33:50 2008 +0200
+++ b/sema/Visitor.d	Sun Apr 20 22:39:07 2008 +0200
@@ -51,6 +51,8 @@
                 return visitIfStmt(cast(IfStmt)stmt);
             case StmtType.While:
                 return visitWhileStmt(cast(WhileStmt)stmt);
+            case StmtType.Switch:
+                return visitSwitchStmt(cast(SwitchStmt)stmt);
             default:
                 throw new Exception("Unknown statement type");
         }
@@ -164,6 +166,24 @@
             return StmtT.init;
     }
 
+    StmtT visitSwitchStmt(SwitchStmt s)
+    {
+        visitExp(s.cond);
+        foreach(stmt; s.defaultBlock)
+            visitStmt(stmt);
+        foreach (c; s.cases)
+        {
+            foreach(lit; c.values)
+                visitIntegerLit(lit);
+            foreach(stmt; c.stmts)
+                visitStmt(stmt);
+        }
+        static if (is(StmtT == void))
+            return;
+        else
+            return StmtT.init;
+    }
+
     StmtT visitExpStmt(ExpStmt s)
     {
         visitExp(s.exp);
--- a/test.td	Sun Apr 20 21:33:50 2008 +0200
+++ b/test.td	Sun Apr 20 22:39:07 2008 +0200
@@ -12,6 +12,16 @@
 int main()
 {
     int y = 4;
+    switch (y)
+    {
+        case 2:
+            y = 3;
+        case 3:
+        default:
+            y = 5;
+        case 5, 6, 7:
+            return 1;
+    }
 
     karina k;