Skip to content

Commit c692526

Browse files
Merge pull request #799 from SciML/nloptcons
Add constraints support for NLopt
2 parents 6bb4e7b + 3bf492c commit c692526

File tree

7 files changed

+130
-43
lines changed

7 files changed

+130
-43
lines changed

lib/OptimizationMultistartOptimization/Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ Reexport = "1.2"
1616

1717
[extras]
1818
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
19+
OptimizationNLopt= "4e6fcdb7-1186-4e1f-a706-475e75c168bb"
1920
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2021
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2122
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2223

2324
[targets]
24-
test = ["ForwardDiff", "ReverseDiff", "Pkg", "Test"]
25+
test = ["ForwardDiff", "OptimizationNLopt", "ReverseDiff", "Pkg", "Test"]

lib/OptimizationMultistartOptimization/test/runtests.jl

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using Pkg;
2-
Pkg.develop(path = joinpath(@__DIR__, "../../", "OptimizationNLopt"));
31
using OptimizationMultistartOptimization, Optimization, ForwardDiff, OptimizationNLopt
42
using Test, ReverseDiff
53

lib/OptimizationNLopt/Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,19 @@ version = "0.2.2"
66
[deps]
77
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
88
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
9+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1011

1112
[compat]
12-
NLopt = "0.6, 1"
13+
NLopt = "1.1"
1314
Optimization = "3.21"
1415
Reexport = "1.2"
1516
julia = "1"
1617

1718
[extras]
19+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1820
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1921
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2022

2123
[targets]
22-
test = ["Test", "Zygote"]
24+
test = ["ReverseDiff", "Test", "Zygote"]

lib/OptimizationNLopt/src/OptimizationNLopt.jl

+70-27
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module OptimizationNLopt
33
using Reexport
44
@reexport using NLopt, Optimization
55
using Optimization.SciMLBase
6+
using Optimization: deduce_retcode
67

78
(f::NLopt.Algorithm)() = f
89

@@ -63,6 +64,38 @@ function SciMLBase.requiresconsjac(opt::Union{NLopt.Algorithm, NLopt.Opt}) #http
6364
end
6465
end
6566

