Skip to content

Commit 8e10644

Browse files
committed
further expand Solve[]'s ability
1. bring domain check back into solve.eval, so that things like Abs() can be evaluated 2. create system symbols such as System`Reals for domain check 3. refactor, moving most logics out of Solve
1 parent 5f32b4d commit 8e10644

File tree

5 files changed

+99
-102
lines changed

5 files changed

+99
-102
lines changed

mathics/builtin/numbers/calculus.py

+24-87
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
from itertools import product
13-
from typing import Optional, Union
13+
from typing import Optional
1414

1515
import numpy as np
1616
import sympy
@@ -71,23 +71,27 @@
7171
from mathics.core.systemsymbols import (
7272
SymbolAnd,
7373
SymbolAutomatic,
74+
SymbolComplex,
7475
SymbolConditionalExpression,
7576
SymbolD,
7677
SymbolDerivative,
7778
SymbolInfinity,
7879
SymbolInfix,
80+
SymbolInteger,
7981
SymbolIntegrate,
8082
SymbolLeft,
8183
SymbolLog,
8284
SymbolNIntegrate,
8385
SymbolO,
86+
SymbolReal,
8487
SymbolRule,
8588
SymbolSequence,
8689
SymbolSeries,
8790
SymbolSeriesData,
8891
SymbolSimplify,
8992
SymbolUndefined,
9093
)
94+
from mathics.eval.calculus import solve_sympy
9195
from mathics.eval.makeboxes import format_element
9296
from mathics.eval.nevaluator import eval_N
9397

@@ -2208,105 +2212,38 @@ class Solve(Builtin):
22082212
messages = {
22092213
"eqf": "`1` is not a well-formed equation.",
22102214
"svars": 'Equations may not give solutions for all "solve" variables.',
2215+
"fulldim": "The solution set contains a full-dimensional component; use Reduce for complete solution information.",
22112216
}
22122217

2213-
# FIXME: the problem with removing the domain parameter from the outside
2214-
# is that the we can't make use of this information inside
2215-
# the evaluation method where it is may be needed.
22162218
rules = {
2217-
"Solve[eqs_, vars_, Complexes]": "Solve[eqs, vars]",
2218-
"Solve[eqs_, vars_, Reals]": (
2219-
"Cases[Solve[eqs, vars], {Rule[x_,y_?RealValuedNumberQ]}]"
2220-
),
2221-
"Solve[eqs_, vars_, Integers]": (
2222-
"Cases[Solve[eqs, vars], {Rule[x_,y_Integer]}]"
2223-
),
2219+
"Solve[eqs_, vars_]": "Solve[eqs, vars, Complexes]"
22242220
}
22252221
summary_text = "find generic solutions for variables"
22262222

2227-
def eval(self, eqs, vars, evaluation: Evaluation):
2228-
"Solve[eqs_, vars_]"
2223+
def eval(self, eqs, vars, domain, evaluation: Evaluation):
2224+
"Solve[eqs_, vars_, domain_]"
22292225

2230-
vars_original = vars
2231-
head_name = vars.get_head_name()
2226+
variables = vars
2227+
head_name = variables.get_head_name()
22322228
if head_name == "System`List":
2233-
vars = vars.elements
2229+
variables = variables.elements
22342230
else:
2235-
vars = [vars]
2236-
for var in vars:
2231+
variables = [variables]
2232+
for var in variables:
22372233
if (
22382234
(isinstance(var, Atom) and not isinstance(var, Symbol)) or
22392235
head_name in ("System`Plus", "System`Times", "System`Power") or # noqa
22402236
A_CONSTANT & var.get_attributes(evaluation.definitions)
22412237
):
22422238

2243-
evaluation.message("Solve", "ivar", vars_original)
2239+
evaluation.message("Solve", "ivar", vars)
22442240
return
22452241

