diff --git a/ld_eventsource/actions.py b/ld_eventsource/actions.py index 276eec1..2508a20 100644 --- a/ld_eventsource/actions.py +++ b/ld_eventsource/actions.py @@ -1,5 +1,7 @@ import json -from typing import Optional +from typing import Any, Dict, Optional + +from ld_eventsource.errors import ExceptionWithHeaders class Action: @@ -110,9 +112,25 @@ class Start(Action): Instances of this class are only available from :attr:`.SSEClient.all`. A ``Start`` is returned for the first successful connection. If the client reconnects after a failure, there will be a :class:`.Fault` followed by a ``Start``. + + Each ``Start`` action may include HTTP response headers from the connection. These headers + are available via the :attr:`headers` property. On reconnection, a new ``Start`` will be + emitted with the headers from the new connection, which may differ from the previous one. """ - pass + def __init__(self, headers: Optional[Dict[str, Any]] = None): + self._headers = headers + + @property + def headers(self) -> Optional[Dict[str, Any]]: + """ + The HTTP response headers from the stream connection, if available. + + The headers dict uses case-insensitive keys (via urllib3's HTTPHeaderDict). + + :return: the response headers, or ``None`` if not available + """ + return self._headers class Fault(Action): @@ -125,6 +143,9 @@ class Fault(Action): connection attempt has failed or an existing connection has been closed. The SSEClient will attempt to reconnect if you either call :meth:`.SSEClient.start()` or simply continue reading events after this point. + + When the error includes HTTP response headers (such as for :class:`.HTTPStatusError` + or :class:`.HTTPContentTypeError`), they are accessible via the :attr:`headers` property. """ def __init__(self, error: Optional[Exception]): @@ -138,3 +159,18 @@ def error(self) -> Optional[Exception]: in an orderly way after sending an EOF chunk as defined by chunked transfer encoding. """ return self.__error + + @property + def headers(self) -> Optional[Dict[str, Any]]: + """ + The HTTP response headers from the failed connection, if available. + + This property returns headers when the error is an exception that includes them, + such as :class:`.HTTPStatusError` or :class:`.HTTPContentTypeError`. For other + error types or when the stream ended normally, this returns ``None``. + + :return: the response headers, or ``None`` if not available + """ + if isinstance(self.__error, ExceptionWithHeaders): + return self.__error.headers + return None diff --git a/ld_eventsource/config/connect_strategy.py b/ld_eventsource/config/connect_strategy.py index 4770831..e723fc5 100644 --- a/ld_eventsource/config/connect_strategy.py +++ b/ld_eventsource/config/connect_strategy.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from logging import Logger -from typing import Callable, Iterator, Optional, Union +from typing import Any, Callable, Dict, Iterator, Optional, Union from urllib3 import PoolManager @@ -96,9 +96,10 @@ class ConnectionResult: The return type of :meth:`ConnectionClient.connect()`. """ - def __init__(self, stream: Iterator[bytes], closer: Optional[Callable]): + def __init__(self, stream: Iterator[bytes], closer: Optional[Callable], headers: Optional[Dict[str, Any]] = None): self.__stream = stream self.__closer = closer + self.__headers = headers @property def stream(self) -> Iterator[bytes]: @@ -107,6 +108,18 @@ def stream(self) -> Iterator[bytes]: """ return self.__stream + @property + def headers(self) -> Optional[Dict[str, Any]]: + """ + The HTTP response headers, if available. + + For HTTP connections, this contains the headers from the SSE stream response. + For non-HTTP connections, this will be ``None``. + + The headers dict uses case-insensitive keys (via urllib3's HTTPHeaderDict). + """ + return self.__headers + def close(self): """ Does whatever is necessary to release the connection. @@ -139,8 +152,8 @@ def __init__(self, params: _HttpConnectParams, logger: Logger): self.__impl = _HttpClientImpl(params, logger) def connect(self, last_event_id: Optional[str]) -> ConnectionResult: - stream, closer = self.__impl.connect(last_event_id) - return ConnectionResult(stream, closer) + stream, closer, headers = self.__impl.connect(last_event_id) + return ConnectionResult(stream, closer, headers) def close(self): self.__impl.close() diff --git a/ld_eventsource/errors.py b/ld_eventsource/errors.py index bb5733c..5e757c9 100644 --- a/ld_eventsource/errors.py +++ b/ld_eventsource/errors.py @@ -1,28 +1,62 @@ +from typing import Any, Dict, Optional, Protocol, runtime_checkable + + +@runtime_checkable +class ExceptionWithHeaders(Protocol): + """ + Protocol for exceptions that include HTTP response headers. + + This allows type-safe access to headers from error responses without + using hasattr checks. + """ + + @property + def headers(self) -> Optional[Dict[str, Any]]: + """The HTTP response headers associated with this exception.""" + raise NotImplementedError + + class HTTPStatusError(Exception): """ This exception indicates that the client was able to connect to the server, but that the HTTP response had an error status. + + When available, the response headers are accessible via the :attr:`headers` property. """ - def __init__(self, status: int): + def __init__(self, status: int, headers: Optional[Dict[str, Any]] = None): super().__init__("HTTP error %d" % status) self._status = status + self._headers = headers @property def status(self) -> int: return self._status + @property + def headers(self) -> Optional[Dict[str, Any]]: + """The HTTP response headers, if available.""" + return self._headers + class HTTPContentTypeError(Exception): """ This exception indicates that the HTTP response did not have the expected content type of `"text/event-stream"`. + + When available, the response headers are accessible via the :attr:`headers` property. """ - def __init__(self, content_type: str): + def __init__(self, content_type: str, headers: Optional[Dict[str, Any]] = None): super().__init__("invalid content type \"%s\"" % content_type) self._content_type = content_type + self._headers = headers @property def content_type(self) -> str: return self._content_type + + @property + def headers(self) -> Optional[Dict[str, Any]]: + """The HTTP response headers, if available.""" + return self._headers diff --git a/ld_eventsource/http.py b/ld_eventsource/http.py index c97ed6d..1058993 100644 --- a/ld_eventsource/http.py +++ b/ld_eventsource/http.py @@ -1,5 +1,5 @@ from logging import Logger -from typing import Callable, Iterator, Optional, Tuple +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, cast from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit from urllib3 import PoolManager @@ -60,7 +60,7 @@ def __init__(self, params: _HttpConnectParams, logger: Logger): self.__should_close_pool = params.pool is not None self.__logger = logger - def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callable]: + def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callable, Dict[str, Any]]: url = self.__params.url if self.__params.query_params is not None: qp = self.__params.query_params() @@ -100,13 +100,17 @@ def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callab reason: Optional[Exception] = e.reason if reason is not None: raise reason # e.reason is the underlying I/O error + + # Capture headers early so they're available for both error and success cases + response_headers = cast(Dict[str, Any], resp.headers) + if resp.status >= 400 or resp.status == 204: - raise HTTPStatusError(resp.status) + raise HTTPStatusError(resp.status, response_headers) content_type = resp.headers.get('Content-Type', None) if content_type is None or not str(content_type).startswith( "text/event-stream" ): - raise HTTPContentTypeError(content_type or '') + raise HTTPContentTypeError(content_type or '', response_headers) stream = resp.stream(_CHUNK_SIZE) @@ -117,7 +121,7 @@ def close(): pass resp.release_conn() - return stream, close + return stream, close, response_headers def close(self): if self.__should_close_pool: diff --git a/ld_eventsource/sse_client.py b/ld_eventsource/sse_client.py index a2d0548..2d50bd0 100644 --- a/ld_eventsource/sse_client.py +++ b/ld_eventsource/sse_client.py @@ -39,6 +39,13 @@ class SSEClient: :meth:`.RetryDelayStrategy.default()`, this delay will double with each subsequent retry, and will also have a pseudo-random jitter subtracted. You can customize this behavior with ``retry_delay_strategy``. + + **HTTP Response Headers:** + When using HTTP-based connections, the response headers from each connection are available + via the :attr:`.Start.headers` property when reading from :attr:`all`. Each time the client + connects or reconnects, a :class:`.Start` action is emitted containing the headers from that + specific connection. This allows you to access server metadata such as rate limits, session + identifiers, or custom headers. """ def __init__( @@ -178,9 +185,10 @@ def all(self) -> Iterable[Action]: # Reading implies starting the stream if it isn't already started. We might also # be restarting since we could have been interrupted at any time. while self.__connection_result is None: - fault = self._try_start(True) + result = self._try_start(True) # return either a Start action or a Fault action - yield Start() if fault is None else fault + if result is not None: + yield result lines = _BufferedLineReader.lines_from(self.__connection_result.stream) reader = _SSEReader(lines, self.__last_event_id, None) @@ -263,7 +271,7 @@ def _compute_next_retry_delay(self): self.__current_retry_delay_strategy.apply(self.__base_retry_delay) ) - def _try_start(self, can_return_fault: bool) -> Optional[Fault]: + def _try_start(self, can_return_fault: bool) -> Union[None, Start, Fault]: if self.__connection_result is not None: return None while True: @@ -297,7 +305,7 @@ def _try_start(self, can_return_fault: bool) -> Optional[Fault]: self._retry_reset_baseline = time.time() self.__current_error_strategy = self.__base_error_strategy self.__interrupted = False - return None + return Start(self.__connection_result.headers) @property def last_event_id(self) -> Optional[str]: diff --git a/ld_eventsource/testing/helpers.py b/ld_eventsource/testing/helpers.py index 5647493..64e7c69 100644 --- a/ld_eventsource/testing/helpers.py +++ b/ld_eventsource/testing/helpers.py @@ -66,16 +66,17 @@ def apply(self) -> ConnectionResult: class RespondWithStream(MockConnectionHandler): - def __init__(self, stream: Iterable[bytes]): + def __init__(self, stream: Iterable[bytes], headers: Optional[dict] = None): self.__stream = stream + self.__headers = headers def apply(self) -> ConnectionResult: - return ConnectionResult(stream=self.__stream.__iter__(), closer=None) + return ConnectionResult(stream=self.__stream.__iter__(), closer=None, headers=self.__headers) class RespondWithData(RespondWithStream): - def __init__(self, data: str): - super().__init__([bytes(data, 'utf-8')]) + def __init__(self, data: str, headers: Optional[dict] = None): + super().__init__([bytes(data, 'utf-8')], headers) class ExpectNoMoreRequests(MockConnectionHandler): diff --git a/ld_eventsource/testing/test_headers.py b/ld_eventsource/testing/test_headers.py new file mode 100644 index 0000000..7766a35 --- /dev/null +++ b/ld_eventsource/testing/test_headers.py @@ -0,0 +1,215 @@ +import pytest + +from ld_eventsource import * +from ld_eventsource.actions import * +from ld_eventsource.config import * +from ld_eventsource.errors import * +from ld_eventsource.testing.helpers import * + + +def test_start_action_with_no_headers(): + """Test that Start action can be created without headers""" + start = Start() + assert start.headers is None + + +def test_start_action_with_headers(): + """Test that Start action can be created with headers""" + headers = {'Content-Type': 'text/event-stream', 'X-Custom': 'value'} + start = Start(headers) + assert start.headers == headers + + +def test_headers_exposed_in_start_action(): + """Test that headers from connection are exposed in Start action""" + headers = {'Content-Type': 'text/event-stream', 'X-Test-Header': 'test-value'} + mock = MockConnectStrategy( + RespondWithData("event: test\ndata: data1\n\n", headers=headers) + ) + + with SSEClient(connect=mock) as client: + all_items = list(client.all) + + # First item should be Start with headers + assert isinstance(all_items[0], Start) + assert all_items[0].headers == headers + + # Second item should be the event + assert isinstance(all_items[1], Event) + assert all_items[1].event == 'test' + + # Third item should be Fault (end of stream) + assert isinstance(all_items[2], Fault) + assert all_items[2].error is None + + +def test_headers_not_visible_in_events_iterator(): + """Test that headers are only visible when using .all, not .events""" + headers = {'X-Custom': 'value'} + mock = MockConnectStrategy( + RespondWithData("event: test\ndata: data1\n\n", headers=headers) + ) + + with SSEClient(connect=mock) as client: + events = list(client.events) + + # Should only get the event, no Start action + assert len(events) == 1 + assert isinstance(events[0], Event) + assert events[0].event == 'test' + + +def test_no_headers_when_not_provided(): + """Test that Start action has None headers when connection doesn't provide them""" + mock = MockConnectStrategy( + RespondWithData("event: test\ndata: data1\n\n") + ) + + with SSEClient(connect=mock) as client: + all_items = list(client.all) + + # First item should be Start with no headers + assert isinstance(all_items[0], Start) + assert all_items[0].headers is None + + +def test_different_headers_on_reconnection(): + """Test that reconnection yields new Start with potentially different headers""" + headers1 = {'X-Connection': 'first'} + headers2 = {'X-Connection': 'second'} + + mock = MockConnectStrategy( + RespondWithData("event: test1\ndata: data1\n\n", headers=headers1), + RespondWithData("event: test2\ndata: data2\n\n", headers=headers2) + ) + + with SSEClient( + connect=mock, + error_strategy=ErrorStrategy.from_lambda(lambda _: (ErrorStrategy.CONTINUE, None)), + retry_delay_strategy=no_delay() + ) as client: + items = [] + for item in client.all: + items.append(item) + # Stop after we get the second Start (from reconnection) + if isinstance(item, Start) and len([i for i in items if isinstance(i, Start)]) == 2: + break + + # Find all Start actions + starts = [item for item in items if isinstance(item, Start)] + assert len(starts) >= 2 + + # First connection should have first headers + assert starts[0].headers == headers1 + + # Second connection should have second headers + assert starts[1].headers == headers2 + + +def test_headers_on_retry_after_error(): + """Test that headers are provided on successful retry after an error""" + error = HTTPStatusError(503) + headers = {'X-Retry': 'success'} + + mock = MockConnectStrategy( + RejectConnection(error), + RespondWithData("event: test\ndata: data1\n\n", headers=headers) + ) + + with SSEClient( + connect=mock, + error_strategy=ErrorStrategy.from_lambda(lambda _: (ErrorStrategy.CONTINUE, None)), + retry_delay_strategy=no_delay() + ) as client: + items = [] + for item in client.all: + items.append(item) + if isinstance(item, Event): + break + + # Should have: Fault (from error), Start (from retry), Event + assert isinstance(items[0], Fault) + assert isinstance(items[0].error, HTTPStatusError) + + assert isinstance(items[1], Start) + assert items[1].headers == headers + + assert isinstance(items[2], Event) + + +def test_connection_result_headers_property(): + """Test that ConnectionResult properly stores and returns headers""" + headers = {'X-Test': 'value'} + result = ConnectionResult(stream=iter([b'data']), closer=None, headers=headers) + assert result.headers == headers + + +def test_connection_result_no_headers(): + """Test that ConnectionResult returns None when no headers provided""" + result = ConnectionResult(stream=iter([b'data']), closer=None) + assert result.headers is None + + +def test_http_status_error_includes_headers(): + """Test that HTTPStatusError can store and expose headers""" + headers = {'Retry-After': '120', 'X-RateLimit-Remaining': '0'} + error = HTTPStatusError(429, headers) + assert error.status == 429 + assert error.headers == headers + assert error.headers.get('Retry-After') == '120' + + +def test_http_content_type_error_includes_headers(): + """Test that HTTPContentTypeError can store and expose headers""" + headers = {'Content-Type': 'text/plain', 'X-Custom': 'value'} + error = HTTPContentTypeError('text/plain', headers) + assert error.content_type == 'text/plain' + assert error.headers == headers + assert error.headers.get('X-Custom') == 'value' + + +def test_fault_exposes_headers_from_http_status_error(): + """Test that Fault.headers delegates to HTTPStatusError.headers""" + headers = {'Retry-After': '60'} + error = HTTPStatusError(503, headers) + fault = Fault(error) + + assert fault.error == error + assert fault.headers == headers + assert fault.headers.get('Retry-After') == '60' + + +def test_fault_exposes_headers_from_http_content_type_error(): + """Test that Fault.headers delegates to HTTPContentTypeError.headers""" + headers = {'Content-Type': 'application/json'} + error = HTTPContentTypeError('application/json', headers) + fault = Fault(error) + + assert fault.error == error + assert fault.headers == headers + + +def test_fault_headers_none_for_non_http_errors(): + """Test that Fault.headers returns None for errors without headers""" + error = RuntimeError("some error") + fault = Fault(error) + + assert fault.error == error + assert fault.headers is None + + +def test_fault_headers_none_when_no_error(): + """Test that Fault.headers returns None when there's no error""" + fault = Fault(None) + + assert fault.error is None + assert fault.headers is None + + +def test_fault_headers_none_when_exception_has_no_headers(): + """Test that Fault.headers returns None when exception doesn't provide headers""" + error = HTTPStatusError(500) # No headers provided + fault = Fault(error) + + assert fault.error == error + assert fault.headers is None diff --git a/ld_eventsource/testing/test_http_connect_strategy.py b/ld_eventsource/testing/test_http_connect_strategy.py index dd50c6f..64a5832 100644 --- a/ld_eventsource/testing/test_http_connect_strategy.py +++ b/ld_eventsource/testing/test_http_connect_strategy.py @@ -2,6 +2,8 @@ from urllib3.exceptions import ProtocolError +from ld_eventsource import * +from ld_eventsource.actions import * from ld_eventsource.config.connect_strategy import * from ld_eventsource.testing.helpers import * from ld_eventsource.testing.http_util import * @@ -133,3 +135,107 @@ def test_sse_client_with_http_connect_strategy(): stream.push("data: data1\n\n") event = next(client.events) assert event.data == 'data1' + + +def test_http_response_headers_captured(): + """Test that HTTP response headers are captured from the connection""" + with start_server() as server: + custom_headers = { + 'Content-Type': 'text/event-stream', + 'X-Custom-Header': 'custom-value', + 'X-Rate-Limit': '100' + } + with ChunkedResponse(custom_headers) as stream: + server.for_path('/', stream) + with ConnectStrategy.http(server.uri).create_client(logger()) as client: + result = client.connect(None) + assert result.headers is not None + assert result.headers.get('X-Custom-Header') == 'custom-value' + assert result.headers.get('X-Rate-Limit') == '100' + # urllib3 should also include Content-Type + assert 'Content-Type' in result.headers + + +def test_http_response_headers_in_sse_client(): + """Test that headers are exposed via Start action in SSEClient""" + with start_server() as server: + custom_headers = { + 'Content-Type': 'text/event-stream', + 'X-Session-Id': 'abc123' + } + with ChunkedResponse(custom_headers) as stream: + server.for_path('/', stream) + with SSEClient(connect=ConnectStrategy.http(server.uri)) as client: + stream.push("event: test\ndata: data1\n\n") + + # Read from .all to get Start action + all_items = [] + for item in client.all: + all_items.append(item) + if isinstance(item, Event): + break + + # First item should be Start with headers + assert isinstance(all_items[0], Start) + assert all_items[0].headers is not None + assert all_items[0].headers.get('X-Session-Id') == 'abc123' + + # Second item should be the event + assert isinstance(all_items[1], Event) + + +def test_http_status_error_includes_headers(): + """Test that HTTPStatusError captures response headers""" + with start_server() as server: + server.for_path('/', BasicResponse(429, None, { + 'Retry-After': '120', + 'X-RateLimit-Remaining': '0', + 'X-RateLimit-Reset': '1234567890' + })) + try: + with ConnectStrategy.http(server.uri).create_client(logger()) as client: + client.connect(None) + raise Exception("expected exception, did not get one") + except HTTPStatusError as e: + assert e.status == 429 + assert e.headers is not None + assert e.headers.get('Retry-After') == '120' + assert e.headers.get('X-RateLimit-Remaining') == '0' + assert e.headers.get('X-RateLimit-Reset') == '1234567890' + + +def test_http_content_type_error_includes_headers(): + """Test that HTTPContentTypeError captures response headers""" + with start_server() as server: + with ChunkedResponse({'Content-Type': 'application/json', 'X-Custom': 'value'}) as stream: + server.for_path('/', stream) + try: + with ConnectStrategy.http(server.uri).create_client(logger()) as client: + client.connect(None) + raise Exception("expected exception, did not get one") + except HTTPContentTypeError as e: + assert e.content_type == "application/json" + assert e.headers is not None + assert e.headers.get('Content-Type') == 'application/json' + assert e.headers.get('X-Custom') == 'value' + + +def test_fault_exposes_headers_from_http_error(): + """Test that Fault.headers exposes headers from HTTP errors""" + with start_server() as server: + server.for_path('/', BasicResponse(503, None, { + 'Retry-After': '60', + 'X-Error-Code': 'SERVICE_UNAVAILABLE' + })) + with SSEClient( + connect=ConnectStrategy.http(server.uri), + error_strategy=ErrorStrategy.always_continue(), + retry_delay_strategy=no_delay() + ) as client: + # Read first item which should be a Fault with the error + fault = next(client.all) + assert isinstance(fault, Fault) + assert isinstance(fault.error, HTTPStatusError) + assert fault.headers is not None + assert fault.headers.get('Retry-After') == '60' + assert fault.headers.get('X-Error-Code') == 'SERVICE_UNAVAILABLE'