67+
function SciMLBase.allowsconstraints(opt::NLopt.Algorithm)
68+
str_opt = string(opt)
69+
if occursin("AUGLAG", str_opt) || occursin("CCSA", str_opt) ||
70+
occursin("MMA", str_opt) || occursin("COBYLA", str_opt) ||
71+
occursin("ISRES", str_opt) || occursin("AGS", str_opt) ||
72+
occursin("ORIG_DIRECT", str_opt) || occursin("SLSQP", str_opt)
73+
return true
74+
else
75+
return false
76+
end
77+
end
78+
79+
function SciMLBase.requiresconsjac(opt::NLopt.Algorithm)
80+
str_opt = string(opt)
81+
if occursin("AUGLAG", str_opt) || occursin("CCSA", str_opt) ||
82+
occursin("MMA", str_opt) || occursin("COBYLA", str_opt) ||
83+
occursin("ISRES", str_opt) || occursin("AGS", str_opt) ||
84+
occursin("ORIG_DIRECT", str_opt) || occursin("SLSQP", str_opt)
85+
return true
86+
else
87+
return false
88+
end
89+
end
90+
91+
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::NLopt.Algorithm,
92+
; cons_tol = 1e-6,
93+
callback = (args...) -> (false),
94+
progress = false, kwargs...)
95+
return OptimizationCache(prob, opt; cons_tol, callback, progress,
96+
kwargs...)
97+
end
98+
6699
function __map_optimizer_args!(cache::OptimizationCache, opt::NLopt.Opt;
67100
callback = nothing,
68101
maxiters::Union{Number, Nothing} = nothing,
@@ -103,7 +136,9 @@ function __map_optimizer_args!(cache::OptimizationCache, opt::NLopt.Opt;
103136

104137
# add optimiser options from kwargs
105138
for j in kwargs
106-
eval(Meta.parse("NLopt." * string(j.first) * "!"))(opt, j.second)
139+
if j.first != :cons_tol
140+
eval(Meta.parse("NLopt." * string(j.first) * "!"))(opt, j.second)
141+
end
107142
end
108143

109144
if cache.ub !== nothing
@@ -132,31 +167,6 @@ function __map_optimizer_args!(cache::OptimizationCache, opt::NLopt.Opt;
132167
return nothing
133168
end
134169

135-
function __nlopt_status_to_ReturnCode(status::Symbol)
136-
if status in Symbol.([
137-
NLopt.SUCCESS,
138-
NLopt.STOPVAL_REACHED,
139-
NLopt.FTOL_REACHED,
140-
NLopt.XTOL_REACHED,
141-
NLopt.ROUNDOFF_LIMITED
142-
])
143-
return ReturnCode.Success
144-
elseif status == Symbol(NLopt.MAXEVAL_REACHED)
145-
return ReturnCode.MaxIters
146-
elseif status == Symbol(NLopt.MAXTIME_REACHED)
147-
return ReturnCode.MaxTime
148-
elseif status in Symbol.([
149-
NLopt.OUT_OF_MEMORY,
150-
NLopt.INVALID_ARGS,
151-
NLopt.FAILURE,
152-
NLopt.FORCED_STOP
153-
])
154-
return ReturnCode.Failure
155-
else
156-
return ReturnCode.Default
157-
end
158-
end
159-
160170
function SciMLBase.__solve(cache::OptimizationCache{
161171
F,
162172
RC,
@@ -219,6 +229,39 @@ function SciMLBase.__solve(cache::OptimizationCache{
219229
NLopt.min_objective!(opt_setup, fg!)
220230
end
221231

232+
if cache.f.cons !== nothing
233+
eqinds = map((y) -> y[1] == y[2], zip(cache.lcons, cache.ucons))
234+
ineqinds = map((y) -> y[1] != y[2], zip(cache.lcons, cache.ucons))
235+
if sum(ineqinds) > 0
236+
ineqcons = function (res, θ, J)
237+
cons_cache = zeros(eltype(res), sum(eqinds) + sum(ineqinds))
238+
cache.f.cons(cons_cache, θ)
239+
res .= @view(cons_cache[ineqinds])
240+
if length(J) > 0
241+
Jcache = zeros(eltype(J), sum(ineqinds) + sum(eqinds), length(θ))
242+
cache.f.cons_j(Jcache, θ)
243+
J .= @view(Jcache[ineqinds, :])'
244+
end
245+
end
246+
NLopt.inequality_constraint!(
247+
opt_setup, ineqcons, [cache.solver_args.cons_tol for i in 1:sum(ineqinds)])
248+
end
249+
if sum(eqinds) > 0
250+
eqcons = function (res, θ, J)
251+
cons_cache = zeros(eltype(res), sum(eqinds) + sum(ineqinds))
252+
cache.f.cons(cons_cache, θ)
253+
res .= @view(cons_cache[eqinds])
254+
if length(J) > 0
255+
Jcache = zeros(eltype(res), sum(eqinds) + sum(ineqinds), length(θ))
256+
cache.f.cons_j(Jcache, θ)
257+
J .= @view(Jcache[eqinds, :])'
258+
end
259+
end
260+
NLopt.equality_constraint!(
261+
opt_setup, eqcons, [cache.solver_args.cons_tol for i in 1:sum(eqinds)])
262+
end
263+
end
264+
222265
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
223266
maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime)
224267

@@ -229,7 +272,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
229272
t0 = time()
230273
(minf, minx, ret) = NLopt.optimize(opt_setup, cache.u0)
231274
t1 = time()
232-
retcode = __nlopt_status_to_ReturnCode(ret)
275+
retcode = deduce_retcode(ret)
233276

234277
if retcode == ReturnCode.Failure
235278
@warn "NLopt failed to converge: $(ret)"

lib/OptimizationNLopt/test/runtests.jl

+49-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using OptimizationNLopt, Optimization, Zygote
2-
using Test
1+
using OptimizationNLopt, Optimization, Zygote, ReverseDiff
2+
using Test, Random
33

44
@testset "OptimizationNLopt.jl" begin
55
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
@@ -16,7 +16,7 @@ using Test
1616
optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
1717
prob = OptimizationProblem(optprob, x0, _p)
1818

19-
sol = solve(prob, NLopt.Opt(:LN_BOBYQA, 2))
19+
sol = solve(prob, NLopt.Opt(:LD_LBFGS, 2))
2020
@test sol.retcode == ReturnCode.Success
2121
@test 10 * sol.objective < l1
2222

@@ -26,10 +26,6 @@ using Test
2626
@test sol.retcode == ReturnCode.Success
2727
@test 10 * sol.objective < l1
2828

29-
sol = solve(prob, NLopt.Opt(:LD_LBFGS, 2))
30-
@test sol.retcode == ReturnCode.Success
31-
@test 10 * sol.objective < l1
32-
3329
sol = solve(prob, NLopt.Opt(:G_MLSL_LDS, 2), local_method = NLopt.Opt(:LD_LBFGS, 2),
3430
maxiters = 10000)
3531
@test sol.retcode == ReturnCode.MaxIters
@@ -82,4 +78,50 @@ using Test
8278
#nlopt gives the last best not the one where callback stops
8379
@test sol.objective < 0.8
8480
end
81+
82+
@testset "constrained" begin
83+
cons = (res, x, p) -> res .= [x[1]^2 + x[2]^2 - 1.0]
84+
x0 = zeros(2)
85+
optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote();
86+
cons = cons)
87+
prob = OptimizationProblem(optprob, x0, _p, lcons = [0.0], ucons = [0.0])
88+
sol = solve(prob, NLopt.LN_COBYLA())
89+
@test sol.retcode == ReturnCode.Success
90+
@test 10 * sol.objective < l1
91+
92+
Random.seed!(1)
93+
prob = OptimizationProblem(optprob, rand(2), _p,
94+
lcons = [0.0], ucons = [0.0])
95+
96+
sol = solve(prob, NLopt.LD_SLSQP())
97+
@test sol.retcode == ReturnCode.Success
98+
@test 10 * sol.objective < l1
99+
100+
Random.seed!(1)
101+
prob = OptimizationProblem(optprob, rand(2), _p,
102+
lcons = [0.0], ucons = [0.0])
103+
sol = solve(prob, NLopt.AUGLAG(), local_method = NLopt.LD_LBFGS())
104+
# @test sol.retcode == ReturnCode.Success
105+
@test 10 * sol.objective < l1
106+
107+
function con2_c(res, x, p)
108+
res .= [x[1]^2 + x[2]^2 - 1.0, x[2] * sin(x[1]) - x[1] - 2.0]
109+
end
110+
111+
optprob = OptimizationFunction(
112+
rosenbrock, Optimization.AutoForwardDiff(); cons = con2_c)
113+
Random.seed!(1)
114+
prob = OptimizationProblem(
115+
optprob, rand(2), _p, lcons = [0.0, -Inf], ucons = [0.0, 0.0])
116+
sol = solve(prob, NLopt.LD_AUGLAG(), local_method = NLopt.LD_LBFGS())
117+
# @test sol.retcode == ReturnCode.Success
118+
@test 10 * sol.objective < l1
119+
120+
Random.seed!(1)
121+
prob = OptimizationProblem(optprob, rand(2), _p, lcons = [-Inf, -Inf],
122+
ucons = [0.0, 0.0], lb = [-1.0, -1.0], ub = [1.0, 1.0])
123+
sol = solve(prob, NLopt.GN_ISRES(), maxiters = 1000)
124+
@test sol.retcode == ReturnCode.MaxIters
125+
@test 10 * sol.objective < l1
126+
end
85127
end

src/utils.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ const STOP_REASON_MAP = Dict(
7979
r"STOP: XTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
8080
r"STOP: TERMINATION" => ReturnCode.Terminated,
8181
r"Optimization completed" => ReturnCode.Success,
82-
r"Convergence achieved" => ReturnCode.Success
82+
r"Convergence achieved" => ReturnCode.Success,
83+
r"ROUNDOFF_LIMITED" => ReturnCode.Success
8384
)
8485

8586
# Function to deduce ReturnCode from a stop_reason string using the dictionary
@@ -99,11 +100,12 @@ function deduce_retcode(retcode::Symbol)
99100
return ReturnCode.Default
100101
elseif retcode == :Success || retcode == :EXACT_SOLUTION_LEFT ||
101102
retcode == :FLOATING_POINT_LIMIT || retcode == :true || retcode == :OPTIMAL ||
102-
retcode == :LOCALLY_SOLVED
103+
retcode == :LOCALLY_SOLVED || retcode == :ROUNDOFF_LIMITED || retcode == :SUCCESS
103104
return ReturnCode.Success
104105
elseif retcode == :Terminated
105106
return ReturnCode.Terminated
106-
elseif retcode == :MaxIters || retcode == :MAXITERS_EXCEED
107+
elseif retcode == :MaxIters || retcode == :MAXITERS_EXCEED ||
108+
retcode == :MAXEVAL_REACHED
107109
return ReturnCode.MaxIters
108110
elseif retcode == :MaxTime || retcode == :TIME_LIMIT
109111
return ReturnCode.MaxTime

test/stdout.txt

-1
This file was deleted.

0 commit comments

Comments
 (0)