Skip to content

Commit 71a7558

Browse files
committed
reconstruct Solve logic
alternate old plain solve logic with a recursive one. Now Solve is capable with nested logics
1 parent 8659983 commit 71a7558

File tree

1 file changed

+156
-145
lines changed

1 file changed

+156
-145
lines changed

mathics/builtin/numbers/calculus.py

+156-145
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,9 @@ def summand(element, index):
330330
Expression(
331331
SymbolDerivative,
332332
*(
333-
[Integer0] * (index)
334-
+ [Integer1]
335-
+ [Integer0] * (len(f.elements) - index - 1)
333+
[Integer0] * (index) +
334+
[Integer1] +
335+
[Integer0] * (len(f.elements) - index - 1)
336336
),
337337
),
338338
f.head,
@@ -664,8 +664,8 @@ def eval(self, f, x, x0, evaluation: Evaluation, options: dict):
664664

665665
# Determine the "jacobian"s
666666
if (
667-
method in ("Newton", "Automatic")
668-
and options["System`Jacobian"] is SymbolAutomatic
667+
method in ("Newton", "Automatic") and
668+
options["System`Jacobian"] is SymbolAutomatic
669669
):
670670

671671
def diff(evaluation):
@@ -1323,16 +1323,16 @@ class NIntegrate(Builtin):
13231323
messages = {
13241324
"bdmtd": "The Method option should be a built-in method name.",
13251325
"inumr": (
1326-
"The integrand `1` has evaluated to non-numerical "
1327-
+ "values for all sampling points in the region "
1328-
+ "with boundaries `2`"
1326+
"The integrand `1` has evaluated to non-numerical " +
1327+
"values for all sampling points in the region " +
1328+
"with boundaries `2`"
13291329
),
13301330
"nlim": "`1` = `2` is not a valid limit of integration.",
13311331
"ilim": "Invalid integration variable or limit(s) in `1`.",
13321332
"mtdfail": (
1333-
"The specified method failed to return a "
1334-
+ "number. Falling back into the internal "
1335-
+ "evaluator."
1333+
"The specified method failed to return a " +
1334+
"number. Falling back into the internal " +
1335+
"evaluator."
13361336
),
13371337
"cmpint": ("Integration over a complex domain is not " + "implemented yet"),
13381338
}
@@ -1375,10 +1375,10 @@ class NIntegrate(Builtin):
13751375

13761376
messages.update(
13771377
{
1378-
"bdmtd": "The Method option should be a "
1379-
+ "built-in method name in {`"
1380-
+ "`, `".join(list(methods))
1381-
+ "`}. Using `Automatic`"
1378+
"bdmtd": "The Method option should be a " +
1379+
"built-in method name in {`" +
1380+
"`, `".join(list(methods)) +
1381+
"`}. Using `Automatic`"
13821382
}
13831383
)
13841384

