Skip to content

Commit dab9126

Browse files
Merge pull request #159 from saschatimme/st/more-constant-computation
Simplify constants in powers and sums
2 parents 42aafd5 + 02c5397 commit dab9126

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

src/simplify.jl

+16
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ function _simplify_constants(O::Operation, shorten_tree)
5858
return O
5959
end
6060

61+
if O.op === (^) && length(O.args) == 2 && iszero(O.args[2])
62+
return Constant(1)
63+
end
64+
65+
if O.op === (^) && length(O.args) == 2 && isone(O.args[2])
66+
return O.args[1]
67+
end
68+
6169
if O.op === (+) && any(iszero, O.args)
6270
# If there are Constant(0)s in a big `+` expression, get rid of them
6371
args = filter(!iszero, O.args)
@@ -67,6 +75,14 @@ function _simplify_constants(O::Operation, shorten_tree)
6775
return Operation(O.op, args)
6876
end
6977

78+
if (O.op === (-) || O.op === (+) || O.op === (*)) && all(is_constant, O.args) && !isempty(O.args)
79+
v = O.args[1].value
80+
for i in 2:length(O.args)
81+
v = O.op(v, O.args[i].value)
82+
end
83+
return Constant(v)
84+
end
85+
7086
(O.op, length(O.args)) === (identity, 1) && return O.args[1]
7187

7288
(O.op, length(O.args)) === (-, 1) && return Operation(*, Expression[-1, O.args[1]])

test/derivatives.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,4 @@ jac = calculate_jacobian(sys)
5959

6060
@test isequal(expand_derivatives(D(2t)), 2)
6161
@test isequal(expand_derivatives(D(2x)), 2D(x))
62-
@test_broken isequal(expand_derivatives(D(x^2)), simplify_constants(2 * x * D(x)))
62+
@test isequal(expand_derivatives(D(x^2)), simplify_constants(2 * x * D(x)))

test/simplify.jl

+12
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,15 @@ identity_op = Operation(identity,[x])
1616
minus_op = -x
1717
@test isequal(simplify_constants(minus_op), -1*x)
1818
simplify_constants(minus_op)
19+
20+
@variables x
21+
22+
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^2))) == :((x-2) * 2)
23+
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^3))) == :((x-2)^2 * 3)
24+
@test simplified_expr(simplify_constants(x+2+3)) == :(x + 5)
25+
26+
d1 = Differential(x)((x-2)^2)
27+
d2 = Differential(x)(d1)
28+
d3 = Differential(x)(d2)
29+
@test simplified_expr(expand_derivatives(d3)) == :(0)
30+
@test simplified_expr(simplify_constants(x^0)) == :(1)

0 commit comments

Comments
 (0)