2246-
vars_sympy = [var.to_sympy() for var in vars]
2247-
if None in vars_sympy:
2242+
sympy_variables = [var.to_sympy() for var in variables]
2243+
if None in sympy_variables:
22482244
evaluation.message("Solve", "ivar")
22492245
return
2250-
all_var_tuples = list(zip(vars, vars_sympy))
2251-
2252-
def cut_var_dimension(expressions: Union[Expression, list[Expression]]):
2253-
'''delete unused variables to avoid SymPy's PolynomialError
2254-
: Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]'''
2255-
if not isinstance(expressions, list):
2256-
expressions = [expressions]
2257-
subset_vars = set()
2258-
subset_vars_sympy = set()
2259-
for var, var_sympy in all_var_tuples:
2260-
pattern = Pattern.create(var)
2261-
for equation in expressions:
2262-
if not equation.is_free(pattern, evaluation):
2263-
subset_vars.add(var)
2264-
subset_vars_sympy.add(var_sympy)
2265-
return subset_vars, subset_vars_sympy
2266-
2267-
def solve_sympy(equations: Union[Expression, list[Expression]]):
2268-
if not isinstance(equations, list):
2269-
equations = [equations]
2270-
equations_sympy = []
2271-
denoms_sympy = []
2272-
subset_vars, subset_vars_sympy = cut_var_dimension(equations)
2273-
for equation in equations:
2274-
if equation is SymbolTrue:
2275-
continue
2276-
elif equation is SymbolFalse:
2277-
return []
2278-
elements = equation.elements
2279-
for left, right in [(elements[index], elements[index + 1]) for index in range(len(elements) - 1)]:
2280-
# ↑ to deal with things like a==b==c==d
2281-
left = left.to_sympy()
2282-
right = right.to_sympy()
2283-
if left is None or right is None:
2284-
return []
2285-
equation_sympy = left - right
2286-
equation_sympy = sympy.together(equation_sympy)
2287-
equation_sympy = sympy.cancel(equation_sympy)
2288-
equations_sympy.append(equation_sympy)
2289-
numer, denom = equation_sympy.as_numer_denom()
2290-
denoms_sympy.append(denom)
2291-
try:
2292-
results = sympy.solve(equations_sympy, subset_vars_sympy, dict=True) # no transform_dict needed with dict=True
2293-
# Filter out results for which denominator is 0
2294-
# (SymPy should actually do that itself, but it doesn't!)
2295-
results = [
2296-
sol
2297-
for sol in results
2298-
if all(sympy.simplify(denom.subs(sol)) != 0 for denom in denoms_sympy)
2299-
]
2300-
return results
2301-
except sympy.PolynomialError:
2302-
# raised for e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] when not deleting
2303-
# unused variables beforehand
2304-
return []
2305-
except NotImplementedError:
2306-
return []
2307-
except TypeError as exc:
2308-
if str(exc).startswith("expected Symbol, Function or Derivative"):
2309-
evaluation.message("Solve", "ivar", vars_original)
2246+
variable_tuples = list(zip(variables, sympy_variables))
23102247

