diff --git a/tests/integrations/mcp/streaming_asgi_transport.py b/tests/integrations/mcp/streaming_asgi_transport.py new file mode 100644 index 0000000000..681a5bc96e --- /dev/null +++ b/tests/integrations/mcp/streaming_asgi_transport.py @@ -0,0 +1,83 @@ +import asyncio +from httpx import ASGITransport, Request, Response, AsyncByteStream +import anyio + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, Callable, MutableMapping + + +class StreamingASGITransport(ASGITransport): + """ + Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing + tests involving SSE interactions to run in-process. + """ + + def __init__( + self, + app: "Callable", + keep_sse_alive: "asyncio.Event", + ) -> None: + self.keep_sse_alive = keep_sse_alive + super().__init__(app) + + async def handle_async_request(self, request: "Request") -> "Response": + scope = { + "type": "http", + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "path": request.url.path, + "query_string": request.url.query, + } + + is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse" + if not is_streaming_sse: + return await super().handle_async_request(request) + + request_body = b"" + if request.content: + request_body = await request.aread() + + body_sender, body_receiver = anyio.create_memory_object_stream[bytes](0) + + async def receive() -> "dict[str, Any]": + if self.keep_sse_alive.is_set(): + return {"type": "http.disconnect"} + + await self.keep_sse_alive.wait() # Keep alive :) + return {"type": "http.request", "body": request_body, "more_body": False} + + async def send(message: "MutableMapping[str, Any]") -> None: + if message["type"] == "http.response.body": + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body == b"" and not more_body: + return + + if body: + await body_sender.send(body) + + if not more_body: + await body_sender.aclose() + + async def run_app(): + await self.app(scope, receive, send) + + class StreamingBodyStream(AsyncByteStream): + def __init__(self, receiver): + self.receiver = receiver + + async def __aiter__(self): + try: + async for chunk in self.receiver: + yield chunk + except anyio.EndOfStream: + pass + + stream = StreamingBodyStream(body_receiver) + response = Response(status_code=200, headers=[], stream=stream) + + asyncio.create_task(run_app()) + return response diff --git a/tests/integrations/mcp/test_mcp.py b/tests/integrations/mcp/test_mcp.py index 5fe100850f..8569ad18e4 100644 --- a/tests/integrations/mcp/test_mcp.py +++ b/tests/integrations/mcp/test_mcp.py @@ -15,6 +15,12 @@ that the integration properly instruments MCP handlers with Sentry spans. """ +from urllib.parse import urlparse, parse_qs +import anyio +import asyncio +import httpx +from .streaming_asgi_transport import StreamingASGITransport + import pytest import json from unittest import mock @@ -32,9 +38,10 @@ async def __call__(self, *args, **kwargs): from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.lowlevel import Server from mcp.server.lowlevel.server import request_ctx +from mcp.server.sse import SseServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from starlette.routing import Mount +from starlette.routing import Mount, Route, Response from starlette.applications import Starlette try: @@ -66,39 +73,103 @@ def reset_request_ctx(): pass -class MockRequestContext: - """Mock MCP request context""" - - def __init__(self, request_id=None, session_id=None, transport="stdio"): - self.request_id = request_id - if transport in ("http", "sse"): - self.request = MockHTTPRequest(session_id, transport) - else: - self.request = None +class MockTextContent: + """Mock TextContent object""" + def __init__(self, text): + self.text = text -class MockHTTPRequest: - """Mock HTTP request for SSE/StreamableHTTP transport""" - def __init__(self, session_id=None, transport="http"): - self.headers = {} - self.query_params = {} +async def json_rpc_sse( + app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event" +): + context = {} + + stream_complete = asyncio.Event() + endpoint_parsed = asyncio.Event() + + # https://github.com/Kludex/starlette/issues/104#issuecomment-729087925 + async with httpx.AsyncClient( + transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive), + base_url="http://test", + ) as client: + + async def parse_stream(): + async with client.stream("GET", "/sse") as stream: + # Read directly from stream.stream instead of aiter_bytes() + async for chunk in stream.stream: + if b"event: endpoint" in chunk: + sse_text = chunk.decode("utf-8") + url = sse_text.split("data: ")[1] + + parsed = urlparse(url) + query_params = parse_qs(parsed.query) + context["session_id"] = query_params["session_id"][0] + endpoint_parsed.set() + continue + + if b"event: message" in chunk and b"structuredContent" in chunk: + sse_text = chunk.decode("utf-8") + + json_str = sse_text.split("data: ")[1] + context["response"] = json.loads(json_str) + break + + stream_complete.set() + + task = asyncio.create_task(parse_stream()) + await endpoint_parsed.wait() + + await client.post( + f"/messages/?session_id={context['session_id']}", + headers={ + "Content-Type": "application/json", + }, + json={ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-11-25", + "capabilities": {}, + }, + "id": request_id, + }, + ) - if transport == "sse": - # SSE transport uses query parameter - if session_id: - self.query_params["session_id"] = session_id - else: - # StreamableHTTP transport uses header - if session_id: - self.headers["mcp-session-id"] = session_id + # Notification response is mandatory. + # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle + await client.post( + f"/messages/?session_id={context['session_id']}", + headers={ + "Content-Type": "application/json", + "mcp-session-id": context["session_id"], + }, + json={ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {}, + }, + ) + await client.post( + f"/messages/?session_id={context['session_id']}", + headers={ + "Content-Type": "application/json", + "mcp-session-id": context["session_id"], + }, + json={ + "jsonrpc": "2.0", + "method": method, + "params": params, + "id": request_id, + }, + ) -class MockTextContent: - """Mock TextContent object""" + await stream_complete.wait() + keep_sse_alive.set() - def __init__(self, text): - self.text = text + return task, context["session_id"], context["response"] def test_integration_patches_server(sentry_init): @@ -985,7 +1056,8 @@ def test_tool_complex(tool_name, arguments): assert span["data"]["mcp.request.argument.number"] == "42" -def test_sse_transport_detection(sentry_init, capture_events): +@pytest.mark.asyncio +async def test_sse_transport_detection(sentry_init, capture_events): """Test that SSE transport is correctly detected via query parameter""" sentry_init( integrations=[MCPIntegration()], @@ -994,29 +1066,67 @@ def test_sse_transport_detection(sentry_init, capture_events): events = capture_events() server = Server("test-server") + sse = SseServerTransport("/messages/") - # Set up mock request context with SSE transport - mock_ctx = MockRequestContext( - request_id="req-sse", session_id="session-sse-123", transport="sse" + sse_connection_closed = asyncio.Event() + + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + async with anyio.create_task_group() as tg: + + async def run_server(): + await server.run( + streams[0], streams[1], server.create_initialization_options() + ) + + tg.start_soon(run_server) + + sse_connection_closed.set() + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ], ) - request_ctx.set(mock_ctx) @server.call_tool() - def test_tool(tool_name, arguments): + async def test_tool(tool_name, arguments): return {"result": "success"} - with start_transaction(name="mcp tx"): - result = test_tool("sse_tool", {}) + keep_sse_alive = asyncio.Event() + app_task, session_id, result = await json_rpc_sse( + app, + method="tools/call", + params={ + "name": "sse_tool", + "arguments": {}, + }, + request_id="req-sse", + keep_sse_alive=keep_sse_alive, + ) - assert result == {"result": "success"} + await sse_connection_closed.wait() + await app_task - (tx,) = events + assert result["result"]["structuredContent"] == {"result": "success"} + + transactions = [ + event + for event in events + if event["type"] == "transaction" and event["transaction"] == "/sse" + ] + assert len(transactions) == 1 + tx = transactions[0] span = tx["spans"][0] # Check that SSE transport is detected assert span["data"][SPANDATA.MCP_TRANSPORT] == "sse" assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp" - assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-sse-123" + assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id def test_streamable_http_transport_detection(