Skip to content

Commit 5a06b38

Browse files
Merge pull request #156 from JuliaDiffEq/simplified_expr
Improved direct usage
2 parents bdcf035 + 04e4c69 commit 5a06b38

File tree

7 files changed

+186
-8
lines changed

7 files changed

+186
-8
lines changed

src/ModelingToolkit.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ module ModelingToolkit
33
export Operation, Expression
44
export calculate_jacobian, generate_jacobian, generate_function
55
export independent_variables, dependent_variables, parameters
6-
export @register
6+
export simplified_expr, eval_function
7+
export @register, @I
78
export modelingtoolkitize
89

910

@@ -87,6 +88,7 @@ include("equations.jl")
8788
include("function_registration.jl")
8889
include("simplify.jl")
8990
include("utils.jl")
91+
include("direct.jl")
9092
include("systems/diffeqs/diffeqsystem.jl")
9193
include("systems/diffeqs/first_order_transform.jl")
9294
include("systems/nonlinear/nonlinear_system.jl")

src/direct.jl

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
function gradient(O::Operation, vars::AbstractVector{Operation}; simplify = true)
2+
out = [expand_derivatives(Differential(v)(O)) for v in vars]
3+
simplify ? simplify_constants.(out) : out
4+
end
5+
6+
function jacobian(ops::AbstractVector{Operation}, vars::AbstractVector{Operation}; simplify = true)
7+
out = [expand_derivatives(Differential(v)(O)) for O in ops, v in vars]
8+
simplify ? simplify_constants.(out) : out
9+
end
10+
11+
function hessian(O::Operation, vars::AbstractVector{Operation}; simplify = true)
12+
out = [expand_derivatives(Differential(v2)(Differential(v1)(O))) for v1 in vars, v2 in vars]
13+
simplify ? simplify_constants.(out) : out
14+
end
15+
16+
function simplified_expr(O::Operation)
17+
if O isa Constant
18+
return O.value
19+
elseif isa(O.op, Differential)
20+
return :(derivative($(simplified_expr(O.args[1])),$(simplified_expr(O.op.x))))
21+
elseif isa(O.op, Variable)
22+
isempty(O.args) && return O.op.name
23+
return Expr(:call, Symbol(O.op), simplified_expr.(O.args)...)
24+
end
25+
return Expr(:call, Symbol(O.op), simplified_expr.(O.args)...)
26+
end
27+
28+
function simplified_expr(c::Constant)
29+
c.value
30+
end
31+
32+
function simplified_expr(eq::Equation)
33+
Expr(:(=), simplified_expr(eq.lhs), simplified_expr(eq.rhs))
34+
end
35+
36+
macro I(ex)
37+
name = :ICompile
38+
ret = return quote
39+
macro $(esc(name))()
40+
esc($ex)
41+
end
42+
end
43+
end

src/systems/nonlinear/nonlinear_system.jl

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
export NonlinearSystem
22

3-
43
struct NLEq
54
rhs::Expression
65
end

src/utils.jl

+9-6
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34-
function build_function(rhss, vs, ps, args = (), conv = rhs -> convert(Expr, rhs); constructor=nothing)
34+
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr; constructor=nothing)
3535
_vs = map(x-> x isa Operation ? x.op : x, vs)
3636
_ps = map(x-> x isa Operation ? x.op : x, ps)
3737
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
@@ -48,20 +48,23 @@ function build_function(rhss, vs, ps, args = (), conv = rhs -> convert(Expr, rhs
4848

4949
sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
5050
let_expr = Expr(:let, var_eqs, sys_expr)
51+
52+
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
5153
quote
52-
function $fname($X,u,p,$(args...))
54+
function $fname($X,$(fargs.args...))
5355
$ip_let_expr
5456
nothing
5557
end
56-
function $fname(u,p,$(args...))
58+
function $fname($(fargs.args...))
5759
X = $let_expr
58-
T = $(constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->(du=similar(u, eltype(X)); du .= x)) : constructor)
59-
T(X)
60+
T = promote_type(map(typeof,X)...)
61+
convert.(T,X)
62+
construct = $(constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->(du=similar(u, T, $(size(rhss)...)); vec(du) .= x; du)) : constructor)
63+
construct(X)
6064
end
6165
end
6266
end
6367

64-
6568
is_constant(::Constant) = true
6669
is_constant(::Any) = false
6770

