diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 0a785aca2e..61997a81b3 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -136,6 +136,7 @@ include("systems/pde/pdesystem.jl") include("systems/discrete_system/discrete_system.jl") include("systems/validation.jl") +include("systems/unitconversion.jl") include("systems/dependency_graphs.jl") include("systems/systemstructure.jl") using .SystemStructures @@ -174,7 +175,7 @@ export Equation, ConstrainedEquation export Term, Sym export SymScope, LocalScope, ParentScope, GlobalScope export independent_variables, independent_variable, states, parameters, equations, controls, observed, structure -export structural_simplify +export structural_simplify, rewrite_units export DiscreteSystem, DiscreteProblem export calculate_jacobian, generate_jacobian, generate_function diff --git a/src/systems/unitconversion.jl b/src/systems/unitconversion.jl new file mode 100644 index 0000000000..3c7a28ea3d --- /dev/null +++ b/src/systems/unitconversion.jl @@ -0,0 +1,180 @@ +"Wrapper for Unitful.convfact that returns a Constant & throws ValidationError instead of DimensionError." +function unitfactor(u, t) + try + cf = Unitful.convfact(u, t) + return cf == 1 ? 1 : Constant(cf*u/t) + catch err + throw(ValidationError("Unable to convert [$t] to [$u]")) + end +end + +"Turn an expression into a Julia function w/ correct units behavior." # mostly for testing +function functionize(pt) + syms = Symbolics.get_variables(pt) + eval(build_function(constructunit(pt), syms, expression = Val{false})) +end + +"Represent a constant as a Symbolic (esp. for lifting units to metadata level)." +struct Constant{T, M} <: SymbolicUtils.Symbolic{T} + val::T + metadata::M +end + +Constant(x) = Constant(x, Dict(VariableUnit => Unitful.unit(x))) +Base.:*(x::Num, y::Unitful.Quantity) = value(x) * y +Base.:*(x::Unitful.Quantity, y::Num) = x * value(y) +Base.show(io::IO, v::Constant) = Base.show(io, v.val) + +Unitless = Union{typeof.([exp, log, sinh, asinh, asin, + cosh, acosh, acos, + tanh, atanh, atan, + coth, acoth, acot, + sech, asech, asec, + csch, acsch, acsc])...} +isunitless(f::Unitless) = true + +#Should run this at the end of @variables and @parameters +set_unitless(x::Vector) = [_has_unit(y) ? y : SymbolicUtils.setmetadata(y,VariableUnit,unitless) for y in x] + +"Convert symbolic expression `x` to have units `u` if possible." +function unitcoerce(u::Unitful.Unitlike, x::Symbolic) + st = _has_unit(x) ? x : constructunit(x) + tu = _get_unit(st) + output = unitfactor(u, tu) * st + return SymbolicUtils.setmetadata(output, VariableUnit, u) +end + +"Convert a set of expressions to a common unit, defined by the first dimensional quantity encountered." +function uniformize(subterms) + newterms = Vector{Any}(undef, size(subterms)) + firstunit = nothing + for (idx, st) in enumerate(subterms) + if !isequal(st, 0) + st = constructunit(st) + tu = _get_unit(st) + if firstunit === nothing + firstunit = tu + end + newterms[idx] = unitfactor(firstunit, tu) * st + else + newterms[idx] = 0 + end + end + return newterms +end + +constructunit(x::Num) = constructunit(value(x)) +function constructunit(x::Unitful.Quantity) + return Constant(x.val, Dict(VariableUnit => Unitful.unit(x))) +end + +function constructunit(x) #This is where it all starts + maybeunit = safe_get_unit(x,"") + if maybeunit !== nothing + return SymbolicUtils.setmetadata(x, VariableUnit, maybeunit) + else # Something needs to be rewritten + op = operation(x) + args = arguments(x) + constructunit(op, args) + end +end + +function constructunit(op, args) # Fallback + if isunitless(op) + try + args = unitcoerce.(unitless, args) + return SymbolicUtils.setmetadata(op(args...), VariableUnit, unitless) + catch err + if err isa Unitful.DimensionError + argunits = get_unit.(args) + throw(ValidationError("Unable to coerce $args to dimensionless from $argunits for function $op.")) + else + rethrow(err) + end + end + else + throw(ValidationError("Unknown function $op supplied with $args with units $argunits")) + end +end + +function constructunit(op::typeof(+), subterms) + newterms = uniformize(subterms) + output = +(newterms...) + return SymbolicUtils.setmetadata(output, VariableUnit, _get_unit(newterms[1])) +end + +function constructunit(op::Conditional, subterms) + newterms = Vector{Any}(undef, 3) + firstunit = nothing + newterms[1] = constructunit(subterms[1]) + newterms[2:3] = uniformize(subterms[2:3]) + output = op(newterms...) + return SymbolicUtils.setmetadata(output, VariableUnit, _get_unit(newterms[2])) +end + +function constructunit(op::Union{Differential,Difference}, subterms) + numerator = constructunit(only(subterms)) + nu = _get_unit(numerator) + denominator = op isa Differential ? constructunit(op.x) : constructunit(op.t) #TODO: make consistent! + du = _get_unit(denominator) + output = op isa Differential ? Differential(denominator)(numerator) : Difference(denominator)(numerator) + return SymbolicUtils.setmetadata(output, VariableUnit, nu/du) +end + +function constructunit(op::typeof(^), subterms) + base, exponent = subterms + base = constructunit(base) + bu = _get_unit(base) + exponent = constructunit(exponent) + exponent = unitfactor(unitless, _get_unit(exponent)) * exponent + output = base^exponent + output_unit = bu == unitless ? unitless : (exponent isa Real ? bu^exponent : (1*bu)^exponent) + return SymbolicUtils.setmetadata(output, VariableUnit, output_unit) +end + +Root = Union{typeof(sqrt),typeof(cbrt)} +function constructunit(op::Root,args) + arg = constructunit(only(args)) + argunit = _get_unit(arg) + return SymbolicUtils.setmetadata(op(arg), VariableUnit, op(argunit)) +end + +function constructunit(op::Comparison, subterms) + newterms = uniformize(subterms) + output = op(newterms...) + return SymbolicUtils.setmetadata(output, VariableUnit, unitless) +end + +function constructunit(op::typeof(*), subterms) + newterms = Vector{Any}(undef, size(subterms)) + pu = unitless + for (idx, st) in enumerate(subterms) + st = constructunit(st) + pu *= _get_unit(st) + newterms[idx] = st + end + output = op(newterms...) + return SymbolicUtils.setmetadata(output, VariableUnit, pu) +end + +function constructunit(eq::ModelingToolkit.Equation) + newterms = uniformize([eq.lhs, eq.rhs]) + return ~(newterms...) + #return SymbolicUtils.setmetadata(output,VariableUnit,firstunit) #Fix this once Symbolics.jl Equations accept units +end + +"Rewrite a set of equations by inserting appropriate unit conversion factors." +function rewrite_units(eqs::Vector{Equation}; debug = false) + output = similar(eqs) + allgood = true + for (idx, eq) in enumerate(eqs) + try + output[idx] = constructunit(eq) + catch err + allgood = false + err isa ValidationError && !debug ? @warn("in eq [$idx], "*err.message) : rethrow(err) + end + end + allgood || throw(ValidationError("Some equations had invalid units. See warnings for details.")) + return output +end diff --git a/src/systems/validation.jl b/src/systems/validation.jl index ccd7fc9816..8aa9ee97b1 100644 --- a/src/systems/validation.jl +++ b/src/systems/validation.jl @@ -7,6 +7,7 @@ end "Throw exception on invalid unit types, otherwise return argument." function screen_unit(result) + result isa Symbolic && return result #For cases like P^γ where base is unitful, exponent is symbolic but dimensionless result isa Unitful.Unitlike || throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result)).")) result isa Unitful.ScalarUnits || throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead.")) result == u"°" && throw(ValidationError("Degrees are not supported. Use radians instead.")) @@ -33,6 +34,16 @@ Literal = Union{Sym,Symbolics.ArrayOp,Symbolics.Arr,Symbolics.CallWithMetadata} Conditional = Union{typeof(ifelse),typeof(IfElse.ifelse)} Comparison = Union{typeof.([==, !=, ≠, <, <=, ≤, >, >=, ≥])...} +#Underscore methods are 'dumb': they only look at the outermost object to see if it has units, they don't traverse the expression tree. +#_has_unit(x::Equation) = getmetadata(x,VariableUnit) Doesn't work yet, equations don't have metadata. +_has_unit(x::Real) = true +_has_unit(x::Num) = _has_unit(value(x)) +_has_unit(x::Symbolic) = hasmetadata(x,VariableUnit) + +_get_unit(x::Real) = unitless +_get_unit(x::Num) = _get_unit(value(x)) +_get_unit(x::Symbolic) = screen_unit(getmetadata(x,VariableUnit,unitless)) + "Find the unit of a symbolic item." get_unit(x::Real) = unitless get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x)) @@ -42,6 +53,7 @@ get_unit(x::Literal) = screen_unit(getmetadata(x,VariableUnit, unitless)) get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) get_unit(op::Difference, args) = get_unit(args[1]) / get_unit(op.t) get_unit(op::typeof(getindex),args) = get_unit(args[1]) + function get_unit(op,args) # Fallback result = op(1 .* get_unit.(args)...) try @@ -86,8 +98,8 @@ end function get_unit(op::Conditional, args) terms = get_unit.(args) - terms[1] == unitless || throw(ValidationError(", in $x, [$(terms[1])] is not dimensionless.")) - equivalent(terms[2], terms[3]) || throw(ValidationError(", in $x, units [$(terms[2])] and [$(terms[3])] do not match.")) + terms[1] == unitless || throw(ValidationError(", in $op, [$(terms[1])] is not dimensionless.")) + equivalent(terms[2], terms[3]) || throw(ValidationError(", in $op, units [$(terms[2])] and [$(terms[3])] do not match.")) return terms[2] end @@ -106,6 +118,7 @@ function get_unit(op::Comparison, args) end function get_unit(x::Symbolic) + _has_unit(x) && return _get_unit(x) #Easy out, if the tree has already been traversed by constructunit if SymbolicUtils.istree(x) op = operation(x) if op isa Sym || (op isa Term && operation(op) isa Term) # Dependent variables, not function calls @@ -116,7 +129,7 @@ function get_unit(x::Symbolic) end # Actual function calls: args = arguments(x) return get_unit(op, args) - else # This function should only be reached by Terms, for which `istree` is true + else # This method should only be reached by Terms, for which `istree` is true, so this branch should never happen: throw(ArgumentError("Unsupported value $x.")) end end diff --git a/test/units.jl b/test/units.jl index 140b52ddda..1690b89d0e 100644 --- a/test/units.jl +++ b/test/units.jl @@ -150,3 +150,23 @@ maj1 = MassActionJump(2.0, [0 => 1], [S => 1]) maj2 = MassActionJump(γ, [S => 1], [S => -1]) @named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ]) +# Rewriting +@variables t [unit = u"ms"] P(t) [unit = u"MW"] E(t) [unit = u"J"] +@parameters τ [unit = u"ms"] γ +D = Differential(t) +eqs = [D(E) ~ P - E/τ] +@test_throws MT.ValidationError MT.get_unit(eqs[1].rhs) +neweqs = MT.rewrite_units(eqs) +@named sys = ODESystem(neweqs) +equations(sys) + +@test MT.get_unit(t/τ) == MT._get_unit(MT.constructunit(t/τ)) +@test MT.get_unit(2^(t/τ)) == MT._get_unit(MT.constructunit(2^(t/τ))) +@test MT.equivalent(MT.get_unit(t^γ), MT._get_unit(MT.constructunit(t^γ))) +@test MT.get_unit(sin(γ)) == MT._get_unit(MT.sin(γ)) +@test MT.get_unit(sqrt(E)) == MT._get_unit(MT.constructunit(sqrt(E))) +@test MT.get_unit(exp(γ)) == MT._get_unit(MT.exp(γ)) + +@variables E(t) [unit = u"kJ"] +@test MT.get_unit(IfElse.ifelse(t<τ,E/τ,P)) == MT._get_unit(MT.constructunit(IfElse.ifelse(t<τ,E/τ,P))) +@test_throws MT.ValidationError MT.constructunit(E+τ)