Skip to content

Commit 0ba1a33

Browse files
committed
further expand Solve[]'s ability
1. bring domain check back into solve.eval, so that things like Abs() can be evaluated 2. create system symbols such as System`Reals for domain check 3. refactor, moving most logics out of Solve
1 parent 405cce1 commit 0ba1a33

File tree

5 files changed

+99
-102
lines changed

5 files changed

+99
-102
lines changed

mathics/builtin/numbers/calculus.py

+24-87
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
from itertools import product
13-
from typing import Optional, Union
13+
from typing import Optional
1414

1515
import numpy as np
1616
import sympy
@@ -59,23 +59,27 @@
5959
from mathics.core.systemsymbols import (
6060
SymbolAnd,
6161
SymbolAutomatic,
62+
SymbolComplex,
6263
SymbolConditionalExpression,
6364
SymbolD,
6465
SymbolDerivative,
6566
SymbolInfinity,
6667
SymbolInfix,
68+
SymbolInteger,
6769
SymbolIntegrate,
6870
SymbolLeft,
6971
SymbolLog,
7072
SymbolNIntegrate,
7173
SymbolO,
74+
SymbolReal,
7275
SymbolRule,
7376
SymbolSequence,
7477
SymbolSeries,
7578
SymbolSeriesData,
7679
SymbolSimplify,
7780
SymbolUndefined,
7881
)
82+
from mathics.eval.calculus import solve_sympy
7983
from mathics.eval.makeboxes import format_element
8084
from mathics.eval.nevaluator import eval_N
8185
from mathics.eval.numbers.calculus.integrators import (
@@ -2210,105 +2214,38 @@ class Solve(Builtin):
22102214
messages = {
22112215
"eqf": "`1` is not a well-formed equation.",
22122216
"svars": 'Equations may not give solutions for all "solve" variables.',
2217+
"fulldim": "The solution set contains a full-dimensional component; use Reduce for complete solution information.",
22132218
}
22142219

2215-
# FIXME: the problem with removing the domain parameter from the outside
2216-
# is that the we can't make use of this information inside
2217-
# the evaluation method where it is may be needed.
22182220
rules = {
2219-
"Solve[eqs_, vars_, Complexes]": "Solve[eqs, vars]",
2220-
"Solve[eqs_, vars_, Reals]": (
2221-
"Cases[Solve[eqs, vars], {Rule[x_,y_?RealValuedNumberQ]}]"
2222-
),
2223-
"Solve[eqs_, vars_, Integers]": (
2224-
"Cases[Solve[eqs, vars], {Rule[x_,y_Integer]}]"
2225-
),
2221+
"Solve[eqs_, vars_]": "Solve[eqs, vars, Complexes]"
22262222
}
22272223
summary_text = "find generic solutions for variables"
22282224

2229-
def eval(self, eqs, vars, evaluation: Evaluation):
2230-
"Solve[eqs_, vars_]"
2225+
def eval(self, eqs, vars, domain, evaluation: Evaluation):
2226+
"Solve[eqs_, vars_, domain_]"
22312227

2232-
vars_original = vars
2233-
head_name = vars.get_head_name()
2228+
variables = vars
2229+
head_name = variables.get_head_name()
22342230
if head_name == "System`List":
2235-
vars = vars.elements
2231+
variables = variables.elements
22362232
else:
2237-
vars = [vars]
2238-
for var in vars:
2233+
variables = [variables]
2234+
for var in variables:
22392235
if (
22402236
(isinstance(var, Atom) and not isinstance(var, Symbol)) or
22412237
head_name in ("System`Plus", "System`Times", "System`Power") or # noqa
22422238
A_CONSTANT & var.get_attributes(evaluation.definitions)
22432239
):
22442240

2245-
evaluation.message("Solve", "ivar", vars_original)
2241+
evaluation.message("Solve", "ivar", vars)
22462242
return
22472243

