diff dmd/PowExp.d @ 137:09c858522d55

merge
author Trass3r
date Mon, 13 Sep 2010 23:29:00 +0200
parents af1bebfd96a4
children e7769d53e750
line wrap: on
line diff
--- a/dmd/PowExp.d	Mon Sep 13 23:27:38 2010 +0200
+++ b/dmd/PowExp.d	Mon Sep 13 23:29:00 2010 +0200
@@ -12,6 +12,18 @@
 import dmd.DotIdExp;
 import dmd.CallExp;
 import dmd.ErrorExp;
+import dmd.CommaExp;
+import dmd.AndExp;
+import dmd.CondExp;
+import dmd.IntegerExp;
+import dmd.Type;
+import dmd.Lexer;
+import dmd.VarDeclaration;
+import dmd.ExpInitializer;
+import dmd.VarExp;
+import dmd.DeclarationExp;
+import dmd.MulExp;
+import dmd.WANT;
 
 version(DMDV2) {
 
@@ -22,65 +34,121 @@
         super(loc, TOK.TOKpow, PowExp.sizeof, e1, e2);
     }
         
-    Expression semantic(Scope sc)
+    override Expression semantic(Scope sc)
     {
         Expression e;
 
         if (type)
 	        return this;
 
+        //printf("PowExp::semantic() %s\n", toChars());
         BinExp.semanticp(sc);
         e = op_overload(sc);
         if (e)
 	        return e;
 
-        static int importMathChecked = 0;
-        if (!importMathChecked)
-        {
-	        importMathChecked = 1;
-	        for (int i = 0; i < Module.amodules.dim; i++)
-	        {
-                auto mi = cast(Module)Module.amodules.data[i];
-	            //printf("\t[%d] %s\n", i, mi->toChars());
-	            if (mi.ident == Id.math &&
-		        mi.parent.ident == Id.std &&
-		        !mi.parent.parent)
-		        goto L1;
-	        }
-	        error("must import std.math to use ^^ operator");
-
-            L1: ;
-        }
-
         assert(e1.type && e2.type);
         if ( (e1.type.isintegral() || e1.type.isfloating()) &&
 	     (e2.type.isintegral() || e2.type.isfloating()))
         {
-	        // For built-in numeric types, there are three cases:
-	        // x ^^ 1   ----> x
-	        // x ^^ 0.5 ----> sqrt(x)
-	        // x ^^ y   ----> pow(x, y)
+	        // For built-in numeric types, there are several cases.
 	        // TODO: backend support, especially for  e1 ^^ 2.
+            
 	        bool wantSqrt = false;	
+        	e1 = e1.optimize(0);
 	        e2 = e2.optimize(0);
-	        if ((e2.op == TOK.TOKfloat64 && e2.toReal() == 1.0) ||
-	            (e2.op == TOK.TOKint64 && e2.toInteger() == 1))
+	        	
+	        // Replace 1 ^^ x or 1.0^^x by (x, 1)
+	        if ((e1.op == TOK.TOKint64 && e1.toInteger() == 1) ||
+		        (e1.op == TOK.TOKfloat64 && e1.toReal() == 1.0))
+	        {
+	            typeCombine(sc);
+	            e = new CommaExp(loc, e2, e1);
+	            e = e.semantic(sc);
+	            return e;
+ 	        }
+	        // Replace -1 ^^ x by (x&1) ? -1 : 1, where x is integral
+	        if (e2.type.isintegral() && e1.op == TOKint64 && cast(long)e1.toInteger() == -1)
+	        {
+	            typeCombine(sc);
+	            Type resultType = type;
+	            e = new AndExp(loc, e2, new IntegerExp(loc, 1, e2.type));
+	            e = new CondExp(loc, e, new IntegerExp(loc, -1, resultType), new IntegerExp(loc, 1, resultType));
+	            e = e.semantic(sc);
+	            return e;
+	        }
+	        // All other negative integral powers are illegal
+	        if ((e1.type.isintegral()) && (e2.op == TOK.TOKint64) && cast(long)e2.toInteger() < 0)
 	        {
-	            return e1;  // Replace x ^^ 1 with x.
+	            error("cannot raise %s to a negative integer power. Did you mean (cast(real)%s)^^%s ?",
+		        e1.type.toBasetype().toChars(), e1.toChars(), e2.toChars());
+	            return new ErrorExp();
+	        }
+	
+	        // Deal with x^^2, x^^3 immediately, since they are of practical importance.
+	        // Don't bother if x is a literal, since it will be constant-folded anyway.
+	        if ( (  (e2.op == TOK.TOKint64 && (e2.toInteger() == 2 || e2.toInteger() == 3)) 
+	             ||	(e2.op == TOK.TOKfloat64 && (e2.toReal() == 2.0 || e2.toReal() == 3.0))
+	             ) && (e1.op == TOK.TOKint64 || e1.op == TOK.TOKfloat64)
+	           )
+	        {
+	            typeCombine(sc);
+	            // Replace x^^2 with (tmp = x, tmp*tmp)
+	            // Replace x^^3 with (tmp = x, tmp*tmp*tmp) 
+	            Identifier idtmp = Lexer.uniqueId("__tmp");
+	            VarDeclaration tmp = new VarDeclaration(loc, e1.type.toBasetype(), idtmp, new ExpInitializer(Loc(0), e1));
+	            VarExp ve = new VarExp(loc, tmp);
+	            Expression ae = new DeclarationExp(loc, tmp);
+	            Expression me = new MulExp(loc, ve, ve);
+	            if ( (e2.op == TOK.TOKint64 && e2.toInteger() == 3) 
+	              || (e2.op == TOK.TOKfloat64 && e2.toReal() == 3.0))
+		        me = new MulExp(loc, me, ve);
+	            e = new CommaExp(loc, ae, me);
+	            e = e.semantic(sc);
+	            return e;
 	        }
 
-	        e = new IdentifierExp(loc, Id.empty);
-	        e = new DotIdExp(loc, e, Id.std);
-	        e = new DotIdExp(loc, e, Id.math);
-	        if (e2.op == TOKfloat64 && e2.toReal() == 0.5)
-	        {   // Replace e1 ^^ 0.5 with .std.math.sqrt(x)
-	            e = new CallExp(loc, new DotIdExp(loc, e, Id._sqrt), e1);
+	        static int importMathChecked = 0;
+	        if (!importMathChecked)
+	        {
+	            importMathChecked = 1;
+	            for (int i = 0; i < Module.amodules.dim; i++)
+	            {
+                    auto mi = cast(Module)Module.amodules.data[i];
+		            //printf("\t[%d] %s\n", i, mi->toChars());
+		            if (mi.ident == Id.math &&
+		                mi.parent.ident == Id.std &&
+		                !mi.parent.parent)
+		                goto L1;
+	            }
+	            error("must import std.math to use ^^ operator");
+
+	         L1: ;
 	        }
-	        else 
-	        {   // Replace e1 ^^ e2 with .std.math.pow(e1, e2)
- 	            e = new CallExp(loc, new DotIdExp(loc, e, Id._pow), e1, e2);	
-	        }	
-	        e = e.semantic(sc);
+ 
+ 	        e = new IdentifierExp(loc, Id.empty);
+ 	        e = new DotIdExp(loc, e, Id.std);
+ 	        e = new DotIdExp(loc, e, Id.math);
+ 	        if (e2.op == TOK.TOKfloat64 && e2.toReal() == 0.5)
+ 	        {   // Replace e1 ^^ 0.5 with .std.math.sqrt(x)
+	            typeCombine(sc);
+ 	            e = new CallExp(loc, new DotIdExp(loc, e, Id._sqrt), e1);
+ 	        }
+ 	        else 
+	        {
+	            // Replace e1 ^^ e2 with .std.math.pow(e1, e2)
+	            // We don't combine the types if raising to an integer power (because
+	            // integer powers are treated specially by std.math.pow).
+	            if (!e2.type.isintegral())
+		            typeCombine(sc);
+	            e = new CallExp(loc, new DotIdExp(loc, e, Id._pow), e1, e2);	
+ 	        }	
+ 	        e = e.semantic(sc);
+	        // Always constant fold integer powers of literals. This will run the interpreter
+	        // on .std.math.pow
+	        if ((e1.op == TOK.TOKfloat64 || e1.op == TOK.TOKint64) && (e2.op == TOK.TOKint64))
+	            e = e.optimize(WANT.WANTvalue | WANT.WANTinterpret);
+
 	        return e;
         }
         error("%s ^^ %s is not supported", e1.type.toChars(), e2.type.toChars() );
@@ -89,12 +157,12 @@
    
 
     // For operator overloading
-    Identifier opId()
+    override Identifier opId()
     {
         return Id.pow;
     }
     
-    Identifier opId_r()
+    override Identifier opId_r()
     {
         return Id.pow_r;
     }