changeset 122:36ab367572df trunk

[svn r126] String switch is now implemented. A few other fixes.
author lindquist
date Tue, 27 Nov 2007 09:19:07 +0100
parents 9c79b61fb638
children 7f9a0a58394b
files gen/statements.cpp gen/toir.cpp llvmdc.kdevelop.filelist lphobos/build.sh lphobos/internal/switch.d test/switch3.d
diffstat 6 files changed, 572 insertions(+), 16 deletions(-) [+]
line wrap: on
line diff
--- a/gen/statements.cpp	Tue Nov 27 03:09:36 2007 +0100
+++ b/gen/statements.cpp	Tue Nov 27 09:19:07 2007 +0100
@@ -526,9 +526,52 @@
 
 //////////////////////////////////////////////////////////////////////////////
 
+// used to build the sorted list of cases
+struct Case : Object
+{
+    StringExp* str;
+    size_t index;
+
+    Case(StringExp* s, size_t i) {
+        str = s;
+        index = i;
+    }
+
+    int compare(Object *obj) {
+        Case* c2 = (Case*)obj;
+        return str->compare(c2->str);
+    }
+};
+
+static llvm::Value* call_string_switch_runtime(llvm::GlobalVariable* table, Expression* e)
+{
+    Type* dt = DtoDType(e->type);
+    Type* dtnext = DtoDType(dt->next);
+    TY ty = dtnext->ty;
+    const char* fname;
+    if (ty == Tchar) {
+        fname = "_d_switch_string";
+    }
+    else if (ty == Twchar) {
+        fname = "_d_switch_ustring";
+    }
+    else if (ty == Tdchar) {
+        fname = "_d_switch_dstring";
+    }
+    else {
+        assert(0 && "not char/wchar/dchar");
+    }
+
+    llvm::Function* fn = LLVM_D_GetRuntimeFunction(gIR->module, fname);
+    std::vector<llvm::Value*> args;
+    args.push_back(table);
+    args.push_back(e->toElem(gIR)->getRVal());
+    return gIR->ir->CreateCall(fn, args.begin(), args.end(), "tmp");
+}
+
 void SwitchStatement::toIR(IRState* p)
 {
-    Logger::println("SwitchStatement::toIR(): %s", toChars());
+    Logger::println("SwitchStatement::toIR()");
     LOG_SCOPE;
 
     llvm::BasicBlock* oldend = gIR->scopeend();
@@ -537,24 +580,32 @@
     typedef std::pair<llvm::BasicBlock*, std::vector<llvm::ConstantInt*> > CasePair;
     std::vector<CasePair> vcases;
     std::vector<Statement*> vbodies;
+    Array caseArray;
     for (int i=0; i<cases->dim; ++i)
     {
         CaseStatement* cs = (CaseStatement*)cases->data[i];
 
-        // create the case bb with a nice label
-        std::string lblname("case"+std::string(cs->exp->toChars()));
+        std::string lblname("case");
         llvm::BasicBlock* bb = new llvm::BasicBlock(lblname, p->topfunc(), oldend);
 
         std::vector<llvm::ConstantInt*> tmp;
         CaseStatement* last;
+        bool first = true;
         do {
-            // get the case value
-            DValue* e = cs->exp->toElem(p);
-            DConstValue* ce = e->isConst();
-            assert(ce);
-            llvm::ConstantInt* ec = isaConstantInt(ce->c);
-            assert(ec);
-            tmp.push_back(ec);
+            // integral case
+            if (cs->exp->type->isintegral()) {
+                llvm::Constant* c = cs->exp->toConstElem(p);
+                tmp.push_back(isaConstantInt(c));
+            }
+            // string case
+            else {
+                assert(cs->exp->op == TOKstring);
+                // for string switches this is unfortunately necessary or there will be duplicates in the list
+                if (first) {
+                    caseArray.push(new Case((StringExp*)cs->exp, i));
+                    first = false;
+                }
+            }
             last = cs;
         }
         while (cs = cs->statement->isCaseStatement());
@@ -563,6 +614,42 @@
         vbodies.push_back(last->statement);
     }
 
+    // string switch?
+    llvm::GlobalVariable* switchTable = 0;
+    if (!condition->type->isintegral())
+    {
+        // first sort it
+        caseArray.sort();
+        // iterate and add indices to cases
+        std::vector<llvm::Constant*> inits;
+        for (size_t i=0; i<caseArray.dim; ++i)
+        {
+            Case* c = (Case*)caseArray.data[i];
+            vcases[c->index].second.push_back(DtoConstUint(i));
+            inits.push_back(c->str->toConstElem(p));
+        }
+        // build static array for ptr or final array
+        const llvm::Type* elemTy = DtoType(condition->type);
+        const llvm::ArrayType* arrTy = llvm::ArrayType::get(elemTy, inits.size());
+        llvm::Constant* arrInit = llvm::ConstantArray::get(arrTy, inits);
+        llvm::GlobalVariable* arr = new llvm::GlobalVariable(arrTy, true, llvm::GlobalValue::InternalLinkage, arrInit, "string_switch_table_data", gIR->module);
+
+        const llvm::Type* elemPtrTy = llvm::PointerType::get(elemTy);
+        llvm::Constant* arrPtr = llvm::ConstantExpr::getBitCast(arr, elemPtrTy);
+
+        // build the static table
+        std::vector<const llvm::Type*> types;
+        types.push_back(DtoSize_t());
+        types.push_back(elemPtrTy);
+        const llvm::StructType* sTy = llvm::StructType::get(types);
+        std::vector<llvm::Constant*> sinits;
+        sinits.push_back(DtoConstSize_t(inits.size()));
+        sinits.push_back(arrPtr);
+        llvm::Constant* sInit = llvm::ConstantStruct::get(sTy, sinits);
+
+        switchTable = new llvm::GlobalVariable(sTy, true, llvm::GlobalValue::InternalLinkage, sInit, "string_switch_table", gIR->module);
+    }
+
     // default
     llvm::BasicBlock* defbb = 0;
     if (!hasNoDefault) {
@@ -573,9 +660,17 @@
     llvm::BasicBlock* endbb = new llvm::BasicBlock("switchend", p->topfunc(), oldend);
 
     // condition var
-    DValue* cond = condition->toElem(p);
-    llvm::SwitchInst* si = new llvm::SwitchInst(cond->getRVal(), defbb ? defbb : endbb, cases->dim, p->scopebb());
-    delete cond;
+    llvm::Value* condVal;
+    // integral switch
+    if (condition->type->isintegral()) {
+        DValue* cond = condition->toElem(p);
+        condVal = cond->getRVal();
+    }
+    // string switch
+    else {
+        condVal = call_string_switch_runtime(switchTable, condition);
+    }
+    llvm::SwitchInst* si = new llvm::SwitchInst(condVal, defbb ? defbb : endbb, cases->dim, p->scopebb());
 
     // add the cases
     size_t n = vcases.size();
--- a/gen/toir.cpp	Tue Nov 27 03:09:36 2007 +0100
+++ b/gen/toir.cpp	Tue Nov 27 09:19:07 2007 +0100
@@ -92,11 +92,18 @@
         Logger::println("AliasDeclaration - no work");
         // do nothing
     }
+    // enum
     else if (EnumDeclaration* e = declaration->isEnumDeclaration())
     {
         Logger::println("EnumDeclaration - no work");
         // do nothing
     }
+    // class
+    else if (ClassDeclaration* e = declaration->isClassDeclaration())
+    {
+        Logger::println("ClassDeclaration");
+        DtoForceConstInitDsymbol(e);
+    }
     // unsupported declaration
     else
     {
@@ -364,7 +371,7 @@
     Type* dtype = DtoDType(type);
     Type* cty = DtoDType(dtype->next);
 
-    const llvm::Type* ct = DtoType(dtype->next);
+    const llvm::Type* ct = DtoType(cty);
     //printf("ct = %s\n", type->next->toChars());
     const llvm::ArrayType* at = llvm::ArrayType::get(ct,len+1);
 
--- a/llvmdc.kdevelop.filelist	Tue Nov 27 03:09:36 2007 +0100
+++ b/llvmdc.kdevelop.filelist	Tue Nov 27 09:19:07 2007 +0100
@@ -153,6 +153,7 @@
 lphobos/internal/mem.d
 lphobos/internal/objectimpl.d
 lphobos/internal/qsort2.d
+lphobos/internal/switch.d
 lphobos/llvm
 lphobos/llvm/intrinsic.d
 lphobos/llvm/va_list.d
@@ -327,6 +328,7 @@
 test/bug75.d
 test/bug76.d
 test/bug77.d
+test/bug78.d
 test/bug8.d
 test/bug9.d
 test/c.d
@@ -432,6 +434,7 @@
 test/structs7.d
 test/switch1.d
 test/switch2.d
+test/switch3.d
 test/sync1.d
 test/templ1.d
 test/templ2.d
--- a/lphobos/build.sh	Tue Nov 27 03:09:36 2007 +0100
+++ b/lphobos/build.sh	Tue Nov 27 09:19:07 2007 +0100
@@ -29,10 +29,11 @@
 llvmdc internal/cast.d -c -odobj || exit 1
 llvm-link -f -o=../lib/llvmdcore.bc obj/cast.bc ../lib/llvmdcore.bc || exit 1
 
-echo "compiling string foreach runtime support"
+echo "compiling string foreach/switch runtime support"
 llvmdc internal/aApply.d -c -odobj || exit 1
 llvmdc internal/aApplyR.d -c -odobj || exit 1
-llvm-link -f -o=../lib/llvmdcore.bc obj/aApply.bc obj/aApplyR.bc ../lib/llvmdcore.bc || exit 1
+llvmdc internal/switch.d -c -odobj || exit 1
+llvm-link -f -o=../lib/llvmdcore.bc obj/aApply.bc obj/aApplyR.bc obj/switch.bc ../lib/llvmdcore.bc || exit 1
 
 echo "compiling array runtime support"
 llvmdc internal/qsort2.d -c -odobj || exit 1
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/lphobos/internal/switch.d	Tue Nov 27 09:19:07 2007 +0100
@@ -0,0 +1,426 @@
+/*
+ *  Copyright (C) 2004-2007 by Digital Mars, www.digitalmars.com
+ *  Written by Walter Bright
+ *
+ *  This software is provided 'as-is', without any express or implied
+ *  warranty. In no event will the authors be held liable for any damages
+ *  arising from the use of this software.
+ *
+ *  Permission is granted to anyone to use this software for any purpose,
+ *  including commercial applications, and to alter it and redistribute it
+ *  freely, in both source and binary form, subject to the following
+ *  restrictions:
+ *
+ *  o  The origin of this software must not be misrepresented; you must not
+ *     claim that you wrote the original software. If you use this software
+ *     in a product, an acknowledgment in the product documentation would be
+ *     appreciated but is not required.
+ *  o  Altered source versions must be plainly marked as such, and must not
+ *     be misrepresented as being the original software.
+ *  o  This notice may not be removed or altered from any source
+ *     distribution.
+ */
+
+
+import std.c.stdio;
+import std.c.string;
+import std.string;
+
+/******************************************************
+ * Support for switch statements switching on strings.
+ * Input:
+ *	table[]		sorted array of strings generated by compiler
+ *	ca		string to look up in table
+ * Output:
+ *	result		index of match in table[]
+ *			-1 if not in table
+ */
+
+extern (C):
+
+int _d_switch_string(char[][] table, char[] ca)
+    in
+    {
+	//printf("in _d_switch_string()\n");
+	assert(table.length >= 0);
+	assert(ca.length >= 0);
+
+	// Make sure table[] is sorted correctly
+	int j;
+
+	for (j = 1; j < table.length; j++)
+	{
+	    int len1 = table[j - 1].length;
+	    int len2 = table[j].length;
+
+	    assert(len1 <= len2);
+	    if (len1 == len2)
+	    {
+		int ci;
+
+		ci = memcmp(table[j - 1].ptr, table[j].ptr, len1);
+		assert(ci < 0);	// ci==0 means a duplicate
+	    }
+	}
+    }
+    out (result)
+    {
+	int i;
+	int cj;
+
+	//printf("out _d_switch_string()\n");
+	if (result == -1)
+	{
+	    // Not found
+	    for (i = 0; i < table.length; i++)
+	    {
+		if (table[i].length == ca.length)
+		{   cj = memcmp(table[i].ptr, ca.ptr, ca.length);
+		    assert(cj != 0);
+		}
+	    }
+	}
+	else
+	{
+	    assert(0 <= result && result < table.length);
+	    for (i = 0; 1; i++)
+	    {
+		assert(i < table.length);
+		if (table[i].length == ca.length)
+		{
+		    cj = memcmp(table[i].ptr, ca.ptr, ca.length);
+		    if (cj == 0)
+		    {
+			assert(i == result);
+			break;
+		    }
+		}
+	    }
+	}
+    }
+    body
+    {
+	//printf("body _d_switch_string(%.*s)\n", ca);
+	int low;
+	int high;
+	int mid;
+	int c;
+	char[] pca;
+
+	low = 0;
+	high = table.length;
+
+	version (none)
+	{
+	    // Print table
+	    printf("ca[] = '%s'\n", cast(char *)ca);
+	    for (mid = 0; mid < high; mid++)
+	    {
+		pca = table[mid];
+		printf("table[%d] = %d, '%.*s'\n", mid, pca.length, pca);
+	    }
+	}
+	if (high &&
+	    ca.length >= table[0].length &&
+	    ca.length <= table[high - 1].length)
+	{
+	    // Looking for 0 length string, which would only be at the beginning
+	    if (ca.length == 0)
+		return 0;
+
+	    char c1 = ca[0];
+
+	    // Do binary search
+	    while (low < high)
+	    {
+		mid = (low + high) >> 1;
+		pca = table[mid];
+		c = ca.length - pca.length;
+		if (c == 0)
+		{
+		    c = cast(ubyte)c1 - cast(ubyte)pca[0];
+		    if (c == 0)
+		    {
+			c = memcmp(ca.ptr, pca.ptr, ca.length);
+			if (c == 0)
+			{   //printf("found %d\n", mid);
+			    return mid;
+			}
+		    }
+		}
+		if (c < 0)
+		{
+		    high = mid;
+		}
+		else
+		{
+		    low = mid + 1;
+		}
+	    }
+	}
+
+	//printf("not found\n");
+	return -1;		// not found
+    }
+
+unittest
+{
+    switch (cast(char []) "c")
+    {
+         case "coo":
+         default:
+             break;
+    }
+}
+
+/**********************************
+ * Same thing, but for wide chars.
+ */
+
+int _d_switch_ustring(wchar[][] table, wchar[] ca)
+    in
+    {
+	//printf("in _d_switch_ustring()\n");
+	assert(table.length >= 0);
+	assert(ca.length >= 0);
+
+	// Make sure table[] is sorted correctly
+	int j;
+
+	for (j = 1; j < table.length; j++)
+	{
+	    int len1 = table[j - 1].length;
+	    int len2 = table[j].length;
+
+	    assert(len1 <= len2);
+	    if (len1 == len2)
+	    {
+		int c;
+
+		c = memcmp(table[j - 1].ptr, table[j].ptr, len1 * wchar.sizeof);
+		assert(c < 0);	// c==0 means a duplicate
+	    }
+	}
+    }
+    out (result)
+    {
+	int i;
+	int c;
+
+	//printf("out _d_switch_string()\n");
+	if (result == -1)
+	{
+	    // Not found
+	    for (i = 0; i < table.length; i++)
+	    {
+		if (table[i].length == ca.length)
+		{   c = memcmp(table[i].ptr, ca.ptr, ca.length * wchar.sizeof);
+		    assert(c != 0);
+		}
+	    }
+	}
+	else
+	{
+	    assert(0 <= result && result < table.length);
+	    for (i = 0; 1; i++)
+	    {
+		assert(i < table.length);
+		if (table[i].length == ca.length)
+		{
+		    c = memcmp(table[i].ptr, ca.ptr, ca.length * wchar.sizeof);
+		    if (c == 0)
+		    {
+			assert(i == result);
+			break;
+		    }
+		}
+	    }
+	}
+    }
+    body
+    {
+	//printf("body _d_switch_ustring()\n");
+	int low;
+	int high;
+	int mid;
+	int c;
+	wchar[] pca;
+
+	low = 0;
+	high = table.length;
+
+    /*
+	// Print table
+	wprintf("ca[] = '%.*s'\n", ca);
+	for (mid = 0; mid < high; mid++)
+	{
+	    pca = table[mid];
+	    wprintf("table[%d] = %d, '%.*s'\n", mid, pca.length, pca);
+	}
+    */
+
+	// Do binary search
+	while (low < high)
+	{
+	    mid = (low + high) >> 1;
+	    pca = table[mid];
+	    c = ca.length - pca.length;
+	    if (c == 0)
+	    {
+		c = memcmp(ca.ptr, pca.ptr, ca.length * wchar.sizeof);
+		if (c == 0)
+		{   //printf("found %d\n", mid);
+		    return mid;
+		}
+	    }
+	    if (c < 0)
+	    {
+		high = mid;
+	    }
+	    else
+	    {
+		low = mid + 1;
+	    }
+	}
+	//printf("not found\n");
+	return -1;		// not found
+    }
+
+
+unittest
+{
+    switch (cast(wchar []) "c")
+    {
+         case "coo":
+         default:
+             break;
+    }
+}
+
+
+/**********************************
+ * Same thing, but for wide chars.
+ */
+
+int _d_switch_dstring(dchar[][] table, dchar[] ca)
+    in
+    {
+	//printf("in _d_switch_dstring()\n");
+	assert(table.length >= 0);
+	assert(ca.length >= 0);
+
+	// Make sure table[] is sorted correctly
+	int j;
+
+	for (j = 1; j < table.length; j++)
+	{
+	    int len1 = table[j - 1].length;
+	    int len2 = table[j].length;
+
+	    assert(len1 <= len2);
+	    if (len1 == len2)
+	    {
+		int c;
+
+		c = memcmp(table[j - 1].ptr, table[j].ptr, len1 * dchar.sizeof);
+		assert(c < 0);	// c==0 means a duplicate
+	    }
+	}
+    }
+    out (result)
+    {
+	int i;
+	int c;
+
+	//printf("out _d_switch_string()\n");
+	if (result == -1)
+	{
+	    // Not found
+	    for (i = 0; i < table.length; i++)
+	    {
+		if (table[i].length == ca.length)
+		{   c = memcmp(table[i].ptr, ca.ptr, ca.length * dchar.sizeof);
+		    assert(c != 0);
+		}
+	    }
+	}
+	else
+	{
+	    assert(0 <= result && result < table.length);
+	    for (i = 0; 1; i++)
+	    {
+		assert(i < table.length);
+		if (table[i].length == ca.length)
+		{
+		    c = memcmp(table[i].ptr, ca.ptr, ca.length * dchar.sizeof);
+		    if (c == 0)
+		    {
+			assert(i == result);
+			break;
+		    }
+		}
+	    }
+	}
+    }
+    body
+    {
+	//printf("body _d_switch_ustring()\n");
+	int low;
+	int high;
+	int mid;
+	int c;
+	dchar[] pca;
+
+	low = 0;
+	high = table.length;
+
+    /*
+	// Print table
+	wprintf("ca[] = '%.*s'\n", ca);
+	for (mid = 0; mid < high; mid++)
+	{
+	    pca = table[mid];
+	    wprintf("table[%d] = %d, '%.*s'\n", mid, pca.length, pca);
+	}
+    */
+
+	// Do binary search
+	while (low < high)
+	{
+	    mid = (low + high) >> 1;
+	    pca = table[mid];
+	    c = ca.length - pca.length;
+	    if (c == 0)
+	    {
+		c = memcmp(ca.ptr, pca.ptr, ca.length * dchar.sizeof);
+		if (c == 0)
+		{   //printf("found %d\n", mid);
+		    return mid;
+		}
+	    }
+	    if (c < 0)
+	    {
+		high = mid;
+	    }
+	    else
+	    {
+		low = mid + 1;
+	    }
+	}
+	//printf("not found\n");
+	return -1;		// not found
+    }
+
+
+unittest
+{
+    switch (cast(dchar []) "c")
+    {
+         case "coo":
+         default:
+             break;
+    }
+}
+
+
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/switch3.d	Tue Nov 27 09:19:07 2007 +0100
@@ -0,0 +1,24 @@
+module switch3;
+
+void main()
+{
+    char[] str = "hello";
+    int i;
+    switch(str)
+    {
+    case "world":
+        i = 1;
+        assert(0);
+    case "hello":
+        i = 2;
+        break;
+    case "a","b","c":
+        i = 3;
+        assert(0);
+    default:
+        i = 4;
+        assert(0);
+    }
+    assert(i == 2);
+    printf("SUCCESS\n");
+}