Skip to content

Commit 23e649e

Browse files
authored
Add Transformer API for in-place manipulation of AST (#99). Fixes #97
* Add perf test for something actually modifying the ast * Add Transformer API for in-place manipulation of AST, deprecate .traverse() * Align APIs of Visitor and Transformer
1 parent f3f9053 commit 23e649e

File tree

3 files changed

+148
-10
lines changed

3 files changed

+148
-10
lines changed

fluent.syntax/fluent/syntax/ast.py

+62-9
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@ class Visitor(object):
1313
The boolean value of the returned value determines if the visitor
1414
descends into the children of the given AST node.
1515
'''
16-
def visit(self, value):
17-
if isinstance(value, BaseNode):
18-
self.visit_node(value)
19-
if isinstance(value, list):
20-
for node in value:
21-
self.visit(node)
22-
23-
def visit_node(self, node):
16+
def visit(self, node):
17+
if isinstance(node, list):
18+
for child in node:
19+
self.visit(child)
20+
return
21+
if not isinstance(node, BaseNode):
22+
return
2423
nodename = type(node).__name__
2524
visit = getattr(self, 'visit_{}'.format(nodename), self.generic_visit)
2625
should_descend = visit(node)
@@ -33,6 +32,41 @@ def generic_visit(self, node):
3332
return True
3433

3534

35+
class Transformer(Visitor):
36+
'''In-place AST Transformer pattern.
37+
38+
Subclass this to create an in-place modified variant
39+
of the given AST.
40+
If you need to keep the original AST around, pass
41+
a `node.clone()` to the transformer.
42+
'''
43+
def visit(self, node):
44+
if not isinstance(node, BaseNode):
45+
return node
46+
47+
nodename = type(node).__name__
48+
visit = getattr(self, 'visit_{}'.format(nodename), self.generic_visit)
49+
return visit(node)
50+
51+
def generic_visit(self, node):
52+
for propname, propvalue in vars(node).items():
53+
if isinstance(propvalue, list):
54+
new_vals = []
55+
for child in propvalue:
56+
new_val = self.visit(child)
57+
if new_val is not None:
58+
new_vals.append(new_val)
59+
# in-place manipulation
60+
propvalue[:] = new_vals
61+
elif isinstance(propvalue, BaseNode):
62+
new_val = self.visit(propvalue)
63+
if new_val is None:
64+
delattr(node, propname)
65+
else:
66+
setattr(node, propname, new_val)
67+
return node
68+
69+
3670
def to_json(value, fn=None):
3771
if isinstance(value, BaseNode):
3872
return value.to_json(fn)
@@ -79,7 +113,9 @@ class BaseNode(object):
79113
"""
80114

81115
def traverse(self, fun):
82-
"""Postorder-traverse this node and apply `fun` to all child nodes.
116+
"""DEPRECATED. Please use Visitor or Transformer.
117+
118+
Postorder-traverse this node and apply `fun` to all child nodes.
83119
84120
Traverse this node depth-first applying `fun` to subnodes and leaves.
85121
Children are processed before parents (postorder traversal).
@@ -103,6 +139,23 @@ def visit(value):
103139

104140
return fun(node)
105141

142+
def clone(self):
143+
"""Create a deep clone of the current node."""
144+
def visit(value):
145+
"""Clone node and its descendants."""
146+
if isinstance(value, BaseNode):
147+
return value.clone()
148+
if isinstance(value, list):
149+
return [visit(child) for child in value]
150+
if isinstance(value, tuple):
151+
return tuple(visit(child) for child in value)
152+
return value
153+
154+
# Use all attributes found on the node as kwargs to the constructor.
155+
return self.__class__(
156+
**{name: visit(value) for name, value in vars(self).items()}
157+
)
158+
106159
def equals(self, other, ignored_fields=['span']):
107160
"""Compare two nodes.
108161

fluent.syntax/tests/syntax/test_equals.py

+6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def test_same_simple_message(self):
2626

2727
self.assertTrue(message1.equals(message1))
2828
self.assertTrue(message1.equals(message1.traverse(identity)))
29+
self.assertTrue(message1.equals(message1.clone()))
2930

3031
def test_same_selector_message(self):
3132
message1 = self.parse_ftl_entry("""\
@@ -41,6 +42,7 @@ def test_same_selector_message(self):
4142

4243
self.assertTrue(message1.equals(message1))
4344
self.assertTrue(message1.equals(message1.traverse(identity)))
45+
self.assertTrue(message1.equals(message1.clone()))
4446

4547
def test_same_complex_placeable_message(self):
4648
message1 = self.parse_ftl_entry("""\
@@ -49,6 +51,7 @@ def test_same_complex_placeable_message(self):
4951

5052
self.assertTrue(message1.equals(message1))
5153
self.assertTrue(message1.equals(message1.traverse(identity)))
54+
self.assertTrue(message1.equals(message1.clone()))
5255

5356
def test_same_message_with_attribute(self):
5457
message1 = self.parse_ftl_entry("""\
@@ -58,6 +61,7 @@ def test_same_message_with_attribute(self):
5861

5962
self.assertTrue(message1.equals(message1))
6063
self.assertTrue(message1.equals(message1.traverse(identity)))
64+
self.assertTrue(message1.equals(message1.clone()))
6165

6266
def test_same_message_with_attributes(self):
6367
message1 = self.parse_ftl_entry("""\
@@ -68,6 +72,7 @@ def test_same_message_with_attributes(self):
6872

6973
self.assertTrue(message1.equals(message1))
7074
self.assertTrue(message1.equals(message1.traverse(identity)))
75+
self.assertTrue(message1.equals(message1.clone()))
7176

7277
def test_same_junk(self):
7378
message1 = self.parse_ftl_entry("""\
@@ -76,6 +81,7 @@ def test_same_junk(self):
7681

7782
self.assertTrue(message1.equals(message1))
7883
self.assertTrue(message1.equals(message1.traverse(identity)))
84+
self.assertTrue(message1.equals(message1.clone()))
7985

8086

8187
class TestOrderEquals(unittest.TestCase):

fluent.syntax/tests/syntax/test_visitor.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,30 @@ def test_resource(self):
5050
)
5151

5252

53+
class TestTransformer(unittest.TestCase):
54+
def test(self):
55+
resource = FluentParser().parse(dedent_ftl('''\
56+
one = Message
57+
two = Messages
58+
three = Has a
59+
.an = Message string in the Attribute
60+
'''))
61+
prior_res_id = id(resource)
62+
prior_msg_id = id(resource.body[1].value)
63+
backup = resource.clone()
64+
transformed = ReplaceTransformer('Message', 'Term').visit(resource)
65+
self.assertEqual(prior_res_id, id(transformed))
66+
self.assertEqual(
67+
prior_msg_id,
68+
id(transformed.body[1].value)
69+
)
70+
self.assertFalse(transformed.equals(backup))
71+
self.assertEqual(
72+
transformed.body[1].value.elements[0].value,
73+
'Terms'
74+
)
75+
76+
5377
class WordCounter(object):
5478
def __init__(self):
5579
self.word_count = 0
@@ -70,6 +94,34 @@ def visit_TextElement(self, node):
7094
return False
7195

7296

97+
class ReplaceText(object):
98+
def __init__(self, before, after):
99+
self.before = before
100+
self.after = after
101+
102+
def __call__(self, node):
103+
"""Perform find and replace on text values only"""
104+
if type(node) == ast.TextElement:
105+
node.value = node.value.replace(self.before, self.after)
106+
return node
107+
108+
109+
class ReplaceTransformer(ast.Transformer):
110+
def __init__(self, before, after):
111+
self.before = before
112+
self.after = after
113+
114+
def generic_visit(self, node):
115+
if isinstance(node, (ast.Span, ast.Annotation)):
116+
return node
117+
return super(ReplaceTransformer, self).generic_visit(node)
118+
119+
def visit_TextElement(self, node):
120+
"""Perform find and replace on text values only"""
121+
node.value = node.value.replace(self.before, self.after)
122+
return node
123+
124+
73125
class TestPerf(unittest.TestCase):
74126
def setUp(self):
75127
parser = FluentParser()
@@ -89,6 +141,27 @@ def test_visitor(self):
89141
counter.visit(self.resource)
90142
self.assertEqual(counter.word_count, 277)
91143

144+
def test_edit_traverse(self):
145+
edited = self.resource.traverse(ReplaceText('Tab', 'Reiter'))
146+
self.assertEqual(
147+
edited.body[4].attributes[0].value.elements[0].value,
148+
'New Reiter'
149+
)
150+
151+
def test_edit_transform(self):
152+
edited = ReplaceTransformer('Tab', 'Reiter').visit(self.resource)
153+
self.assertEqual(
154+
edited.body[4].attributes[0].value.elements[0].value,
155+
'New Reiter'
156+
)
157+
158+
def test_edit_cloned(self):
159+
edited = ReplaceTransformer('Tab', 'Reiter').visit(self.resource.clone())
160+
self.assertEqual(
161+
edited.body[4].attributes[0].value.elements[0].value,
162+
'New Reiter'
163+
)
164+
92165

93166
def gather_stats(method, repeat=10, number=50):
94167
t = timeit.Timer(
@@ -107,7 +180,13 @@ def gather_stats(method, repeat=10, number=50):
107180

108181

109182
if __name__=='__main__':
110-
for m in ('traverse', 'visitor'):
183+
for m in (
184+
'traverse',
185+
'visitor',
186+
'edit_traverse',
187+
'edit_transform',
188+
'edit_cloned',
189+
):
111190
results = gather_stats(m)
112191
try:
113192
import statistics

0 commit comments

Comments
 (0)