changeset 1213:9430d4959ab4

Fix a bug in nested context code that occured when calling a function nested in the outermost scope with a context frame from a function using a more nested context frame.
author Frits van Bommel <fvbommel wxs.nl>
date Mon, 13 Apr 2009 12:19:18 +0200
parents df2227fdc860
children 7e5547d8e59f
files gen/nested.cpp gen/tocall.cpp tests/mini/nested21.d
diffstat 3 files changed, 51 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/gen/nested.cpp	Mon Apr 13 04:09:08 2009 +0200
+++ b/gen/nested.cpp	Mon Apr 13 12:19:18 2009 +0200
@@ -50,6 +50,8 @@
 ////////////////////////////////////////////////////////////////////////////////////////*/
 
 static FuncDeclaration* getParentFunc(Dsymbol* sym) {
+    if (!sym)
+        return NULL;
     Dsymbol* parent = sym->parent;
     assert(parent);
     while (parent && !parent->isFuncDeclaration())
@@ -178,27 +180,54 @@
     LOG_SCOPE;
 
     IrFunction* irfunc = gIR->func();
+    bool fromParent = true;
 
+    LLValue* val;
     // if this func has its own vars that are accessed by nested funcs
     // use its own context
-    if (irfunc->nestedVar)
-        return irfunc->nestedVar;
+    if (irfunc->nestedVar) {
+        val = irfunc->nestedVar;
+        fromParent = false;
+    }
     // otherwise, it may have gotten a context from the caller
     else if (irfunc->nestArg)
-        return irfunc->nestArg;
+        val = irfunc->nestArg;
     // or just have a this argument
     else if (irfunc->thisArg)
     {
         ClassDeclaration* cd = irfunc->decl->isMember2()->isClassDeclaration();
         if (!cd || !cd->vthis)
             return getNullPtr(getVoidPtrType());
-        LLValue* val = DtoLoad(irfunc->thisArg);
-        return DtoLoad(DtoGEPi(val, 0,cd->vthis->ir.irField->index, ".vthis"));
+        val = DtoLoad(irfunc->thisArg);
+        val = DtoLoad(DtoGEPi(val, 0,cd->vthis->ir.irField->index, ".vthis"));
     }
     else
     {
         return getNullPtr(getVoidPtrType());
     }
+    if (nestedCtx == NCHybrid) {
+        // If sym is a nested function, and its parent elided the context list but the
+        // context we got didn't, we need to index to the first frame.
+        if (FuncDeclaration* fd = getParentFunc(sym->isFuncDeclaration())) {
+            Logger::println("For nested function, parent is %s", fd->toChars());
+            FuncDeclaration* ctxfd = irfunc->decl;
+            Logger::println("Current function is %s", ctxfd->toChars());
+            if (fromParent) {
+                ctxfd = getParentFunc(ctxfd);
+                assert(ctxfd && "Context from outer function, but no outer function?");
+            }
+            Logger::println("Context is from %s", ctxfd->toChars());
+            if (fd->ir.irFunc->elidedCtxList && !ctxfd->ir.irFunc->elidedCtxList) {
+                Logger::println("Adjusting to remove context frame list", ctxfd->toChars());
+                val = DtoBitCast(val, LLPointerType::getUnqual(ctxfd->ir.irFunc->framesType));
+                val = DtoGEPi(val, 0, 0);
+                val = DtoAlignedLoad(val, (std::string(".frame.") + fd->toChars()).c_str());
+            }
+        }
+    }
+    Logger::cout() << "result = " << *val << '\n';
+    Logger::cout() << "of type " << *val->getType() << '\n';
+    return val;
 }
 
 void DtoCreateNestedContext(FuncDeclaration* fd) {
--- a/gen/tocall.cpp	Mon Apr 13 04:09:08 2009 +0200
+++ b/gen/tocall.cpp	Mon Apr 13 12:19:18 2009 +0200
@@ -452,6 +452,7 @@
                 if (Logger::enabled())
                 {
                     Logger::cout() << "arg:     " << *arg << '\n';
+                    Logger::cout() << "of type: " << *arg->getType() << '\n';
                     Logger::cout() << "expects: " << *callableTy->getParamType(j) << '\n';
                 }
             #endif
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tests/mini/nested21.d	Mon Apr 13 12:19:18 2009 +0200
@@ -0,0 +1,16 @@
+module nested21;
+
+extern(C) int printf(char*, ...);
+
+void main() {
+    int i = 42;
+    int foo() { return i; }
+    int bar() {
+        int j = 47;
+        int baz() { return j; }
+        return foo() + baz();
+    }
+    auto result = bar();
+    printf("%d\n", result);
+    assert(result == 42 + 47);
+}