23112248
def solve_recur(expression: Expression):
23122249
'''solve And, Or and List within the scope of sympy,
@@ -2334,7 +2271,7 @@ def solve_recur(expression: Expression):
23342271
inequations.append(sub_condition)
23352272
else:
23362273
inequations.append(child.to_sympy())
2337-
solutions.extend(solve_sympy(equations))
2274+
solutions.extend(solve_sympy(evaluation, equations, variables, domain))
23382275
conditions = sympy.And(*inequations)
23392276
result = [sol for sol in solutions if conditions.subs(sol)]
23402277
return result, None if solutions else conditions
@@ -2344,7 +2281,7 @@ def solve_recur(expression: Expression):
23442281
conditions = []
23452282
for child in expression.elements:
23462283
if child.has_form("Equal", 2):
2347-
solutions.extend(solve_sympy(child))
2284+
solutions.extend(solve_sympy(evaluation, child, variables, domain))
23482285
elif child.get_head_name() in ('System`And', 'System`Or'): # I don't believe List would be in here
23492286
sub_solution, sub_condition = solve_recur(child)
23502287
solutions.extend(sub_solution)
@@ -2363,8 +2300,8 @@ def solve_recur(expression: Expression):
23632300
if conditions is not None:
23642301
evaluation.message("Solve", "fulldim")
23652302
else:
2366-
if eqs.has_form("Equal", 2):
2367-
solutions = solve_sympy(eqs)
2303+
if eqs.get_head_name() == "System`Equal":
2304+
solutions = solve_sympy(evaluation, eqs, variables, domain)
23682305
else:
23692306
evaluation.message("Solve", "fulldim")
23702307
return ListExpression(ListExpression())
@@ -2374,7 +2311,7 @@ def solve_recur(expression: Expression):
23742311
return ListExpression(ListExpression())
23752312

23762313
if any(
2377-
sol and any(var not in sol for var in vars_sympy) for sol in solutions
2314+
sol and any(var not in sol for var in sympy_variables) for sol in solutions
23782315
):
23792316
evaluation.message("Solve", "svars")
23802317

@@ -2383,7 +2320,7 @@ def solve_recur(expression: Expression):
23832320
ListExpression(
23842321
*(
23852322
Expression(SymbolRule, var, from_sympy(sol[var_sympy]))
2386-
for var, var_sympy in all_var_tuples
2323+
for var, var_sympy in variable_tuples
23872324
if var_sympy in sol
23882325
),
23892326
)

mathics/core/atoms.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -774,9 +774,9 @@ def get_sort_key(self, pattern_sort=False) -> tuple:
774774
def sameQ(self, other) -> bool:
775775
"""Mathics SameQ"""
776776
return (
777-
isinstance(other, Complex)
778-
and self.real == other.real
779-
and self.imag == other.imag
777+
isinstance(other, Complex) and
778+
self.real == other.real and
779+
self.imag == other.imag
780780
)
781781

782782
def round(self, d=None) -> "Complex":

mathics/core/convert/sympy.py

+43-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Conversion to SymPy is handled directly in BaseElement descendants.
66
"""
77

8+
from collections.abc import Iterable
89
from typing import Optional, Type, Union
910

1011
import sympy
@@ -13,9 +14,6 @@
1314
# Import the singleton class
1415
from sympy.core.numbers import S
1516

16-
BasicSympy = sympy.Expr
17-
18-
1917
from mathics.core.atoms import (
2018
MATHICS3_COMPLEX_I,
2119
Complex,
@@ -40,6 +38,7 @@
4038
)
4139
from mathics.core.list import ListExpression
4240
from mathics.core.number import FP_MANTISA_BINARY_DIGITS
41+
from mathics.core.rules import Pattern
4342
from mathics.core.symbols import (
4443
Symbol,
4544
SymbolFalse,
@@ -62,16 +61,21 @@
6261
SymbolGreater,
6362
SymbolGreaterEqual,
6463
SymbolIndeterminate,
64+
SymbolIntegers,
6565
SymbolLess,
6666
SymbolLessEqual,
6767
SymbolMatrixPower,
6868
SymbolO,
6969
SymbolPi,
7070
SymbolPiecewise,
71+
SymbolReals,
7172
SymbolSlot,
7273
SymbolUnequal,
7374
)
7475

76+
BasicSympy = sympy.Expr
77+
78+
7579
SymbolPrime = Symbol("Prime")
7680
SymbolRoot = Symbol("Root")
7781
SymbolRootSum = Symbol("RootSum")
@@ -130,6 +134,39 @@ def to_sympy_matrix(data, **kwargs) -> Optional[sympy.MutableDenseMatrix]:
130134
return None
131135

132136

137+
def apply_domain_to_symbols(symbols: Iterable[sympy.Symbol], domain) -> dict[sympy.Symbol, sympy.Symbol]:
138+
"""Create new sympy symbols with domain applied.
139+
Return a dict maps old to new.
140+
"""
141+
# FIXME: this substitute solution would break when Solve[Abs[x]==3, x],where x=-3 and x=3.
142+
# However, substituting symbol prior to actual solving would cause sympy to have biased assumption,
143+
# it would refuse to solve Abs() when symbol is in Complexes
144+
result = {}
145+
for symbol in symbols:
146+
if domain == SymbolReals:
147+
new_symbol = sympy.Symbol(repr(symbol), real=True)
148+
elif domain == SymbolIntegers:
149+
new_symbol = sympy.Symbol(repr(symbol), integer=True)
150+
else:
151+
new_symbol = symbol
152+
result[symbol] = new_symbol
153+
return result
154+
155+
156+
def cut_dimension(evaluation, expressions: Union[Expression, list[Expression]], symbols: Iterable[sympy.Symbol]) -> set[sympy.Symbol]:
157+
'''delete unused variables to avoid SymPy's PolynomialError
158+
: Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]'''
159+
if not isinstance(expressions, list):
160+
expressions = [expressions]
161+
subset = set()
162+
for symbol in symbols:
163+
pattern = Pattern.create(symbol)
164+
for equation in expressions:
165+
if not equation.is_free(pattern, evaluation):
166+
subset.add(symbol)
167+
return subset
168+
169+
133170
class SympyExpression(BasicSympy):
134171
is_Function = True
135172
nargs = None
@@ -363,9 +400,9 @@ def old_from_sympy(expr) -> BaseElement:
363400
if is_Cn_expr(name):
364401
return Expression(SymbolC, Integer(int(name[1:])))
365402
if name.startswith(sympy_symbol_prefix):
366-
name = name[len(sympy_symbol_prefix) :]
403+
name = name[len(sympy_symbol_prefix):]
367404
if name.startswith(sympy_slot_prefix):
368-
index = name[len(sympy_slot_prefix) :]
405+
index = name[len(sympy_slot_prefix):]
369406
return Expression(SymbolSlot, Integer(int(index)))
370407
elif expr.is_NumberSymbol:
371408
name = str(expr)
@@ -517,7 +554,7 @@ def old_from_sympy(expr) -> BaseElement:
517554
*[from_sympy(arg) for arg in expr.args]
518555
)
519556
if name.startswith(sympy_symbol_prefix):
520-
name = name[len(sympy_symbol_prefix) :]
557+
name = name[len(sympy_symbol_prefix):]
521558
args = [from_sympy(arg) for arg in expr.args]
522559
builtin = sympy_to_mathics.get(name)
523560
if builtin is not None:

mathics/core/systemsymbols.py

+3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
SymbolCompile = Symbol("System`Compile")
5757
SymbolCompiledFunction = Symbol("System`CompiledFunction")
5858
SymbolComplex = Symbol("System`Complex")
59+
SymbolComplexes = Symbol("System`Complexes")
5960
SymbolComplexInfinity = Symbol("System`ComplexInfinity")
6061
SymbolCondition = Symbol("System`Condition")
6162
SymbolConditionalExpression = Symbol("System`ConditionalExpression")
@@ -124,6 +125,7 @@
124125
SymbolInfix = Symbol("System`Infix")
125126
SymbolInputForm = Symbol("System`InputForm")
126127
SymbolInteger = Symbol("System`Integer")
128+
SymbolIntegers = Symbol("System`Integers")
127129
SymbolIntegrate = Symbol("System`Integrate")
128130
SymbolLeft = Symbol("System`Left")
129131
SymbolLength = Symbol("System`Length")
@@ -200,6 +202,7 @@
200202
SymbolRational = Symbol("System`Rational")
201203
SymbolRe = Symbol("System`Re")
202204
SymbolReal = Symbol("System`Real")
205+
SymbolReals = Symbol("System`Reals")
203206
SymbolRealAbs = Symbol("System`RealAbs")
204207
SymbolRealDigits = Symbol("System`RealDigits")
205208
SymbolRealSign = Symbol("System`RealSign")

test/builtin/calculus/test_solve.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,34 @@ def test_solve():
4343
"Issue #1235",
4444
),
4545
(
46-
"Solve[{x^2==4 && x < 0},{x}]",
47-
"{x->-2}",
48-
"",
46+
"Solve[Abs[-2/3*(lambda + 2) + 8/3 + 4] == 4, lambda,Reals]",
47+
"{{lambda -> 2}, {lambda -> 14}}",
48+
"abs()",
4949
),
5050
(
51-
"Solve[{x^2==4 && x < 0 && x > -4},{x}]",
52-
"{x->-2}",
53-
"",
51+
"Solve[q^3 == (20-12)/(4-3), q,Reals]",
52+
"{{q -> 2}}",
53+
"domain check",
54+
),
55+
(
56+
"Solve[x + Pi/3 == 2k*Pi + Pi/6 || x + Pi/3 == 2k*Pi + 5Pi/6, x,Reals]",
57+
"{{x -> -Pi / 6 + 2 k Pi}, {x -> Pi / 2 + 2 k Pi}}",
58+
"logics involved",
59+
),
60+
(
61+
"Solve[m - 1 == 0 && -(m + 1) != 0, m,Reals]",
62+
"{{m -> 1}}",
63+
"logics and constraints",
64+
),
65+
(
66+
"Solve[(lambda + 1)/6 == 1/(mu - 1) == lambda/4, {lambda, mu},Reals]",
67+
"{{lambda -> 2, mu -> 3}}",
68+
"chained equations",
69+
),
70+
(
71+
"Solve[2*x0*Log[x0] + x0 - 2*a*x0 == -1 && x0^2*Log[x0] - a*x0^2 + b == b - x0, {x0, a, b},Reals]",
72+
"{{x0 -> 1, a -> 1}}",
73+
"excess variable b",
5474
),
5575
):
5676
session.evaluate("Clear[h]; Clear[g]; Clear[f];")

0 commit comments

Comments
 (0)