@@ -1398,7 +1398,7 @@ def eval_with_func_domain(
13981398
elif isinstance(method, Symbol):
13991399
method = method.get_name()
14001400
# strip context
1401-
method = method[method.rindex("`") + 1 :]
1401+
method = method[method.rindex("`") + 1:]
14021402
else:
14031403
evaluation.message("NIntegrate", "bdmtd", method)
14041404
return
@@ -2237,146 +2237,157 @@ def eval(self, eqs, vars, evaluation: Evaluation):
22372237
vars = [vars]
22382238
for var in vars:
22392239
if (
2240-
(isinstance(var, Atom) and not isinstance(var, Symbol))
2241-
or head_name in ("System`Plus", "System`Times", "System`Power") # noqa
2242-
or A_CONSTANT & var.get_attributes(evaluation.definitions)
2240+
(isinstance(var, Atom) and not isinstance(var, Symbol)) or
2241+
head_name in ("System`Plus", "System`Times", "System`Power") or # noqa
2242+
A_CONSTANT & var.get_attributes(evaluation.definitions)
22432243
):
22442244

22452245
evaluation.message("Solve", "ivar", vars_original)
22462246
return
2247-
if eqs.get_head_name() in ("System`List", "System`And"):
2248-
eq_list = eqs.elements
2249-
else:
2250-
eq_list = [eqs]
2251-
sympy_conditions = []
2252-
sympy_eqs = []
2253-
sympy_denoms = []
2254-
for eq in eq_list:
2255-
if eq is SymbolTrue:
2256-
pass
2257-
elif eq is SymbolFalse:
2258-
return ListExpression()
2259-
elif not eq.has_form("Equal", 2):
2260-
sympy_conditions.append(eq.to_sympy())
2261-
else:
2262-
left, right = eq.elements
2263-
left = left.to_sympy()
2264-
right = right.to_sympy()
2265-
if left is None or right is None:
2266-
return
2267-
eq = left - right
2268-
eq = sympy.together(eq)
2269-
eq = sympy.cancel(eq)
2270-
sympy_eqs.append(eq)
2271-
numer, denom = eq.as_numer_denom()
2272-
sympy_denoms.append(denom)
2273-
2274-
if not sympy_eqs:
2275-
evaluation.message("Solve", "eqf", eqs)
2276-
return
22772247

22782248
vars_sympy = [var.to_sympy() for var in vars]
22792249
if None in vars_sympy:
2250+
evaluation.message("Solve", "ivar")
22802251
return
2281-
2282-
# delete unused variables to avoid SymPy's
2283-
# PolynomialError: Not a zero-dimensional system
2284-
# in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]
2285-
all_vars = vars[:]
2286-
all_vars_sympy = vars_sympy[:]
2287-
vars = []
2288-
vars_sympy = []
2289-
for var, var_sympy in zip(all_vars, all_vars_sympy):
2290-
pattern = Pattern.create(var)
2291-
if not eqs.is_free(pattern, evaluation):
2292-
vars.append(var)
2293-
vars_sympy.append(var_sympy)
2294-
2295-
def transform_dict(sols):
2296-
if not sols:
2297-
yield sols
2298-
for var, sol in sols.items():
2299-
rest = sols.copy()
2300-
del rest[var]
2301-
rest = transform_dict(rest)
2302-
if not isinstance(sol, (tuple, list)):
2303-
sol = [sol]
2304-
if not sol:
2305-
for r in rest:
2306-
yield r
2307-
else:
2308-
for r in rest:
2309-
for item in sol:
2310-
new_sols = r.copy()
2311-
new_sols[var] = item
2312-
yield new_sols
2313-
break
2314-
2315-
def transform_solution(sol):
2316-
if not isinstance(sol, dict):
2317-
if not isinstance(sol, (list, tuple)):
2318-
sol = [sol]
2319-
sol = dict(list(zip(vars_sympy, sol)))
2320-
return transform_dict(sol)
2321-
2322-
if not sympy_eqs:
2323-
sympy_eqs = True
2324-
elif len(sympy_eqs) == 1:
2325-
sympy_eqs = sympy_eqs[0]
2326-
2327-
try:
2328-
if isinstance(sympy_eqs, bool):
2329-
result = sympy_eqs
2252+
all_var_tuples = list(zip(vars, vars_sympy))
2253+
2254+
def cut_var_dimension(expressions: Expression | list[Expression]):
2255+
'''delete unused variables to avoid SymPy's PolynomialError
2256+
: Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]'''
2257+
if not isinstance(expressions, list):
2258+
expressions = [expressions]
2259+
subset_vars = set()
2260+
subset_vars_sympy = set()
2261+
for var, var_sympy in all_var_tuples:
2262+
pattern = Pattern.create(var)
2263+
for equation in expressions:
2264+
if not equation.is_free(pattern, evaluation):
2265+
subset_vars.add(var)
2266+
subset_vars_sympy.add(var_sympy)
2267+
return subset_vars, subset_vars_sympy
2268+
2269+
def solve_sympy(equations: Expression | list[Expression]):
2270+
if not isinstance(equations, list):
2271+
equations = [equations]
2272+
equations_sympy = []
2273+
denoms_sympy = []
2274+
subset_vars, subset_vars_sympy = cut_var_dimension(equations)
2275+
for equation in equations:
2276+
if equation is SymbolTrue:
2277+
continue
2278+
elif equation is SymbolFalse:
2279+
return []
2280+
elements = equation.elements
2281+
for left, right in [(elements[index], elements[index + 1]) for index in range(len(elements) - 1)]:
2282+
# ↑ to deal with things like a==b==c==d
2283+
left = left.to_sympy()
2284+
right = right.to_sympy()
2285+
if left is None or right is None:
2286+
return []
2287+
equation_sympy = left - right
2288+
equation_sympy = sympy.together(equation_sympy)
2289+
equation_sympy = sympy.cancel(equation_sympy)
2290+
numer, denom = equation_sympy.as_numer_denom()
2291+
denoms_sympy.append(denom)
2292+
try:
2293+
results = sympy.solve(equations_sympy, subset_vars_sympy, dict=True) # no transform needed with dict=True
2294+
# Filter out results for which denominator is 0
2295+
# (SymPy should actually do that itself, but it doesn't!)
2296+
results = [
2297+
sol
2298+
for sol in results
2299+
if all(sympy.simplify(denom.subs(sol)) != 0 for denom in denoms_sympy)
2300+
]
2301+
return results
2302+
except sympy.PolynomialError:
2303+
# raised for e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] when not deleting
2304+
# unused variables beforehand
2305+
return []
2306+
except NotImplementedError:
2307+
return []
2308+
except TypeError as exc:
2309+
if str(exc).startswith("expected Symbol, Function or Derivative"):
2310+
evaluation.message("Solve", "ivar", vars_original)
2311+
2312+
def solve_recur(expression: Expression):
2313+
'''solve And, Or and List within the scope of sympy,
2314+
but including the translation from Mathics to sympy
2315+
2316+
returns:
2317+
solutions: a list of sympy solution dictionaries
2318+
conditions: a sympy condition object
2319+
2320+
note:
2321+
for And and List, should always return either (solutions, None) or ([], conditions)
2322+
for Or, all combinations are possible. if Or is root, should be handled outside'''
2323+
head = expression.get_head_name()
2324+
if head in ("System`And", "System`List"):
2325+
solutions = []
2326+
equations: list[Expression] = []
2327+
inequations = []
2328+
for child in expression.elements:
2329+
if child.has_form("Equal", 2):
2330+
equations.append(child)
2331+
elif child.get_head_name() in ("System`And", "System`Or"):
2332+
sub_solution, sub_condition = solve_recur(child)
2333+
solutions.extend(sub_solution)
2334+
if sub_condition is not None:
2335+
inequations.append(sub_condition)
2336+
else:
2337+
inequations.append(child.to_sympy())
2338+
solutions.extend(solve_sympy(equations))
2339+
conditions = sympy.And(*inequations)
2340+
result = [sol for sol in solutions if conditions.subs(sol)]
2341+
return result, None if solutions else conditions
2342+
else: # should be System`Or then
2343+
assert head == "System`Or"
2344+
solutions = []
2345+
conditions = []
2346+
for child in expression.elements:
2347+
if child.has_form("Equal", 2):
2348+
solutions.extend(solve_sympy(child))
2349+
elif child.get_head_name() in ("System`And", "System`Or"): # List wouldn't be in here
2350+
sub_solution, sub_condition = solve_recur(child)
2351+
solutions.extend(sub_solution)
2352+
if sub_condition is not None:
2353+
conditions.append(sub_condition)
2354+
else:
2355+
# SymbolTrue and SymbolFalse are allowed here since it's subtree context
2356+
# FIXME: None is not allowed, not sure what to do here
2357+
conditions.append(child.to_sympy())
2358+
conditions = sympy.Or(*conditions)
2359+
return solutions, conditions
2360+
2361+
if eqs.get_head_name() in ("System`List", "System`And", "System`Or"):
2362+
solutions, conditions = solve_recur(eqs)
2363+
# non True conditions are only accepted in subtrees, not root
2364+
if conditions is not None:
2365+
evaluation.message("Solve", "fulldim")
2366+
return ListExpression(ListExpression())
2367+
else:
2368+
if eqs.has_form("Equal", 2):
2369+
solutions = solve_sympy(eqs)
23302370
else:
2331-
result = sympy.solve(sympy_eqs, vars_sympy)
2332-
if not isinstance(result, list):
2333-
result = [result]
2334-
if isinstance(result, list) and len(result) == 1 and result[0] is True:
2371+
evaluation.message("Solve", "fulldim")
23352372
return ListExpression(ListExpression())
2336-
if result == [None]:
2337-
return ListExpression()
2338-
results = []
2339-
for sol in result:
2340-
results.extend(transform_solution(sol))
2341-
result = results
2342-
# filter with conditions before further translation
2343-
conditions = sympy.And(*sympy_conditions)
2344-
result = [sol for sol in result if conditions.subs(sol)]
2345-
2346-
if any(
2347-
sol and any(var not in sol for var in all_vars_sympy) for sol in result
2348-
):
2349-
evaluation.message("Solve", "svars")
23502373

2351-
# Filter out results for which denominator is 0
2352-
# (SymPy should actually do that itself, but it doesn't!)
2353-
result = [
2354-
sol
2355-
for sol in result
2356-
if all(sympy.simplify(denom.subs(sol)) != 0 for denom in sympy_denoms)
2357-
]
2358-
2359-
return ListExpression(
2360-
*(
2361-
ListExpression(
2362-
*(
2363-
Expression(SymbolRule, var, from_sympy(sol[var_sympy]))
2364-
for var, var_sympy in zip(vars, vars_sympy)
2365-
if var_sympy in sol
2366-
),
2367-
)
2368-
for sol in result
2369-
),
2370-
)
2371-
except sympy.PolynomialError:
2372-
# raised for e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] when not deleting
2373-
# unused variables beforehand
2374-
pass
2375-
except NotImplementedError:
2376-
pass
2377-
except TypeError as exc:
2378-
if str(exc).startswith("expected Symbol, Function or Derivative"):
2379-
evaluation.message("Solve", "ivar", vars_original)
2374+
if any(
2375+
sol and any(var not in sol for var in vars_sympy) for sol in solutions
2376+
):
2377+
evaluation.message("Solve", "svars")
2378+
2379+
return ListExpression(
2380+
*(
2381+
ListExpression(
2382+
*(
2383+
Expression(SymbolRule, var, from_sympy(sol[var_sympy]))
2384+
for var, var_sympy in zip(vars, all_var_tuples)
2385+
if var_sympy in sol
2386+
),
2387+
)
2388+
for sol in solutions
2389+
),
2390+
)
23802391

23812392

23822393
# Auxiliary routines. Maybe should be moved to another module.

0 commit comments

Comments
 (0)