diff gen/statements.cpp @ 122:36ab367572df trunk

[svn r126] String switch is now implemented. A few other fixes.
author lindquist
date Tue, 27 Nov 2007 09:19:07 +0100
parents 9c79b61fb638
children 7f9a0a58394b
line wrap: on
line diff
--- a/gen/statements.cpp	Tue Nov 27 03:09:36 2007 +0100
+++ b/gen/statements.cpp	Tue Nov 27 09:19:07 2007 +0100
@@ -526,9 +526,52 @@
 
 //////////////////////////////////////////////////////////////////////////////
 
+// used to build the sorted list of cases
+struct Case : Object
+{
+    StringExp* str;
+    size_t index;
+
+    Case(StringExp* s, size_t i) {
+        str = s;
+        index = i;
+    }
+
+    int compare(Object *obj) {
+        Case* c2 = (Case*)obj;
+        return str->compare(c2->str);
+    }
+};
+
+static llvm::Value* call_string_switch_runtime(llvm::GlobalVariable* table, Expression* e)
+{
+    Type* dt = DtoDType(e->type);
+    Type* dtnext = DtoDType(dt->next);
+    TY ty = dtnext->ty;
+    const char* fname;
+    if (ty == Tchar) {
+        fname = "_d_switch_string";
+    }
+    else if (ty == Twchar) {
+        fname = "_d_switch_ustring";
+    }
+    else if (ty == Tdchar) {
+        fname = "_d_switch_dstring";
+    }
+    else {
+        assert(0 && "not char/wchar/dchar");
+    }
+
+    llvm::Function* fn = LLVM_D_GetRuntimeFunction(gIR->module, fname);
+    std::vector<llvm::Value*> args;
+    args.push_back(table);
+    args.push_back(e->toElem(gIR)->getRVal());
+    return gIR->ir->CreateCall(fn, args.begin(), args.end(), "tmp");
+}
+
 void SwitchStatement::toIR(IRState* p)
 {
-    Logger::println("SwitchStatement::toIR(): %s", toChars());
+    Logger::println("SwitchStatement::toIR()");
     LOG_SCOPE;
 
     llvm::BasicBlock* oldend = gIR->scopeend();
@@ -537,24 +580,32 @@
     typedef std::pair<llvm::BasicBlock*, std::vector<llvm::ConstantInt*> > CasePair;
     std::vector<CasePair> vcases;
     std::vector<Statement*> vbodies;
+    Array caseArray;
     for (int i=0; i<cases->dim; ++i)
     {
         CaseStatement* cs = (CaseStatement*)cases->data[i];
 
-        // create the case bb with a nice label
-        std::string lblname("case"+std::string(cs->exp->toChars()));
+        std::string lblname("case");
         llvm::BasicBlock* bb = new llvm::BasicBlock(lblname, p->topfunc(), oldend);
 
         std::vector<llvm::ConstantInt*> tmp;
         CaseStatement* last;
+        bool first = true;
         do {
-            // get the case value
-            DValue* e = cs->exp->toElem(p);
-            DConstValue* ce = e->isConst();
-            assert(ce);
-            llvm::ConstantInt* ec = isaConstantInt(ce->c);
-            assert(ec);
-            tmp.push_back(ec);
+            // integral case
+            if (cs->exp->type->isintegral()) {
+                llvm::Constant* c = cs->exp->toConstElem(p);
+                tmp.push_back(isaConstantInt(c));
+            }
+            // string case
+            else {
+                assert(cs->exp->op == TOKstring);
+                // for string switches this is unfortunately necessary or there will be duplicates in the list
+                if (first) {
+                    caseArray.push(new Case((StringExp*)cs->exp, i));
+                    first = false;
+                }
+            }
             last = cs;
         }
         while (cs = cs->statement->isCaseStatement());
@@ -563,6 +614,42 @@
         vbodies.push_back(last->statement);
     }
 
+    // string switch?
+    llvm::GlobalVariable* switchTable = 0;
+    if (!condition->type->isintegral())
+    {
+        // first sort it
+        caseArray.sort();
+        // iterate and add indices to cases
+        std::vector<llvm::Constant*> inits;
+        for (size_t i=0; i<caseArray.dim; ++i)
+        {
+            Case* c = (Case*)caseArray.data[i];
+            vcases[c->index].second.push_back(DtoConstUint(i));
+            inits.push_back(c->str->toConstElem(p));
+        }
+        // build static array for ptr or final array
+        const llvm::Type* elemTy = DtoType(condition->type);
+        const llvm::ArrayType* arrTy = llvm::ArrayType::get(elemTy, inits.size());
+        llvm::Constant* arrInit = llvm::ConstantArray::get(arrTy, inits);
+        llvm::GlobalVariable* arr = new llvm::GlobalVariable(arrTy, true, llvm::GlobalValue::InternalLinkage, arrInit, "string_switch_table_data", gIR->module);
+
+        const llvm::Type* elemPtrTy = llvm::PointerType::get(elemTy);
+        llvm::Constant* arrPtr = llvm::ConstantExpr::getBitCast(arr, elemPtrTy);
+
+        // build the static table
+        std::vector<const llvm::Type*> types;
+        types.push_back(DtoSize_t());
+        types.push_back(elemPtrTy);
+        const llvm::StructType* sTy = llvm::StructType::get(types);
+        std::vector<llvm::Constant*> sinits;
+        sinits.push_back(DtoConstSize_t(inits.size()));
+        sinits.push_back(arrPtr);
+        llvm::Constant* sInit = llvm::ConstantStruct::get(sTy, sinits);
+
+        switchTable = new llvm::GlobalVariable(sTy, true, llvm::GlobalValue::InternalLinkage, sInit, "string_switch_table", gIR->module);
+    }
+
     // default
     llvm::BasicBlock* defbb = 0;
     if (!hasNoDefault) {
@@ -573,9 +660,17 @@
     llvm::BasicBlock* endbb = new llvm::BasicBlock("switchend", p->topfunc(), oldend);
 
     // condition var
-    DValue* cond = condition->toElem(p);
-    llvm::SwitchInst* si = new llvm::SwitchInst(cond->getRVal(), defbb ? defbb : endbb, cases->dim, p->scopebb());
-    delete cond;
+    llvm::Value* condVal;
+    // integral switch
+    if (condition->type->isintegral()) {
+        DValue* cond = condition->toElem(p);
+        condVal = cond->getRVal();
+    }
+    // string switch
+    else {
+        condVal = call_string_switch_runtime(switchTable, condition);
+    }
+    llvm::SwitchInst* si = new llvm::SwitchInst(condVal, defbb ? defbb : endbb, cases->dim, p->scopebb());
 
     // add the cases
     size_t n = vcases.size();