Skip to content

Commit 04e4c69

Browse files
allow no parameters
1 parent cdc2635 commit 04e4c69

File tree

3 files changed

+72
-14
lines changed

3 files changed

+72
-14
lines changed

src/ModelingToolkit.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module ModelingToolkit
33
export Operation, Expression
44
export calculate_jacobian, generate_jacobian, generate_function
55
export independent_variables, dependent_variables, parameters
6-
export simplified_expr
6+
export simplified_expr, eval_function
77
export @register, @I
88
export modelingtoolkitize
99

src/utils.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34-
function build_function(rhss, vs, ps, args = (), conv = simplified_expr; constructor=nothing)
34+
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr; constructor=nothing)
3535
_vs = map(x-> x isa Operation ? x.op : x, vs)
3636
_ps = map(x-> x isa Operation ? x.op : x, ps)
3737
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
@@ -48,12 +48,14 @@ function build_function(rhss, vs, ps, args = (), conv = simplified_expr; constru
4848

4949
sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
5050
let_expr = Expr(:let, var_eqs, sys_expr)
51+
52+
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
5153
quote
52-
function $fname($X,u,p,$(args...))
54+
function $fname($X,$(fargs.args...))
5355
$ip_let_expr
5456
nothing
5557
end
56-
function $fname(u,p,$(args...))
58+
function $fname($(fargs.args...))
5759
X = $let_expr
5860
T = promote_type(map(typeof,X)...)
5961
convert.(T,X)
@@ -63,7 +65,6 @@ function build_function(rhss, vs, ps, args = (), conv = simplified_expr; constru
6365
end
6466
end
6567

66-
6768
is_constant(::Constant) = true
6869
is_constant(::Any) = false
6970

test/direct.jl

+66-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using ModelingToolkit, StaticArrays, LinearAlgebra
22
using DiffEqBase
33
using Test
44

5-
# Define some variables
5+
# Calculus
66
@parameters t σ ρ β
77
@variables x y z
88

@@ -29,14 +29,33 @@ end
2929
@test all(isequal.(ModelingToolkit.gradient(eqs[1],[x,y,z]),[σ * -1,σ,0]))
3030
@test all(isequal.(ModelingToolkit.hessian(eqs[1],[x,y,z]),0))
3131

32+
# Function building
33+
3234
@parameters σ() ρ() β()
3335
@variables x y z
34-
3536
eqs =*(y-x),
3637
x*-z)-y,
3738
x*y - β*z]
39+
f = eval(ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β]))
40+
out = [1.0,2,3]
41+
o1 = f([1.0,2,3],[1.0,2,3])
42+
f(out,[1.0,2,3],[1.0,2,3])
43+
@test all(o1 .== out)
3844

39-
ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β])
45+
function test_worldage()
46+
@parameters σ() ρ() β()
47+
@variables x y z
48+
eqs =*(y-x),
49+
x*-z)-y,
50+
x*y - β*z]
51+
_f = eval(ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β]))
52+
f(u,p) = ModelingToolkit.fast_invokelatest(_f,typeof(u),u,p)
53+
f(du,u,p) = ModelingToolkit.fast_invokelatest(_f,Nothing,du,u,p)
54+
out = [1.0,2,3]
55+
o1 = f([1.0,2,3],[1.0,2,3])
56+
f(out,[1.0,2,3],[1.0,2,3])
57+
end
58+
test_worldage()
4059

4160
mac = @I begin
4261
@parameters σ() ρ() β()
@@ -45,17 +64,17 @@ mac = @I begin
4564
eqs =*(y-x),
4665
x*-z)-y,
4766
x*y - β*z]
48-
4967
ModelingToolkit.build_function(eqs,[x,y,z],[σ,ρ,β])
5068
end
5169
f = @ICompile
5270
out = [1.0,2,3]
53-
f([1.0,2,3],[1.0,2,3])
71+
o1 = f([1.0,2,3],[1.0,2,3])
5472
f(out,[1.0,2,3],[1.0,2,3])
73+
@test all(o1 .== out)
5574

5675
mac = @I begin
57-
@parameters σ() ρ() β()
58-
@variables x() y() z()
76+
@parameters σ ρ β
77+
@variables x y z
5978

6079
eqs =*(y-x),
6180
x*-z)-y,
@@ -65,6 +84,44 @@ mac = @I begin
6584
end
6685
f = @ICompile
6786
out = zeros(3,3)
68-
f([1.0,2,3],[1.0,2,3])
87+
o1 = f([1.0,2,3],[1.0,2,3])
6988
f(out,[1.0,2,3],[1.0,2,3])
70-
out
89+
@test all(out .== o1)
90+
91+
## No parameters
92+
@variables x y z
93+
eqs = [(y-x)^2,
94+
x*(x-z)-y,
95+
x*y - y*z]
96+
f = eval(ModelingToolkit.build_function(eqs,[x,y,z]))
97+
out = zeros(3)
98+
o1 = f([1.0,2,3])
99+
f(out,[1.0,2,3])
100+
@test all(out .== o1)
101+
102+
function test_worldage()
103+
@variables x y z
104+
eqs = [(y-x)^2,
105+
x*(x-z)-y,
106+
x*y - y*z]
107+
_f = eval(ModelingToolkit.build_function(eqs,[x,y,z]))
108+
f(u) = ModelingToolkit.fast_invokelatest(_f,typeof(u),u)
109+
f(du,u) = ModelingToolkit.fast_invokelatest(_f,Nothing,du,u)
110+
out = zeros(3)
111+
o1 = f([1.0,2,3])
112+
f(out,[1.0,2,3])
113+
end
114+
test_worldage()
115+
116+
mac = @I begin
117+
@variables x y z
118+
eqs = [(y-x)^2,
119+
x*(x-z)-y,
120+
x*y - y*z]
121+
ModelingToolkit.build_function(eqs,[x,y,z])
122+
end
123+
f = @ICompile
124+
out = zeros(3)
125+
o1 = f([1.0,2,3])
126+
f(out,[1.0,2,3])
127+
@test all(out .== o1)

0 commit comments

Comments
 (0)