2248-
vars_sympy = [var.to_sympy() for var in vars]
2249-
if None in vars_sympy:
2244+
sympy_variables = [var.to_sympy() for var in variables]
2245+
if None in sympy_variables:
22502246
evaluation.message("Solve", "ivar")
22512247
return
2252-
all_var_tuples = list(zip(vars, vars_sympy))
2253-
2254-
def cut_var_dimension(expressions: Union[Expression, list[Expression]]):
2255-
'''delete unused variables to avoid SymPy's PolynomialError
2256-
: Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]'''
2257-
if not isinstance(expressions, list):
2258-
expressions = [expressions]
2259-
subset_vars = set()
2260-
subset_vars_sympy = set()
2261-
for var, var_sympy in all_var_tuples:
2262-
pattern = Pattern.create(var)
2263-
for equation in expressions:
2264-
if not equation.is_free(pattern, evaluation):
2265-
subset_vars.add(var)
2266-
subset_vars_sympy.add(var_sympy)
2267-
return subset_vars, subset_vars_sympy
2268-
2269-
def solve_sympy(equations: Union[Expression, list[Expression]]):
2270-
if not isinstance(equations, list):
2271-
equations = [equations]
2272-
equations_sympy = []
2273-
denoms_sympy = []
2274-
subset_vars, subset_vars_sympy = cut_var_dimension(equations)
2275-
for equation in equations:
2276-
if equation is SymbolTrue:
2277-
continue
2278-
elif equation is SymbolFalse:
2279-
return []
2280-
elements = equation.elements
2281-
for left, right in [(elements[index], elements[index + 1]) for index in range(len(elements) - 1)]:
2282-
# ↑ to deal with things like a==b==c==d
2283-
left = left.to_sympy()
2284-
right = right.to_sympy()
2285-
if left is None or right is None:
2286-
return []
2287-
equation_sympy = left - right
2288-
equation_sympy = sympy.together(equation_sympy)
2289-
equation_sympy = sympy.cancel(equation_sympy)
2290-
equations_sympy.append(equation_sympy)
2291-
numer, denom = equation_sympy.as_numer_denom()
2292-
denoms_sympy.append(denom)
2293-
try:
2294-
results = sympy.solve(equations_sympy, subset_vars_sympy, dict=True) # no transform_dict needed with dict=True
2295-
# Filter out results for which denominator is 0
2296-
# (SymPy should actually do that itself, but it doesn't!)
2297-
results = [
2298-
sol
2299-
for sol in results
2300-
if all(sympy.simplify(denom.subs(sol)) != 0 for denom in denoms_sympy)
2301-
]
2302-
return results
2303-
except sympy.PolynomialError:
2304-
# raised for e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] when not deleting
2305-
# unused variables beforehand
2306-
return []
2307-
except NotImplementedError:
2308-
return []
2309-
except TypeError as exc:
2310-
if str(exc).startswith("expected Symbol, Function or Derivative"):
2311-
evaluation.message("Solve", "ivar", vars_original)
2248+
variable_tuples = list(zip(variables, sympy_variables))
23122249