test/direct.jl

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
using ModelingToolkit, StaticArrays, LinearAlgebra
2+
using DiffEqBase
3+
using Test
4+
5+
# Calculus
6+
@parameters t σ ρ β
7+
@variables x y z
8+
9+
eqs =*(y-x),
10+
x*-z)-y,
11+
x*y - β*z]
12+
13+
simpexpr = [
14+
:(σ * (y - x))
15+
:(x *- z) - y)
16+
:(x * y - β * z)
17+
]
18+
19+
for i in 1:3
20+
@test ModelingToolkit.simplified_expr.(eqs)[i] == simpexpr[i]
21+
@test ModelingToolkit.simplified_expr.(eqs)[i] == simpexpr[i]
22+
end
23+
24+
= ModelingToolkit.jacobian(eqs,[x,y,z])
25+
for i in 1:3
26+
= ModelingToolkit.gradient(eqs[i],[x,y,z])
27+
@test isequal(∂[i,:],∇)
28+
end
29+
@test all(isequal.(ModelingToolkit.gradient(eqs[1],[x,y,z]),[σ * -1,σ,0]))
30+
@test all(isequal.(ModelingToolkit.hessian(eqs[1],[x,y,z]),0))
31+
32+
# Function building
33+
34+
@parameters σ() ρ() β()
35+
@variables x y z
36+
eqs =*(y-x),
37+
x*-z)-y,
38+
x*y - β*z]
39+
f = eval(ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β]))
40+
out = [1.0,2,3]
41+
o1 = f([1.0,2,3],[1.0,2,3])
42+
f(out,[1.0,2,3],[1.0,2,3])
43+
@test all(o1 .== out)
44+
45+
function test_worldage()
46+
@parameters σ() ρ() β()
47+
@variables x y z
48+
eqs =*(y-x),
49+
x*-z)-y,
50+
x*y - β*z]
51+
_f = eval(ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β]))
52+
f(u,p) = ModelingToolkit.fast_invokelatest(_f,typeof(u),u,p)
53+
f(du,u,p) = ModelingToolkit.fast_invokelatest(_f,Nothing,du,u,p)
54+
out = [1.0,2,3]
55+
o1 = f([1.0,2,3],[1.0,2,3])
56+
f(out,[1.0,2,3],[1.0,2,3])
57+
end
58+
test_worldage()
59+
60+
mac = @I begin
61+
@parameters σ() ρ() β()
62+
@variables x() y() z()
63+
64+
eqs =*(y-x),
65+
x*-z)-y,
66+
x*y - β*z]
67+
ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β])
68+
end
69+
f = @ICompile
70+
out = [1.0,2,3]
71+
o1 = f([1.0,2,3],[1.0,2,3])
72+
f(out,[1.0,2,3],[1.0,2,3])
73+
@test all(o1 .== out)
74+
75+
mac = @I begin
76+
@parameters σ ρ β
77+
@variables x y z
78+
79+
eqs =*(y-x),
80+
x*-z)-y,
81+
x*y - β*z]
82+
= ModelingToolkit.jacobian(eqs,[x,y,z])
83+
ModelingToolkit.build_function(∂,[x,y,z],[σ,ρ,β])
84+
end
85+
f = @ICompile
86+
out = zeros(3,3)
87+
o1 = f([1.0,2,3],[1.0,2,3])
88+
f(out,[1.0,2,3],[1.0,2,3])
89+
@test all(out .== o1)
90+
91+
## No parameters
92+
@variables x y z
93+
eqs = [(y-x)^2,
94+
x*(x-z)-y,
95+
x*y - y*z]
96+
f = eval(ModelingToolkit.build_function(eqs,[x,y,z]))
97+
out = zeros(3)
98+
o1 = f([1.0,2,3])
99+
f(out,[1.0,2,3])
100+
@test all(out .== o1)
101+
102+
function test_worldage()
103+
@variables x y z
104+
eqs = [(y-x)^2,
105+
x*(x-z)-y,
106+
x*y - y*z]
107+
_f = eval(ModelingToolkit.build_function(eqs,[x,y,z]))
108+
f(u) = ModelingToolkit.fast_invokelatest(_f,typeof(u),u)
109+
f(du,u) = ModelingToolkit.fast_invokelatest(_f,Nothing,du,u)
110+
out = zeros(3)
111+
o1 = f([1.0,2,3])
112+
f(out,[1.0,2,3])
113+
end
114+
test_worldage()
115+
116+
mac = @I begin
117+
@variables x y z
118+
eqs = [(y-x)^2,
119+
x*(x-z)-y,
120+
x*y - y*z]
121+
ModelingToolkit.build_function(eqs,[x,y,z])
122+
end
123+
f = @ICompile
124+
out = zeros(3)
125+
o1 = f([1.0,2,3])
126+
f(out,[1.0,2,3])
127+
@test all(out .== o1)

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ using ModelingToolkit, Test
33
@testset "Parsing Test" begin include("variable_parsing.jl") end
44
@testset "Differentiation Test" begin include("derivatives.jl") end
55
@testset "Simplify Test" begin include("simplify.jl") end
6+
@testset "Direct Usage Test" begin include("direct.jl") end
67
@testset "System Construction Test" begin include("system_construction.jl") end

test/system_construction.jl

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ end
3030
eqs = [D(x) ~ σ*(y-x),
3131
D(y) ~ x*-z)-y,
3232
D(z) ~ x*y - β*z]
33+
34+
ModelingToolkit.simplified_expr.(eqs)[1]
35+
:(derivative(x(t), t) = σ * (y(t) - x(t))).args
3336
de = ODESystem(eqs)
3437
test_diffeq_inference("standard", de, t, (x, y, z), (σ, ρ, β))
3538
generate_function(de, [x,y,z], [σ,ρ,β])

0 commit comments

Comments
 (0)