Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tests/integrations/mcp/streaming_asgi_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import asyncio
from httpx import ASGITransport, Request, Response, AsyncByteStream
import anyio

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Callable, MutableMapping


class StreamingASGITransport(ASGITransport):
"""
Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing
tests involving SSE interactions to run in-process.
"""

def __init__(
self,
app: "Callable",
keep_sse_alive: "asyncio.Event",
) -> None:
self.keep_sse_alive = keep_sse_alive
super().__init__(app)

async def handle_async_request(self, request: "Request") -> "Response":
scope = {
"type": "http",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"path": request.url.path,
"query_string": request.url.query,
}

is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse"
if not is_streaming_sse:
return await super().handle_async_request(request)

request_body = b""
if request.content:
request_body = await request.aread()

body_sender, body_receiver = anyio.create_memory_object_stream[bytes](0)

async def receive() -> "dict[str, Any]":
if self.keep_sse_alive.is_set():
return {"type": "http.disconnect"}

await self.keep_sse_alive.wait() # Keep alive :)
return {"type": "http.request", "body": request_body, "more_body": False}

async def send(message: "MutableMapping[str, Any]") -> None:
if message["type"] == "http.response.body":
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body == b"" and not more_body:
return

if body:
await body_sender.send(body)

if not more_body:
await body_sender.aclose()

async def run_app():
await self.app(scope, receive, send)

class StreamingBodyStream(AsyncByteStream):
def __init__(self, receiver):
self.receiver = receiver

async def __aiter__(self):
try:
async for chunk in self.receiver:
yield chunk
except anyio.EndOfStream:
pass

stream = StreamingBodyStream(body_receiver)
response = Response(status_code=200, headers=[], stream=stream)

asyncio.create_task(run_app())
return response
186 changes: 148 additions & 38 deletions tests/integrations/mcp/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
that the integration properly instruments MCP handlers with Sentry spans.
"""

from urllib.parse import urlparse, parse_qs
import anyio
import asyncio
import httpx
from .streaming_asgi_transport import StreamingASGITransport

import pytest
import json
from unittest import mock
Expand All @@ -32,9 +38,10 @@ async def __call__(self, *args, **kwargs):
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.lowlevel import Server
from mcp.server.lowlevel.server import request_ctx
from mcp.server.sse import SseServerTransport
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager

from starlette.routing import Mount
from starlette.routing import Mount, Route, Response
from starlette.applications import Starlette

try:
Expand Down Expand Up @@ -66,39 +73,103 @@ def reset_request_ctx():
pass


class MockRequestContext:
"""Mock MCP request context"""

def __init__(self, request_id=None, session_id=None, transport="stdio"):
self.request_id = request_id
if transport in ("http", "sse"):
self.request = MockHTTPRequest(session_id, transport)
else:
self.request = None
class MockTextContent:
"""Mock TextContent object"""

def __init__(self, text):
self.text = text

class MockHTTPRequest:
"""Mock HTTP request for SSE/StreamableHTTP transport"""

def __init__(self, session_id=None, transport="http"):
self.headers = {}
self.query_params = {}
async def json_rpc_sse(
app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event"
):
context = {}

stream_complete = asyncio.Event()
endpoint_parsed = asyncio.Event()

# https://github.com/Kludex/starlette/issues/104#issuecomment-729087925
async with httpx.AsyncClient(
transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive),
base_url="http://test",
) as client:

async def parse_stream():
async with client.stream("GET", "/sse") as stream:
# Read directly from stream.stream instead of aiter_bytes()
async for chunk in stream.stream:
if b"event: endpoint" in chunk:
sse_text = chunk.decode("utf-8")
url = sse_text.split("data: ")[1]

parsed = urlparse(url)
query_params = parse_qs(parsed.query)
context["session_id"] = query_params["session_id"][0]
endpoint_parsed.set()
continue

if b"event: message" in chunk and b"structuredContent" in chunk:
sse_text = chunk.decode("utf-8")

json_str = sse_text.split("data: ")[1]
context["response"] = json.loads(json_str)
break

stream_complete.set()

task = asyncio.create_task(parse_stream())
await endpoint_parsed.wait()

await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
},
json={
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": "2025-11-25",
"capabilities": {},
},
"id": request_id,
},
)

if transport == "sse":
# SSE transport uses query parameter
if session_id:
self.query_params["session_id"] = session_id
else:
# StreamableHTTP transport uses header
if session_id:
self.headers["mcp-session-id"] = session_id
# Notification response is mandatory.
# https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle
await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
"mcp-session-id": context["session_id"],
},
json={
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
},
)

await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
"mcp-session-id": context["session_id"],
},
json={
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": request_id,
},
)

class MockTextContent:
"""Mock TextContent object"""
await stream_complete.wait()
keep_sse_alive.set()

def __init__(self, text):
self.text = text
return task, context["session_id"], context["response"]


def test_integration_patches_server(sentry_init):
Expand Down Expand Up @@ -985,7 +1056,8 @@ def test_tool_complex(tool_name, arguments):
assert span["data"]["mcp.request.argument.number"] == "42"


def test_sse_transport_detection(sentry_init, capture_events):
@pytest.mark.asyncio
async def test_sse_transport_detection(sentry_init, capture_events):
"""Test that SSE transport is correctly detected via query parameter"""
sentry_init(
integrations=[MCPIntegration()],
Expand All @@ -994,29 +1066,67 @@ def test_sse_transport_detection(sentry_init, capture_events):
events = capture_events()

server = Server("test-server")
sse = SseServerTransport("/messages/")

# Set up mock request context with SSE transport
mock_ctx = MockRequestContext(
request_id="req-sse", session_id="session-sse-123", transport="sse"
sse_connection_closed = asyncio.Event()

async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
async with anyio.create_task_group() as tg:

async def run_server():
await server.run(
streams[0], streams[1], server.create_initialization_options()
)

tg.start_soon(run_server)

sse_connection_closed.set()
return Response()

app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse, methods=["GET"]),
Mount("/messages/", app=sse.handle_post_message),
],
)
request_ctx.set(mock_ctx)

@server.call_tool()
def test_tool(tool_name, arguments):
async def test_tool(tool_name, arguments):
return {"result": "success"}

with start_transaction(name="mcp tx"):
result = test_tool("sse_tool", {})
keep_sse_alive = asyncio.Event()
app_task, session_id, result = await json_rpc_sse(
app,
method="tools/call",
params={
"name": "sse_tool",
"arguments": {},
},
request_id="req-sse",
keep_sse_alive=keep_sse_alive,
)

assert result == {"result": "success"}
await sse_connection_closed.wait()
await app_task

(tx,) = events
assert result["result"]["structuredContent"] == {"result": "success"}

transactions = [
event
for event in events
if event["type"] == "transaction" and event["transaction"] == "/sse"
]
assert len(transactions) == 1
tx = transactions[0]
span = tx["spans"][0]

# Check that SSE transport is detected
assert span["data"][SPANDATA.MCP_TRANSPORT] == "sse"
assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp"
assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-sse-123"
assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id


def test_streamable_http_transport_detection(
Expand Down
Loading