23132250
def solve_recur(expression: Expression):
23142251
'''solve And, Or and List within the scope of sympy,
@@ -2336,7 +2273,7 @@ def solve_recur(expression: Expression):
23362273
inequations.append(sub_condition)
23372274
else:
23382275
inequations.append(child.to_sympy())
2339-
solutions.extend(solve_sympy(equations))
2276+
solutions.extend(solve_sympy(evaluation, equations, variables, domain))
23402277
conditions = sympy.And(*inequations)
23412278
result = [sol for sol in solutions if conditions.subs(sol)]
23422279
return result, None if solutions else conditions
@@ -2346,7 +2283,7 @@ def solve_recur(expression: Expression):
23462283
conditions = []
23472284
for child in expression.elements:
23482285
if child.has_form("Equal", 2):
2349-
solutions.extend(solve_sympy(child))
2286+
solutions.extend(solve_sympy(evaluation, child, variables, domain))
23502287
elif child.get_head_name() in ('System`And', 'System`Or'): # I don't believe List would be in here
23512288
sub_solution, sub_condition = solve_recur(child)
23522289
solutions.extend(sub_solution)
@@ -2365,8 +2302,8 @@ def solve_recur(expression: Expression):
23652302
if conditions is not None:
23662303
evaluation.message("Solve", "fulldim")
23672304
else:
2368-
if eqs.has_form("Equal", 2):
2369-
solutions = solve_sympy(eqs)
2305+
if eqs.get_head_name() == "System`Equal":
2306+
solutions = solve_sympy(evaluation, eqs, variables, domain)
23702307
else:
23712308
evaluation.message("Solve", "fulldim")
23722309
return ListExpression(ListExpression())
@@ -2376,7 +2313,7 @@ def solve_recur(expression: Expression):
23762313
return ListExpression(ListExpression())
23772314

23782315
if any(
2379-
sol and any(var not in sol for var in vars_sympy) for sol in solutions
2316+
sol and any(var not in sol for var in sympy_variables) for sol in solutions
23802317
):
23812318
evaluation.message("Solve", "svars")
23822319

@@ -2385,7 +2322,7 @@ def solve_recur(expression: Expression):
23852322
ListExpression(
23862323
*(
23872324
Expression(SymbolRule, var, from_sympy(sol[var_sympy]))
2388-
for var, var_sympy in all_var_tuples
2325+
for var, var_sympy in variable_tuples
23892326
if var_sympy in sol
23902327
),
23912328
)

mathics/core/atoms.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -774,9 +774,9 @@ def get_sort_key(self, pattern_sort=False) -> tuple:
774774
def sameQ(self, other) -> bool:
775775
"""Mathics SameQ"""
776776
return (
777-
isinstance(other, Complex)
778-
and self.real == other.real
779-
and self.imag == other.imag
777+
isinstance(other, Complex) and
778+
self.real == other.real and
779+
self.imag == other.imag
780780
)
781781

782782
def round(self, d=None) -> "Complex":

mathics/core/convert/sympy.py

+43-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Conversion to SymPy is handled directly in BaseElement descendants.
66
"""
77

8+
from collections.abc import Iterable
89
from typing import Optional, Type, Union
910

1011
import sympy
@@ -13,9 +14,6 @@
1314
# Import the singleton class
1415
from sympy.core.numbers import S
1516

16-
BasicSympy = sympy.Expr
17-
18-
1917
from mathics.core.atoms import (
2018
MATHICS3_COMPLEX_I,
2119
Complex,
@@ -40,6 +38,7 @@
4038
)
4139
from mathics.core.list import ListExpression
4240
from mathics.core.number import FP_MANTISA_BINARY_DIGITS
41+
from mathics.core.rules import Pattern
4342
from mathics.core.symbols import (
4443
Symbol,
4544
SymbolFalse,
@@ -62,16 +61,21 @@
6261
SymbolGreater,
6362
SymbolGreaterEqual,
6463
SymbolIndeterminate,
64+
SymbolIntegers,
6565
SymbolLess,
6666
SymbolLessEqual,
6767
SymbolMatrixPower,
6868
SymbolO,
6969
SymbolPi,
7070
SymbolPiecewise,
71+
SymbolReals,
7172
SymbolSlot,
7273
SymbolUnequal,
7374
)
7475

76+
BasicSympy = sympy.Expr
77+
78+
7579
SymbolPrime = Symbol("Prime")
7680
SymbolRoot = Symbol("Root")
7781
SymbolRootSum = Symbol("RootSum")
@@ -130,6 +134,39 @@ def to_sympy_matrix(data, **kwargs) -> Optional[sympy.MutableDenseMatrix]:
130134
return None
131135

132136

137+
def apply_domain_to_symbols(symbols: Iterable[sympy.Symbol], domain) -> dict[sympy.Symbol, sympy.Symbol]:
138+
"""Create new sympy symbols with domain applied.
139+
Return a dict maps old to new.
140+
"""
141+
# FIXME: this substitute solution would break when Solve[Abs[x]==3, x],where x=-3 and x=3.
142+
# However, substituting symbol prior to actual solving would cause sympy to have biased assumption,
143+
# it would refuse to solve Abs() when symbol is in Complexes
144+
result = {}
145+
for symbol in symbols:
146+
if domain == SymbolReals:
147+
new_symbol = sympy.Symbol(repr(symbol), real=True)
148+
elif domain == SymbolIntegers:
149+
new_symbol = sympy.Symbol(repr(symbol), integer=True)
150+
else:
151+
new_symbol = symbol
152+
result[symbol] = new_symbol
153+
return result
154+
155+
156+
def cut_dimension(evaluation, expressions: Union[Expression, list[Expression]], symbols: Iterable[sympy.Symbol]) -> set[sympy.Symbol]:
157+
'''delete unused variables to avoid SymPy's PolynomialError
158+
: Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]'''
159+
if not isinstance(expressions, list):
160+
expressions = [expressions]
161+
subset = set()
162+
for symbol in symbols:
163+
pattern = Pattern.create(symbol)
164+
for equation in expressions:
165+
if not equation.is_free(pattern, evaluation):
166+
subset.add(symbol)
167+
return subset
168+
169+
133170
class SympyExpression(BasicSympy):
134171
is_Function = True
135172
nargs = None
@@ -363,9 +400,9 @@ def old_from_sympy(expr) -> BaseElement:
363400
if is_Cn_expr(name):
364401
return Expression(SymbolC, Integer(int(name[1:])))
365402
if name.startswith(sympy_symbol_prefix):
366-
name = name[len(sympy_symbol_prefix) :]
403+
name = name[len(sympy_symbol_prefix):]
367404
if name.startswith(sympy_slot_prefix):
368-
index = name[len(sympy_slot_prefix) :]
405+
index = name[len(sympy_slot_prefix):]
369406
return Expression(SymbolSlot, Integer(int(index)))
370407
elif expr.is_NumberSymbol:
371408
name = str(expr)
@@ -517,7 +554,7 @@ def old_from_sympy(expr) -> BaseElement:
517554
*[from_sympy(arg) for arg in expr.args]
518555
)
519556
if name.startswith(sympy_symbol_prefix):
520-
name = name[len(sympy_symbol_prefix) :]
557+
name = name[len(sympy_symbol_prefix):]
521558
args = [from_sympy(arg) for arg in expr.args]
522559
builtin = sympy_to_mathics.get(name)
523560
if builtin is not None:

mathics/core/systemsymbols.py

+3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
SymbolCompile = Symbol("System`Compile")
5757
SymbolCompiledFunction = Symbol("System`CompiledFunction")
5858
SymbolComplex = Symbol("System`Complex")
59+
SymbolComplexes = Symbol("System`Complexes")
5960
SymbolComplexInfinity = Symbol("System`ComplexInfinity")
6061
SymbolCondition = Symbol("System`Condition")
6162
SymbolConditionalExpression = Symbol("System`ConditionalExpression")
@@ -124,6 +125,7 @@
124125
SymbolInfix = Symbol("System`Infix")
125126
SymbolInputForm = Symbol("System`InputForm")
126127
SymbolInteger = Symbol("System`Integer")
128+
SymbolIntegers = Symbol("System`Integers")
127129
SymbolIntegrate = Symbol("System`Integrate")
128130
SymbolLeft = Symbol("System`Left")
129131
SymbolLength = Symbol("System`Length")
@@ -200,6 +202,7 @@
200202
SymbolRational = Symbol("System`Rational")
201203
SymbolRe = Symbol("System`Re")
202204
SymbolReal = Symbol("System`Real")
205+
SymbolReals = Symbol("System`Reals")
203206
SymbolRealAbs = Symbol("System`RealAbs")
204207
SymbolRealDigits = Symbol("System`RealDigits")
205208
SymbolRealSign = Symbol("System`RealSign")

test/builtin/calculus/test_solve.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,34 @@ def test_solve():
4343
"Issue #1235",
4444
),
4545
(
46-
"Solve[{x^2==4 && x < 0},{x}]",
47-
"{x->-2}",
48-
"",
46+
"Solve[Abs[-2/3*(lambda + 2) + 8/3 + 4] == 4, lambda,Reals]",
47+
"{{lambda -> 2}, {lambda -> 14}}",
48+
"abs()",
4949
),
5050
(
51-
"Solve[{x^2==4 && x < 0 && x > -4},{x}]",
52-
"{x->-2}",
53-
"",
51+
"Solve[q^3 == (20-12)/(4-3), q,Reals]",
52+
"{{q -> 2}}",
53+
"domain check",
54+
),
55+
(
56+
"Solve[x + Pi/3 == 2k*Pi + Pi/6 || x + Pi/3 == 2k*Pi + 5Pi/6, x,Reals]",
57+
"{{x -> -Pi / 6 + 2 k Pi}, {x -> Pi / 2 + 2 k Pi}}",
58+
"logics involved",
59+
),
60+
(
61+
"Solve[m - 1 == 0 && -(m + 1) != 0, m,Reals]",
62+
"{{m -> 1}}",
63+
"logics and constraints",
64+
),
65+
(
66+
"Solve[(lambda + 1)/6 == 1/(mu - 1) == lambda/4, {lambda, mu},Reals]",
67+
"{{lambda -> 2, mu -> 3}}",
68+
"chained equations",
69+
),
70+
(
71+
"Solve[2*x0*Log[x0] + x0 - 2*a*x0 == -1 && x0^2*Log[x0] - a*x0^2 + b == b - x0, {x0, a, b},Reals]",
72+
"{{x0 -> 1, a -> 1}}",
73+
"excess variable b",
5474
),
5575
):
5676
session.evaluate("Clear[h]; Clear[g]; Clear[f];")

0 commit comments

Comments
 (0)