Skip to content
Snippets Groups Projects
python-anyio-553.patch 8.66 KiB
diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py
index c8e19e2..05a70cb 100644
--- a/src/anyio/streams/tls.py
+++ b/src/anyio/streams/tls.py
@@ -37,8 +37,9 @@ class TLSAttribute(TypedAttributeSet):
     peer_certificate_binary: Optional[bytes] = typed_attribute()
     #: ``True`` if this is the server side of the connection
     server_side: bool = typed_attribute()
-    #: ciphers shared between both ends of the TLS connection
-    shared_ciphers: List[Tuple[str, str, int]] = typed_attribute()
+    #: ciphers shared by the client during the TLS handshake (``None`` if this is the
+    #: client side)
+    shared_ciphers: Optional[List[Tuple[str, str, int]]] = typed_attribute()
     #: the :class:`~ssl.SSLObject` used for encryption
     ssl_object: ssl.SSLObject = typed_attribute()
     #: ``True`` if this stream does (and expects) a closing TLS handshake when the stream is being
@@ -105,7 +106,7 @@ class TLSStream(ByteStream):
 
             # Re-enable detection of unexpected EOFs if it was disabled by Python
             if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
-                ssl_context.options ^= ssl.OP_IGNORE_UNEXPECTED_EOF  # type: ignore[attr-defined]
+                ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF  # type: ignore[attr-defined]
 
         bio_in = ssl.MemoryBIO()
         bio_out = ssl.MemoryBIO()
@@ -228,7 +229,9 @@ class TLSStream(ByteStream):
                 True
             ),
             TLSAttribute.server_side: lambda: self._ssl_object.server_side,
-            TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers(),
+            TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
+            if self._ssl_object.server_side
+            else None,
             TLSAttribute.standard_compatible: lambda: self.standard_compatible,
             TLSAttribute.ssl_object: lambda: self._ssl_object,
             TLSAttribute.tls_version: self._ssl_object.version,
diff --git a/tests/conftest.py b/tests/conftest.py
index 6ee5a53..5d032be 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -54,7 +54,7 @@ def ca() -> CA:
 def server_context(ca: CA) -> SSLContext:
     server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
     if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
-        server_context.options ^= ssl.OP_IGNORE_UNEXPECTED_EOF  # type: ignore[attr-defined]
+        server_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF  # type: ignore[attr-defined]
 
     ca.issue_cert("localhost").configure_cert(server_context)
     return server_context
@@ -64,7 +64,7 @@ def server_context(ca: CA) -> SSLContext:
 def client_context(ca: CA) -> SSLContext:
     client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
     if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
-        client_context.options ^= ssl.OP_IGNORE_UNEXPECTED_EOF  # type: ignore[attr-defined]
+        client_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF  # type: ignore[attr-defined]
 
     ca.configure_trust(client_context)
     return client_context
diff --git a/tests/streams/test_tls.py b/tests/streams/test_tls.py
index 73ccfad..4d3a263 100644
--- a/tests/streams/test_tls.py
+++ b/tests/streams/test_tls.py
@@ -5,6 +5,7 @@ from threading import Thread
 from typing import ContextManager, NoReturn
 
 import pytest
+from pytest_mock import MockerFixture
 from trustme import CA
 
 from anyio import (
@@ -12,10 +13,13 @@ from anyio import (
     EndOfStream,
     Event,
     connect_tcp,
+    create_memory_object_stream,
     create_task_group,
     create_tcp_listener,
+    to_thread,
 )
 from anyio.abc import AnyByteStream, SocketAttribute, SocketStream
+from anyio.streams.stapled import StapledObjectStream
 from anyio.streams.tls import TLSAttribute, TLSListener, TLSStream
 
 pytestmark = pytest.mark.anyio
@@ -95,7 +99,7 @@ class TestTLSStream:
                     wrapper.extra(TLSAttribute.peer_certificate_binary), bytes
                 )
                 assert wrapper.extra(TLSAttribute.server_side) is False
-                assert isinstance(wrapper.extra(TLSAttribute.shared_ciphers), list)
+                assert wrapper.extra(TLSAttribute.shared_ciphers) is None
                 assert isinstance(wrapper.extra(TLSAttribute.ssl_object), ssl.SSLObject)
                 assert wrapper.extra(TLSAttribute.standard_compatible) is False
                 assert wrapper.extra(TLSAttribute.tls_version).startswith("TLSv")
