diff --git a/sentry_sdk/integrations/asgi.py b/sentry_sdk/integrations/asgi.py index 3569336aae..e5bf461cfe 100644 --- a/sentry_sdk/integrations/asgi.py +++ b/sentry_sdk/integrations/asgi.py @@ -145,6 +145,22 @@ def __init__( else: self.__call__ = self._run_asgi2 + def _capture_lifespan_exception(self, exc): + # type: (Exception) -> None + """Capture exceptions raise in application lifespan handlers. + + The separate function is needed to support overriding in derived integrations that use different catching mechanisms. + """ + return _capture_exception(exc=exc, mechanism_type=self.mechanism_type) + + def _capture_request_exception(self, exc): + # type: (Exception) -> None + """Capture exceptions raised in incoming request handlers. + + The separate function is needed to support overriding in derived integrations that use different catching mechanisms. + """ + return _capture_exception(exc=exc, mechanism_type=self.mechanism_type) + def _run_asgi2(self, scope): # type: (Any) -> Any async def inner(receive, send): @@ -158,7 +174,7 @@ async def _run_asgi3(self, scope, receive, send): return await self._run_app(scope, receive, send, asgi_version=3) async def _run_app(self, scope, receive, send, asgi_version): - # type: (Any, Any, Any, Any, int) -> Any + # type: (Any, Any, Any, int) -> Any is_recursive_asgi_middleware = _asgi_middleware_applied.get(False) is_lifespan = scope["type"] == "lifespan" if is_recursive_asgi_middleware or is_lifespan: @@ -169,7 +185,7 @@ async def _run_app(self, scope, receive, send, asgi_version): return await self.app(scope, receive, send) except Exception as exc: - _capture_exception(exc, mechanism_type=self.mechanism_type) + self._capture_lifespan_exception(exc) raise exc from None _asgi_middleware_applied.set(True) @@ -255,7 +271,7 @@ async def _sentry_wrapped_send(event): scope, receive, _sentry_wrapped_send ) except Exception as exc: - _capture_exception(exc, mechanism_type=self.mechanism_type) + self._capture_request_exception(exc) raise exc from None finally: _asgi_middleware_applied.set(False) diff --git a/sentry_sdk/integrations/litestar.py b/sentry_sdk/integrations/litestar.py index 5f0b32b04e..e186222689 100644 --- a/sentry_sdk/integrations/litestar.py +++ b/sentry_sdk/integrations/litestar.py @@ -87,6 +87,15 @@ def __init__(self, app, span_origin=LitestarIntegration.origin): span_origin=span_origin, ) + def _capture_request_exception(self, exc): + # type: (Exception) -> None + """Avoid catching exceptions from request handlers. + + Those exceptions are already han in Litestar.after_exception handler. + We still catch exceptions from application lifespan handlers. + """ + pass + def patch_app_init(): # type: () -> None diff --git a/tests/integrations/litestar/test_litestar.py b/tests/integrations/litestar/test_litestar.py index 4f642479e4..e7979c24c1 100644 --- a/tests/integrations/litestar/test_litestar.py +++ b/tests/integrations/litestar/test_litestar.py @@ -402,7 +402,7 @@ async def __call__(self, scope, receive, send): @parametrize_test_configurable_status_codes -def test_configurable_status_codes( +def test_configurable_status_codes_handler( sentry_init, capture_events, failed_request_status_codes, @@ -427,3 +427,36 @@ async def error() -> None: client.get("/error") assert len(events) == int(expected_error) + + +@parametrize_test_configurable_status_codes +def test_configurable_status_codes_middleware( + sentry_init, + capture_events, + failed_request_status_codes, + status_code, + expected_error, +): + integration_kwargs = ( + {"failed_request_status_codes": failed_request_status_codes} + if failed_request_status_codes is not None + else {} + ) + sentry_init(integrations=[LitestarIntegration(**integration_kwargs)]) + + events = capture_events() + + def create_raising_middleware(app): + async def raising_middleware(scope, receive, send): + raise HTTPException(status_code=status_code) + + return raising_middleware + + @get("/error") + async def error() -> None: ... + + app = Litestar([error], middleware=[create_raising_middleware]) + client = TestClient(app) + client.get("/error") + + assert len(events) == int(expected_error)