Skip to content

Commit bfa476a

Browse files
authored
Adhoc handling of custom entries for notations and improvements on plugins (VernacExtend) (#51)
* handle custom entries and (partially) gen args for plugins * hotfix listarg * Add test for the nth locate result in a step context * Remove handling of ListArg for identref, no longer needed * Fix equations and add test for it * Add coq-equations to workflow * bump ocaml version * Add coq-released to workflow
1 parent 226c0ff commit bfa476a

File tree

8 files changed

+130
-44
lines changed

8 files changed

+130
-44
lines changed

.github/workflows/test.yml

+14-4
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ jobs:
1212
strategy:
1313
matrix:
1414
ocaml-compiler:
15-
- "4.11"
15+
- "5.2.0"
1616
coq-version:
1717
- "8.17.1"
1818
- "8.18.0"
19-
- "8.19.1"
19+
- "8.19.2"
2020

2121
steps:
2222
- name: Checkout
@@ -40,8 +40,8 @@ jobs:
4040
/home/runner/work/coqpyt/coqpyt/_opam/
4141
key: ${{ matrix.ocaml-compiler }}-${{ matrix.coq-version }}-opam
4242

43-
- name: Set-up OCaml ${{ matrix.ocaml-compiler }}
44-
uses: ocaml/setup-ocaml@v2
43+
- name: Set-up OCaml
44+
uses: ocaml/setup-ocaml@v3
4545
with:
4646
ocaml-compiler: ${{ matrix.ocaml-compiler }}
4747

@@ -51,6 +51,16 @@ jobs:
5151
opam pin add coq ${{ matrix.coq-version }}
5252
opam install coq-lsp
5353
54+
- name: Add coq-released
55+
if: steps.cache-opam-restore.outputs.cache-hit != 'true'
56+
run: |
57+
opam repo add coq-released https://coq.inria.fr/opam/released
58+
59+
- name: Install coq-equations
60+
if: steps.cache-opam-restore.outputs.cache-hit != 'true'
61+
run: |
62+
opam install coq-equations
63+
5464
- name: Install coqpyt
5565
run: |
5666
pip install -e .

coqpyt/coq/context.py

+45-18
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,13 @@ def __add_terms(self, step: Step, expr: List):
107107
self.__anonymous_id += 1
108108
self.__add_term(name, step, term_type)
109109
elif term_type == TermType.DERIVE:
110-
name = FileContext.get_ident(expr[2][0])
111-
self.__add_term(name, step, term_type)
112-
if self.__ext_entry(expr[1]) == "Derive":
113-
prop = FileContext.get_ident(expr[2][2])
114-
self.__add_term(prop, step, term_type)
115-
elif term_type == TermType.OBLIGATION:
110+
for arg in expr[2]:
111+
name = FileContext.get_ident(arg)
112+
if name is not None:
113+
self.__add_term(name, step, term_type)
114+
elif term_type in [TermType.OBLIGATION, TermType.EQUATION]:
115+
# FIXME: For Equations, we are unable of getting terms from the AST
116+
# but these commands do generate named terms
116117
self.__last_terms[-1].append(
117118
("", Term(step, term_type, self.__path, self.__segments.modules[:]))
118119
)
@@ -241,12 +242,15 @@ def __term_type(self, expr: List) -> TermType:
241242
return TermType.FIXPOINT
242243
if expr[0] == "VernacScheme":
243244
return TermType.SCHEME
245+
# FIXME: These are plugins and should probably be handled differently
244246
if self.__is_extend(expr, "Obligations"):
245247
return TermType.OBLIGATION
246248
if self.__is_extend(expr, "VernacDeclareTacticDefinition"):
247249
return TermType.TACTIC
248250
if self.__is_extend(expr, "Function"):
249251
return TermType.FUNCTION
252+
if self.__is_extend(expr, "Define_equations", exact=False):
253+
return TermType.EQUATION
250254
if self.__is_extend(expr, "Derive", exact=False):
251255
return TermType.DERIVE
252256
if self.__is_extend(expr, "AddSetoid", exact=False):
@@ -293,6 +297,16 @@ def __get_names(expr: List) -> List[str]:
293297
stack.append(v)
294298
return res
295299

300+
@staticmethod
301+
def is_id(el) -> bool:
302+
return isinstance(el, list) and (len(el) == 3 and el[0] == "Ser_Qualid")
303+
304+
@staticmethod
305+
def is_notation(el) -> bool:
306+
return isinstance(el, list) and (
307+
len(el) == 4 and el[0] == "CNotation" and el[2][1] != ""
308+
)
309+
296310
@staticmethod
297311
def get_id(id: List) -> Optional[str]:
298312
# FIXME: This should be made private once [__step_context] is extracted
@@ -305,18 +319,18 @@ def get_id(id: List) -> Optional[str]:
305319

306320
@staticmethod
307321
def get_ident(el: List) -> Optional[str]:
308-
# FIXME: This should be made private once [__get_program_context] is extracted
309-
# from ProofFile to here.
310-
if (
311-
len(el) == 3
312-
and el[0] == "GenArg"
313-
and el[1][0] == "Rawwit"
314-
and el[1][1][0] == "ExtraArg"
315-
):
316-
if el[1][1][1] == "identref":
317-
return el[2][0][1][1]
318-
elif el[1][1][1] == "ident":
319-
return el[2][1]
322+
# FIXME: This method should be made private once [__get_program_context]
323+
# is extracted from ProofFile to here.
324+
def handle_arg_type(args, ids):
325+
if args[0] == "ExtraArg":
326+
if args[1] == "identref":
327+
return ids[0][1][1]
328+
elif args[1] == "ident":
329+
return ids[1]
330+
return None
331+
332+
if len(el) == 3 and el[0] == "GenArg" and el[1][0] == "Rawwit":
333+
return handle_arg_type(el[1][1], el[2])
320334
return None
321335

322336
@staticmethod
@@ -546,6 +560,19 @@ def get_term(self, name: str) -> Optional[Term]:
546560
return self.__terms[curr_name][-1]
547561
return None
548562

563+
@staticmethod
564+
def get_notation_scope(notation: str) -> str:
565+
"""Get the scope of a notation.
566+
Args:
567+
notation (str): Possibly scoped notation pattern. E.g. "_ + _ : nat_scope".
568+
569+
Returns:
570+
str: The scope of the notation. E.g. "nat_scope".
571+
"""
572+
if notation.split(":")[-1].endswith("_scope"):
573+
return notation.split(":")[-1].strip()
574+
return ""
575+
549576
def get_notation(self, notation: str, scope: str) -> Term:
550577
"""Get a notation from the context.
551578

coqpyt/coq/proof_file.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def __init__(
313313
Args:
314314
file_path (str): Path of the Coq file.
315315
library (Optional[str], optional): The library of the file. Defaults to None.
316-
timeout (int, optional): Timeout used in coq-lsp. Defaults to 2.
316+
timeout (int, optional): Timeout used in coq-lsp. Defaults to 30.
317317
workspace (Optional[str], optional): Absolute path for the workspace.
318318
If the workspace is not defined, the workspace is equal to the
319319
path of the file.
@@ -372,42 +372,44 @@ def _handle_exception(self, e):
372372
raise e
373373

374374
def __locate(self, search, line):
375-
nots = self.__aux_file.get_diagnostics("Locate", f'"{search}"', line).split(
376-
"\n"
377-
)
378-
fun = lambda x: x.endswith("(default interpretation)")
379-
return nots[0][:-25] if fun(nots[0]) else nots[0]
375+
located = self.__aux_file.get_diagnostics("Locate", f'"{search}"', line)
376+
trim = lambda x: x[:-25] if x.endswith("(default interpretation)") else x
377+
return list(map(trim, located.split("\n")))
380378

381379
def __step_context(self, step: Step) -> List[Term]:
382380
stack, res = self.context.expr(step)[:0:-1], []
383381
while len(stack) > 0:
384382
el = stack.pop()
385-
if isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid":
383+
if FileContext.is_id(el):
386384
term = self.context.get_term(FileContext.get_id(el))
387385
if term is not None and term not in res:
388386
res.append(term)
389-
elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation":
387+
elif FileContext.is_notation(el):
388+
stack.append(el[1:])
389+
390+
notation_name = el[2][1]
390391
line = len(self.__aux_file.read().split("\n"))
391-
self.__aux_file.append(f'\nLocate "{el[2][1]}".')
392+
self.__aux_file.append(f'\nLocate "{notation_name}".')
392393
self.__aux_file.didChange()
394+
notations = self.__locate(notation_name, line)
395+
if len(notations) == 1 and notations[0] == "Unknown notation":
396+
continue
393397

394-
notation_name, scope = el[2][1], ""
395-
notation = self.__locate(notation_name, line)
396-
if notation.split(":")[-1].endswith("_scope"):
397-
scope = notation.split(":")[-1].strip()
398-
399-
if notation != "Unknown notation":
398+
for notation in notations:
399+
scope = FileContext.get_notation_scope(notation)
400400
try:
401401
term = self.context.get_notation(notation_name, scope)
402402
if term not in res:
403403
res.append(term)
404-
except NotationNotFoundException as e:
405-
if self.__error_mode == "strict":
406-
raise e
407-
else:
408-
logging.warning(str(e))
409-
410-
stack.append(el[1:])
404+
break
405+
except NotationNotFoundException:
406+
continue
407+
else:
408+
e = NotationNotFoundException(notation_name)
409+
if self.__error_mode == "strict":
410+
raise e
411+
else:
412+
logging.warning(str(e))
411413
elif isinstance(el, list):
412414
for v in reversed(el):
413415
if isinstance(v, (dict, list)):

coqpyt/coq/structs.py

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class TermType(Enum):
3636
SETOID = 22
3737
FUNCTION = 23
3838
DERIVE = 24
39+
EQUATION = 25
3940
OTHER = 100
4041

4142

coqpyt/tests/proof_file/test_proof_file.py

+22
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,28 @@ def test_unknown_notation(self):
120120
assert self.proof_file.context.get_notation("{ _ }", "")
121121

122122

123+
class TestProofNthLocate(SetupProofFile):
124+
def setup_method(self, method):
125+
self.setup("test_nth_locate.v")
126+
127+
def test_nth_locate(self):
128+
"""Checks if it is able to handle notations that are not the first result
129+
returned by the Locate command.
130+
"""
131+
proof_file = self.proof_file
132+
assert len(proof_file.proofs) == 1
133+
proof = proof_file.proofs[0]
134+
135+
theorem = "Lemma test : <> = <>."
136+
assert proof.text == theorem
137+
138+
statement_context = [
139+
('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []),
140+
('Notation "<>" := BAnon : binder_scope.', TermType.NOTATION, []),
141+
]
142+
compare_context(statement_context, proof.context)
143+
144+
123145
class TestProofNestedProofs(SetupProofFile):
124146
def setup_method(self, method):
125147
self.setup("test_nested_proofs.v")
+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
From Equations Require Import Equations.
2+
3+
Equations? f (n : nat) : nat :=
4+
f 0 := 42 ;
5+
f (S m) with f m := { f (S m) IH := _ }.
6+
Proof. intros. exact IH. Defined.
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Inductive binder := BAnon | BNum :> nat -> binder.
2+
Declare Scope binder_scope.
3+
Notation "<>" := BAnon : binder_scope.
4+
5+
Open Scope binder_scope.
6+
Lemma test : <> = <>.
7+
Proof. reflexivity. Qed.

coqpyt/tests/test_coq_file.py

+11
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,17 @@ def test_derive(setup, teardown):
280280
)
281281

282282

283+
@pytest.mark.parametrize("setup", ["test_equations.v"], indirect=True)
284+
def test_derive(setup, teardown):
285+
coq_file.run()
286+
assert len(coq_file.context.terms) == 0
287+
assert coq_file.context.last_term is not None
288+
assert (
289+
coq_file.context.last_term.text
290+
== "Equations? f (n : nat) : nat := f 0 := 42 ; f (S m) with f m := { f (S m) IH := _ }."
291+
)
292+
293+
283294
def test_space_in_path():
284295
# This test exists because coq-lsp encodes spaces in paths as %20
285296
# This causes the diagnostics to be saved in a different path than the one

0 commit comments

Comments
 (0)