Skip to content

Commit 05c3a64

Browse files
authored
Implement PEP 302 optional get_code loader method
We implement the optional get_code loader method. This increases compatibility with other tools/libraries that need to manipulate the code object of a module before it is executed.
1 parent 2b40981 commit 05c3a64

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

src/_pytest/assertion/rewrite.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(self, config: Config) -> None:
8080
self._basenames_to_check_rewrite = {"conftest"}
8181
self._marked_for_rewrite_cache: dict[str, bool] = {}
8282
self._session_paths_checked = False
83+
self.fn: str | None = None
8384

8485
def set_session(self, session: Session | None) -> None:
8586
self.session = session
@@ -126,7 +127,7 @@ def find_spec(
126127
):
127128
return None
128129
else:
129-
fn = spec.origin
130+
self.fn = fn = spec.origin
130131

131132
if not self._should_rewrite(name, fn, state):
132133
return None
@@ -143,14 +144,11 @@ def create_module(
143144
) -> types.ModuleType | None:
144145
return None # default behaviour is fine
145146

146-
def exec_module(self, module: types.ModuleType) -> None:
147-
assert module.__spec__ is not None
148-
assert module.__spec__.origin is not None
149-
fn = Path(module.__spec__.origin)
147+
def get_code(self, fullname: str) -> types.CodeType
148+
assert self.fn is not None
149+
fn = Path(self.fn)
150150
state = self.config.stash[assertstate_key]
151151

152-
self._rewritten_names[module.__name__] = fn
153-
154152
# The requested module looks like a test file, so rewrite it. This is
155153
# the most magical part of the process: load the source, rewrite the
156154
# asserts, and load the rewritten source. We also cache the rewritten
@@ -183,7 +181,15 @@ def exec_module(self, module: types.ModuleType) -> None:
183181
self._writing_pyc = False
184182
else:
185183
state.trace(f"found cached rewritten pyc for {fn}")
186-
exec(co, module.__dict__)
184+
185+
return co
186+
187+
def exec_module(self, module: types.ModuleType) -> None:
188+
module_name = module.__name__
189+
190+
self._rewritten_names[module_name] = fn
191+
192+
exec(self.get_code(module_name), module.__dict__)
187193

188194
def _early_rewrite_bailout(self, name: str, state: AssertionState) -> bool:
189195
"""A fast way to get out of rewriting modules.

0 commit comments

Comments
 (0)