@@ -368,6 +372,20 @@ class TestTLSStream:
         server_thread.join()
         server_sock.close()
 
+    @pytest.mark.skipif(
+        not hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"),
+        reason="The ssl module does not have the OP_IGNORE_UNEXPECTED_EOF attribute",
+    )
+    async def test_default_context_ignore_unexpected_eof_flag_off(
+        self, mocker: MockerFixture
+    ) -> None:
+        send1, receive1 = create_memory_object_stream()
+        client_stream = StapledObjectStream(send1, receive1)
+        mocker.patch.object(TLSStream, "_call_sslobject_method")
+        tls_stream = await TLSStream.wrap(client_stream)
+        ssl_context = tls_stream.extra(TLSAttribute.ssl_object).context
+        assert not ssl_context.options & ssl.OP_IGNORE_UNEXPECTED_EOF
+
 
 class TestTLSListener:
     async def test_handshake_fail(self, server_context: ssl.SSLContext) -> None:
@@ -399,3 +417,74 @@ class TestTLSListener:
             tg.cancel_scope.cancel()
 
         assert isinstance(exception, BrokenResourceError)
+
+    async def test_extra_attributes(
+        self, client_context: ssl.SSLContext, server_context: ssl.SSLContext, ca: CA
+    ) -> None:
+        def connect_sync(addr: tuple[str, int]) -> None:
+            with socket.create_connection(addr) as plain_sock:
+                plain_sock.settimeout(2)
+                with client_context.wrap_socket(
+                    plain_sock,
+                    server_side=False,
+                    server_hostname="localhost",
+                    suppress_ragged_eofs=False,
+                ) as conn:
+                    conn.recv(1)
+                    conn.unwrap()
+
+        class CustomTLSListener(TLSListener):
+            @staticmethod
+            async def handle_handshake_error(
+                exc: BaseException, stream: AnyByteStream
+            ) -> None:
+                await TLSListener.handle_handshake_error(exc, stream)
+                pytest.fail("TLS handshake failed")
+
+        async def handler(stream: TLSStream) -> None:
+            async with stream:
+                try:
+                    assert stream.extra(TLSAttribute.alpn_protocol) == "h2"
+                    assert isinstance(
+                        stream.extra(TLSAttribute.channel_binding_tls_unique), bytes
+                    )
+                    assert isinstance(stream.extra(TLSAttribute.cipher), tuple)
+                    assert isinstance(stream.extra(TLSAttribute.peer_certificate), dict)
+                    assert isinstance(
+                        stream.extra(TLSAttribute.peer_certificate_binary), bytes
+                    )
+                    assert stream.extra(TLSAttribute.server_side) is True
+                    shared_ciphers = stream.extra(TLSAttribute.shared_ciphers)
+                    assert isinstance(shared_ciphers, list)
+                    assert len(shared_ciphers) > 1
+                    assert isinstance(
+                        stream.extra(TLSAttribute.ssl_object), ssl.SSLObject
+                    )
+                    assert stream.extra(TLSAttribute.standard_compatible) is True
+                    assert stream.extra(TLSAttribute.tls_version).startswith("TLSv")
+                finally:
+                    event.set()
+                    await stream.send(b"\x00")
+
+        # Issue a client certificate and make the server trust it
+        client_cert = ca.issue_cert("dummy-client")
+        client_cert.configure_cert(client_context)
+        ca.configure_trust(server_context)
+        server_context.verify_mode = ssl.CERT_REQUIRED
+
+        event = Event()
+        server_context.set_alpn_protocols(["h2"])
+        client_context.set_alpn_protocols(["h2"])
+        listener = await create_tcp_listener(local_host="127.0.0.1")
+        tls_listener = CustomTLSListener(listener, server_context)
+        async with tls_listener, create_task_group() as tg:
+            assert tls_listener.extra(TLSAttribute.standard_compatible) is True
+            tg.start_soon(tls_listener.serve, handler)
+            client_thread = Thread(
+                target=connect_sync,
+                args=[listener.extra(SocketAttribute.local_address)],
+            )
+            client_thread.start()
+            await event.wait()
+            await to_thread.run_sync(client_thread.join)
+            tg.cancel_scope.cancel()