summaryrefslogtreecommitdiffstats
path: root/venv/lib/python3.9/site-packages/trio/tests
diff options
context:
space:
mode:
authornoptuno <repollo.marrero@gmail.com>2023-04-28 02:29:30 +0200
committernoptuno <repollo.marrero@gmail.com>2023-04-28 02:29:30 +0200
commit355dee533bb34a571b9367820a63cccb668cf866 (patch)
tree838af886b4fec07320aeb10f0d1e74ba79e79b5c /venv/lib/python3.9/site-packages/trio/tests
parentadded pyproject.toml file (diff)
downloadgpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar
gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar.gz
gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar.bz2
gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar.lz
gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar.xz
gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar.zst
gpt4free-355dee533bb34a571b9367820a63cccb668cf866.zip
Diffstat (limited to 'venv/lib/python3.9/site-packages/trio/tests')
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/__init__.py0
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/conftest.py41
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/module_with_deprecations.py21
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_abc.py49
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_channel.py407
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_contextvars.py52
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_deprecate.py243
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_dtls.py867
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_exports.py145
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_fakenet.py44
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_file_io.py198
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_highlevel_generic.py94
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_listeners.py300
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_stream.py574
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_unix_stream.py67
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_highlevel_serve_listeners.py145
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_highlevel_socket.py267
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_highlevel_ssl_helpers.py113
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_path.py262
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_scheduler_determinism.py40
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_signals.py177
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_socket.py1017
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_ssl.py1303
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_subprocess.py602
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_sync.py567
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_testing.py657
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_threads.py752
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_timeouts.py104
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_unix_pipes.py276
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_util.py193
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_wait_for_object.py220
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/test_windows_pipes.py110
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/tools/__init__.py0
-rw-r--r--venv/lib/python3.9/site-packages/trio/tests/tools/test_gen_exports.py72
34 files changed, 9979 insertions, 0 deletions
diff --git a/venv/lib/python3.9/site-packages/trio/tests/__init__.py b/venv/lib/python3.9/site-packages/trio/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/__init__.py
diff --git a/venv/lib/python3.9/site-packages/trio/tests/conftest.py b/venv/lib/python3.9/site-packages/trio/tests/conftest.py
new file mode 100644
index 00000000..772486e1
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/conftest.py
@@ -0,0 +1,41 @@
+# XX this does not belong here -- b/c it's here, these things only apply to
+# the tests in trio/_core/tests, not in trio/tests. For now there's some
+# copy-paste...
+#
+# this stuff should become a proper pytest plugin
+
+import pytest
+import inspect
+
+from ..testing import trio_test, MockClock
+
+RUN_SLOW = True
+
+
+def pytest_addoption(parser):
+ parser.addoption("--run-slow", action="store_true", help="run slow tests")
+
+
+def pytest_configure(config):
+ global RUN_SLOW
+ RUN_SLOW = config.getoption("--run-slow", True)
+
+
+@pytest.fixture
+def mock_clock():
+ return MockClock()
+
+
+@pytest.fixture
+def autojump_clock():
+ return MockClock(autojump_threshold=0)
+
+
+# FIXME: split off into a package (or just make part of Trio's public
+# interface?), with config file to enable? and I guess a mark option too; I
+# guess it's useful with the class- and file-level marking machinery (where
+# the raw @trio_test decorator isn't enough).
+@pytest.hookimpl(tryfirst=True)
+def pytest_pyfunc_call(pyfuncitem):
+ if inspect.iscoroutinefunction(pyfuncitem.obj):
+ pyfuncitem.obj = trio_test(pyfuncitem.obj)
diff --git a/venv/lib/python3.9/site-packages/trio/tests/module_with_deprecations.py b/venv/lib/python3.9/site-packages/trio/tests/module_with_deprecations.py
new file mode 100644
index 00000000..73184d11
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/module_with_deprecations.py
@@ -0,0 +1,21 @@
+regular = "hi"
+
+from .. import _deprecate
+
+_deprecate.enable_attribute_deprecations(__name__)
+
+# Make sure that we don't trigger infinite recursion when accessing module
+# attributes in between calling enable_attribute_deprecations and defining
+# __deprecated_attributes__:
+import sys
+
+this_mod = sys.modules[__name__]
+assert this_mod.regular == "hi"
+assert not hasattr(this_mod, "dep1")
+
+__deprecated_attributes__ = {
+ "dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1),
+ "dep2": _deprecate.DeprecatedAttribute(
+ "value2", "1.2", issue=1, instead="instead-string"
+ ),
+}
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_abc.py b/venv/lib/python3.9/site-packages/trio/tests/test_abc.py
new file mode 100644
index 00000000..c445c971
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_abc.py
@@ -0,0 +1,49 @@
+import pytest
+
+import attr
+
+from ..testing import assert_checkpoints
+from .. import abc as tabc
+
+
+async def test_AsyncResource_defaults():
+ @attr.s
+ class MyAR(tabc.AsyncResource):
+ record = attr.ib(factory=list)
+
+ async def aclose(self):
+ self.record.append("ac")
+
+ async with MyAR() as myar:
+ assert isinstance(myar, MyAR)
+ assert myar.record == []
+
+ assert myar.record == ["ac"]
+
+
+def test_abc_generics():
+ # Pythons below 3.5.2 had a typing.Generic that would throw
+ # errors when instantiating or subclassing a parameterized
+ # version of a class with any __slots__. This is why RunVar
+ # (which has slots) is not generic. This tests that
+ # the generic ABCs are fine, because while they are slotted
+ # they don't actually define any slots.
+
+ class SlottedChannel(tabc.SendChannel[tabc.Stream]):
+ __slots__ = ("x",)
+
+ def send_nowait(self, value):
+ raise RuntimeError
+
+ async def send(self, value):
+ raise RuntimeError # pragma: no cover
+
+ def clone(self):
+ raise RuntimeError # pragma: no cover
+
+ async def aclose(self):
+ pass # pragma: no cover
+
+ channel = SlottedChannel()
+ with pytest.raises(RuntimeError):
+ channel.send_nowait(None)
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_channel.py b/venv/lib/python3.9/site-packages/trio/tests/test_channel.py
new file mode 100644
index 00000000..fd990fb3
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_channel.py
@@ -0,0 +1,407 @@
+import pytest
+
+from ..testing import wait_all_tasks_blocked, assert_checkpoints
+import trio
+from trio import open_memory_channel, EndOfChannel
+
+
+async def test_channel():
+ with pytest.raises(TypeError):
+ open_memory_channel(1.0)
+ with pytest.raises(ValueError):
+ open_memory_channel(-1)
+
+ s, r = open_memory_channel(2)
+ repr(s) # smoke test
+ repr(r) # smoke test
+
+ s.send_nowait(1)
+ with assert_checkpoints():
+ await s.send(2)
+ with pytest.raises(trio.WouldBlock):
+ s.send_nowait(None)
+
+ with assert_checkpoints():
+ assert await r.receive() == 1
+ assert r.receive_nowait() == 2
+ with pytest.raises(trio.WouldBlock):
+ r.receive_nowait()
+
+ s.send_nowait("last")
+ await s.aclose()
+ with pytest.raises(trio.ClosedResourceError):
+ await s.send("too late")
+ with pytest.raises(trio.ClosedResourceError):
+ s.send_nowait("too late")
+ with pytest.raises(trio.ClosedResourceError):
+ s.clone()
+ await s.aclose()
+
+ assert r.receive_nowait() == "last"
+ with pytest.raises(EndOfChannel):
+ await r.receive()
+ await r.aclose()
+ with pytest.raises(trio.ClosedResourceError):
+ await r.receive()
+ with pytest.raises(trio.ClosedResourceError):
+ await r.receive_nowait()
+ await r.aclose()
+
+
+async def test_553(autojump_clock):
+ s, r = open_memory_channel(1)
+ with trio.move_on_after(10) as timeout_scope:
+ await r.receive()
+ assert timeout_scope.cancelled_caught
+ await s.send("Test for PR #553")
+
+
+async def test_channel_multiple_producers():
+ async def producer(send_channel, i):
+ # We close our handle when we're done with it
+ async with send_channel:
+ for j in range(3 * i, 3 * (i + 1)):
+ await send_channel.send(j)
+
+ send_channel, receive_channel = open_memory_channel(0)
+ async with trio.open_nursery() as nursery:
+ # We hand out clones to all the new producers, and then close the
+ # original.
+ async with send_channel:
+ for i in range(10):
+ nursery.start_soon(producer, send_channel.clone(), i)
+
+ got = []
+ async for value in receive_channel:
+ got.append(value)
+
+ got.sort()
+ assert got == list(range(30))
+
+
+async def test_channel_multiple_consumers():
+ successful_receivers = set()
+ received = []
+
+ async def consumer(receive_channel, i):
+ async for value in receive_channel:
+ successful_receivers.add(i)
+ received.append(value)
+
+ async with trio.open_nursery() as nursery:
+ send_channel, receive_channel = trio.open_memory_channel(1)
+ async with send_channel:
+ for i in range(5):
+ nursery.start_soon(consumer, receive_channel, i)
+ await wait_all_tasks_blocked()
+ for i in range(10):
+ await send_channel.send(i)
+
+ assert successful_receivers == set(range(5))
+ assert len(received) == 10
+ assert set(received) == set(range(10))
+
+
+async def test_close_basics():
+ async def send_block(s, expect):
+ with pytest.raises(expect):
+ await s.send(None)
+
+ # closing send -> other send gets ClosedResourceError
+ s, r = open_memory_channel(0)
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(send_block, s, trio.ClosedResourceError)
+ await wait_all_tasks_blocked()
+ await s.aclose()
+
+ # and it's persistent
+ with pytest.raises(trio.ClosedResourceError):
+ s.send_nowait(None)
+ with pytest.raises(trio.ClosedResourceError):
+ await s.send(None)
+
+ # and receive gets EndOfChannel
+ with pytest.raises(EndOfChannel):
+ r.receive_nowait()
+ with pytest.raises(EndOfChannel):
+ await r.receive()
+
+ # closing receive -> send gets BrokenResourceError
+ s, r = open_memory_channel(0)
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(send_block, s, trio.BrokenResourceError)
+ await wait_all_tasks_blocked()
+ await r.aclose()
+
+ # and it's persistent
+ with pytest.raises(trio.BrokenResourceError):
+ s.send_nowait(None)
+ with pytest.raises(trio.BrokenResourceError):
+ await s.send(None)
+
+ # closing receive -> other receive gets ClosedResourceError
+ async def receive_block(r):
+ with pytest.raises(trio.ClosedResourceError):
+ await r.receive()
+
+ s, r = open_memory_channel(0)
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(receive_block, r)
+ await wait_all_tasks_blocked()
+ await r.aclose()
+
+ # and it's persistent
+ with pytest.raises(trio.ClosedResourceError):
+ r.receive_nowait()
+ with pytest.raises(trio.ClosedResourceError):
+ await r.receive()
+
+
+async def test_close_sync():
+ async def send_block(s, expect):
+ with pytest.raises(expect):
+ await s.send(None)
+
+ # closing send -> other send gets ClosedResourceError
+ s, r = open_memory_channel(0)
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(send_block, s, trio.ClosedResourceError)
+ await wait_all_tasks_blocked()
+ s.close()
+
+ # and it's persistent
+ with pytest.raises(trio.ClosedResourceError):
+ s.send_nowait(None)
+ with pytest.raises(trio.ClosedResourceError):
+ await s.send(None)
+
+ # and receive gets EndOfChannel
+ with pytest.raises(EndOfChannel):
+ r.receive_nowait()
+ with pytest.raises(EndOfChannel):
+ await r.receive()
+
+ # closing receive -> send gets BrokenResourceError
+ s, r = open_memory_channel(0)
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(send_block, s, trio.BrokenResourceError)
+ await wait_all_tasks_blocked()
+ r.close()
+
+ # and it's persistent
+ with pytest.raises(trio.BrokenResourceError):
+ s.send_nowait(None)
+ with pytest.raises(trio.BrokenResourceError):
+ await s.send(None)
+
+ # closing receive -> other receive gets ClosedResourceError
+ async def receive_block(r):
+ with pytest.raises(trio.ClosedResourceError):
+ await r.receive()
+
+ s, r = open_memory_channel(0)
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(receive_block, r)
+ await wait_all_tasks_blocked()
+ r.close()
+
+ # and it's persistent
+ with pytest.raises(trio.ClosedResourceError):
+ r.receive_nowait()
+ with pytest.raises(trio.ClosedResourceError):
+ await r.receive()
+
+
+async def test_receive_channel_clone_and_close():
+ s, r = open_memory_channel(10)
+
+ r2 = r.clone()
+ r3 = r.clone()
+
+ s.send_nowait(None)
+ await r.aclose()
+ with r2:
+ pass
+
+ with pytest.raises(trio.ClosedResourceError):
+ r.clone()
+
+ with pytest.raises(trio.ClosedResourceError):
+ r2.clone()
+
+ # Can still send, r3 is still open
+ s.send_nowait(None)
+
+ await r3.aclose()
+
+ # But now the receiver is really closed
+ with pytest.raises(trio.BrokenResourceError):
+ s.send_nowait(None)
+
+
+async def test_close_multiple_send_handles():
+ # With multiple send handles, closing one handle only wakes senders on
+ # that handle, but others can continue just fine
+ s1, r = open_memory_channel(0)
+ s2 = s1.clone()
+
+ async def send_will_close():
+ with pytest.raises(trio.ClosedResourceError):
+ await s1.send("nope")
+
+ async def send_will_succeed():
+ await s2.send("ok")
+
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(send_will_close)
+ nursery.start_soon(send_will_succeed)
+ await wait_all_tasks_blocked()
+ await s1.aclose()
+ assert await r.receive() == "ok"
+
+
+async def test_close_multiple_receive_handles():
+ # With multiple receive handles, closing one handle only wakes receivers on
+ # that handle, but others can continue just fine
+ s, r1 = open_memory_channel(0)
+ r2 = r1.clone()
+
+ async def receive_will_close():
+ with pytest.raises(trio.ClosedResourceError):
+ await r1.receive()
+
+ async def receive_will_succeed():
+ assert await r2.receive() == "ok"
+
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(receive_will_close)
+ nursery.start_soon(receive_will_succeed)
+ await wait_all_tasks_blocked()
+ await r1.aclose()
+ await s.send("ok")
+
+
+async def test_inf_capacity():
+ s, r = open_memory_channel(float("inf"))
+
+ # It's accepted, and we can send all day without blocking
+ with s:
+ for i in range(10):
+ s.send_nowait(i)
+
+ got = []
+ async for i in r:
+ got.append(i)
+ assert got == list(range(10))
+
+
+async def test_statistics():
+ s, r = open_memory_channel(2)
+
+ assert s.statistics() == r.statistics()
+ stats = s.statistics()
+ assert stats.current_buffer_used == 0
+ assert stats.max_buffer_size == 2
+ assert stats.open_send_channels == 1
+ assert stats.open_receive_channels == 1
+ assert stats.tasks_waiting_send == 0
+ assert stats.tasks_waiting_receive == 0
+
+ s.send_nowait(None)
+ assert s.statistics().current_buffer_used == 1
+
+ s2 = s.clone()
+ assert s.statistics().open_send_channels == 2
+ await s.aclose()
+ assert s2.statistics().open_send_channels == 1
+
+ r2 = r.clone()
+ assert s2.statistics().open_receive_channels == 2
+ await r2.aclose()
+ assert s2.statistics().open_receive_channels == 1
+
+ async with trio.open_nursery() as nursery:
+ s2.send_nowait(None) # fill up the buffer
+ assert s.statistics().current_buffer_used == 2
+ nursery.start_soon(s2.send, None)
+ nursery.start_soon(s2.send, None)
+ await wait_all_tasks_blocked()
+ assert s.statistics().tasks_waiting_send == 2
+ nursery.cancel_scope.cancel()
+ assert s.statistics().tasks_waiting_send == 0
+
+ # empty out the buffer again
+ try:
+ while True:
+ r.receive_nowait()
+ except trio.WouldBlock:
+ pass
+
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(r.receive)
+ await wait_all_tasks_blocked()
+ assert s.statistics().tasks_waiting_receive == 1
+ nursery.cancel_scope.cancel()
+ assert s.statistics().tasks_waiting_receive == 0
+
+
+async def test_channel_fairness():
+
+ # We can remove an item we just sent, and send an item back in after, if
+ # no-one else is waiting.
+ s, r = open_memory_channel(1)
+ s.send_nowait(1)
+ assert r.receive_nowait() == 1
+ s.send_nowait(2)
+ assert r.receive_nowait() == 2
+
+ # But if someone else is waiting to receive, then they "own" the item we
+ # send, so we can't receive it (even though we run first):
+
+ result = None
+
+ async def do_receive(r):
+ nonlocal result
+ result = await r.receive()
+
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(do_receive, r)
+ await wait_all_tasks_blocked()
+ s.send_nowait(2)
+ with pytest.raises(trio.WouldBlock):
+ r.receive_nowait()
+ assert result == 2
+
+ # And the analogous situation for send: if we free up a space, we can't
+ # immediately send something in it if someone is already waiting to do
+ # that
+ s, r = open_memory_channel(1)
+ s.send_nowait(1)
+ with pytest.raises(trio.WouldBlock):
+ s.send_nowait(None)
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(s.send, 2)
+ await wait_all_tasks_blocked()
+ assert r.receive_nowait() == 1
+ with pytest.raises(trio.WouldBlock):
+ s.send_nowait(3)
+ assert (await r.receive()) == 2
+
+
+async def test_unbuffered():
+ s, r = open_memory_channel(0)
+ with pytest.raises(trio.WouldBlock):
+ r.receive_nowait()
+ with pytest.raises(trio.WouldBlock):
+ s.send_nowait(1)
+
+ async def do_send(s, v):
+ with assert_checkpoints():
+ await s.send(v)
+
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(do_send, s, 1)
+ with assert_checkpoints():
+ assert await r.receive() == 1
+ with pytest.raises(trio.WouldBlock):
+ r.receive_nowait()
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_contextvars.py b/venv/lib/python3.9/site-packages/trio/tests/test_contextvars.py
new file mode 100644
index 00000000..63853f51
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_contextvars.py
@@ -0,0 +1,52 @@
+import contextvars
+
+from .. import _core
+
+trio_testing_contextvar = contextvars.ContextVar("trio_testing_contextvar")
+
+
+async def test_contextvars_default():
+ trio_testing_contextvar.set("main")
+ record = []
+
+ async def child():
+ value = trio_testing_contextvar.get()
+ record.append(value)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(child)
+ assert record == ["main"]
+
+
+async def test_contextvars_set():
+ trio_testing_contextvar.set("main")
+ record = []
+
+ async def child():
+ trio_testing_contextvar.set("child")
+ value = trio_testing_contextvar.get()
+ record.append(value)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(child)
+ value = trio_testing_contextvar.get()
+ assert record == ["child"]
+ assert value == "main"
+
+
+async def test_contextvars_copy():
+ trio_testing_contextvar.set("main")
+ context = contextvars.copy_context()
+ trio_testing_contextvar.set("second_main")
+ record = []
+
+ async def child():
+ value = trio_testing_contextvar.get()
+ record.append(value)
+
+ async with _core.open_nursery() as nursery:
+ context.run(nursery.start_soon, child)
+ nursery.start_soon(child)
+ value = trio_testing_contextvar.get()
+ assert set(record) == {"main", "second_main"}
+ assert value == "second_main"
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_deprecate.py b/venv/lib/python3.9/site-packages/trio/tests/test_deprecate.py
new file mode 100644
index 00000000..e5e1da8c
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_deprecate.py
@@ -0,0 +1,243 @@
+import pytest
+
+import inspect
+import warnings
+
+from .._deprecate import (
+ TrioDeprecationWarning,
+ warn_deprecated,
+ deprecated,
+ deprecated_alias,
+)
+
+from . import module_with_deprecations
+
+
+@pytest.fixture
+def recwarn_always(recwarn):
+ warnings.simplefilter("always")
+ # ResourceWarnings about unclosed sockets can occur nondeterministically
+ # (during GC) which throws off the tests in this file
+ warnings.simplefilter("ignore", ResourceWarning)
+ return recwarn
+
+
+def _here():
+ info = inspect.getframeinfo(inspect.currentframe().f_back)
+ return (info.filename, info.lineno)
+
+
+def test_warn_deprecated(recwarn_always):
+ def deprecated_thing():
+ warn_deprecated("ice", "1.2", issue=1, instead="water")
+
+ deprecated_thing()
+ filename, lineno = _here()
+ assert len(recwarn_always) == 1
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ assert "ice is deprecated" in got.message.args[0]
+ assert "Trio 1.2" in got.message.args[0]
+ assert "water instead" in got.message.args[0]
+ assert "/issues/1" in got.message.args[0]
+ assert got.filename == filename
+ assert got.lineno == lineno - 1
+
+
+def test_warn_deprecated_no_instead_or_issue(recwarn_always):
+ # Explicitly no instead or issue
+ warn_deprecated("water", "1.3", issue=None, instead=None)
+ assert len(recwarn_always) == 1
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ assert "water is deprecated" in got.message.args[0]
+ assert "no replacement" in got.message.args[0]
+ assert "Trio 1.3" in got.message.args[0]
+
+
+def test_warn_deprecated_stacklevel(recwarn_always):
+ def nested1():
+ nested2()
+
+ def nested2():
+ warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3)
+
+ filename, lineno = _here()
+ nested1()
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ assert got.filename == filename
+ assert got.lineno == lineno + 1
+
+
+def old(): # pragma: no cover
+ pass
+
+
+def new(): # pragma: no cover
+ pass
+
+
+def test_warn_deprecated_formatting(recwarn_always):
+ warn_deprecated(old, "1.0", issue=1, instead=new)
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ assert "test_deprecate.old is deprecated" in got.message.args[0]
+ assert "test_deprecate.new instead" in got.message.args[0]
+
+
+@deprecated("1.5", issue=123, instead=new)
+def deprecated_old():
+ return 3
+
+
+def test_deprecated_decorator(recwarn_always):
+ assert deprecated_old() == 3
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0]
+ assert "1.5" in got.message.args[0]
+ assert "test_deprecate.new" in got.message.args[0]
+ assert "issues/123" in got.message.args[0]
+
+
+class Foo:
+ @deprecated("1.0", issue=123, instead="crying")
+ def method(self):
+ return 7
+
+
+def test_deprecated_decorator_method(recwarn_always):
+ f = Foo()
+ assert f.method() == 7
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ assert "test_deprecate.Foo.method is deprecated" in got.message.args[0]
+
+
+@deprecated("1.2", thing="the thing", issue=None, instead=None)
+def deprecated_with_thing():
+ return 72
+
+
+def test_deprecated_decorator_with_explicit_thing(recwarn_always):
+ assert deprecated_with_thing() == 72
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ assert "the thing is deprecated" in got.message.args[0]
+
+
+def new_hotness():
+ return "new hotness"
+
+
+old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1)
+
+
+def test_deprecated_alias(recwarn_always):
+ assert old_hotness() == "new hotness"
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ assert "test_deprecate.old_hotness is deprecated" in got.message.args[0]
+ assert "1.23" in got.message.args[0]
+ assert "test_deprecate.new_hotness instead" in got.message.args[0]
+ assert "issues/1" in got.message.args[0]
+
+ assert ".. deprecated:: 1.23" in old_hotness.__doc__
+ assert "test_deprecate.new_hotness instead" in old_hotness.__doc__
+ assert "issues/1>`__" in old_hotness.__doc__
+
+
+class Alias:
+ def new_hotness_method(self):
+ return "new hotness method"
+
+ old_hotness_method = deprecated_alias(
+ "Alias.old_hotness_method", new_hotness_method, "3.21", issue=1
+ )
+
+
+def test_deprecated_alias_method(recwarn_always):
+ obj = Alias()
+ assert obj.old_hotness_method() == "new hotness method"
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ msg = got.message.args[0]
+ assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg
+ assert "test_deprecate.Alias.new_hotness_method instead" in msg
+
+
+@deprecated("2.1", issue=1, instead="hi")
+def docstring_test1(): # pragma: no cover
+ """Hello!"""
+
+
+@deprecated("2.1", issue=None, instead="hi")
+def docstring_test2(): # pragma: no cover
+ """Hello!"""
+
+
+@deprecated("2.1", issue=1, instead=None)
+def docstring_test3(): # pragma: no cover
+ """Hello!"""
+
+
+@deprecated("2.1", issue=None, instead=None)
+def docstring_test4(): # pragma: no cover
+ """Hello!"""
+
+
+def test_deprecated_docstring_munging():
+ assert (
+ docstring_test1.__doc__
+ == """Hello!
+
+.. deprecated:: 2.1
+ Use hi instead.
+ For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
+
+"""
+ )
+
+ assert (
+ docstring_test2.__doc__
+ == """Hello!
+
+.. deprecated:: 2.1
+ Use hi instead.
+
+"""
+ )
+
+ assert (
+ docstring_test3.__doc__
+ == """Hello!
+
+.. deprecated:: 2.1
+ For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
+
+"""
+ )
+
+ assert (
+ docstring_test4.__doc__
+ == """Hello!
+
+.. deprecated:: 2.1
+
+"""
+ )
+
+
+def test_module_with_deprecations(recwarn_always):
+ assert module_with_deprecations.regular == "hi"
+ assert len(recwarn_always) == 0
+
+ filename, lineno = _here()
+ assert module_with_deprecations.dep1 == "value1"
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ assert got.filename == filename
+ assert got.lineno == lineno + 1
+
+ assert "module_with_deprecations.dep1" in got.message.args[0]
+ assert "Trio 1.1" in got.message.args[0]
+ assert "/issues/1" in got.message.args[0]
+ assert "value1 instead" in got.message.args[0]
+
+ assert module_with_deprecations.dep2 == "value2"
+ got = recwarn_always.pop(TrioDeprecationWarning)
+ assert "instead-string instead" in got.message.args[0]
+
+ with pytest.raises(AttributeError):
+ module_with_deprecations.asdf
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_dtls.py b/venv/lib/python3.9/site-packages/trio/tests/test_dtls.py
new file mode 100644
index 00000000..8968d9a6
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_dtls.py
@@ -0,0 +1,867 @@
+import pytest
+import trio
+import trio.testing
+from trio import DTLSEndpoint
+import random
+import attr
+from async_generator import asynccontextmanager
+from itertools import count
+
+import trustme
+from OpenSSL import SSL
+
+from trio.testing._fake_net import FakeNet
+from .._core.tests.tutil import slow, binds_ipv6, gc_collect_harder
+
+ca = trustme.CA()
+server_cert = ca.issue_cert("example.com")
+
+server_ctx = SSL.Context(SSL.DTLS_METHOD)
+server_cert.configure_cert(server_ctx)
+
+client_ctx = SSL.Context(SSL.DTLS_METHOD)
+ca.configure_trust(client_ctx)
+
+
+parametrize_ipv6 = pytest.mark.parametrize(
+ "ipv6", [False, pytest.param(True, marks=binds_ipv6)], ids=["ipv4", "ipv6"]
+)
+
+
+def endpoint(**kwargs):
+ ipv6 = kwargs.pop("ipv6", False)
+ if ipv6:
+ family = trio.socket.AF_INET6
+ else:
+ family = trio.socket.AF_INET
+ sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family)
+ return DTLSEndpoint(sock, **kwargs)
+
+
+@asynccontextmanager
+async def dtls_echo_server(*, autocancel=True, mtu=None, ipv6=False):
+ with endpoint(ipv6=ipv6) as server:
+ if ipv6:
+ localhost = "::1"
+ else:
+ localhost = "127.0.0.1"
+ await server.socket.bind((localhost, 0))
+ async with trio.open_nursery() as nursery:
+
+ async def echo_handler(dtls_channel):
+ print(
+ f"echo handler started: "
+ f"server {dtls_channel.endpoint.socket.getsockname()} "
+ f"client {dtls_channel.peer_address}"
+ )
+ if mtu is not None:
+ dtls_channel.set_ciphertext_mtu(mtu)
+ try:
+ print("server starting do_handshake")
+ await dtls_channel.do_handshake()
+ print("server finished do_handshake")
+ async for packet in dtls_channel:
+ print(f"echoing {packet} -> {dtls_channel.peer_address}")
+ await dtls_channel.send(packet)
+ except trio.BrokenResourceError: # pragma: no cover
+ print("echo handler channel broken")
+
+ await nursery.start(server.serve, server_ctx, echo_handler)
+
+ yield server, server.socket.getsockname()
+
+ if autocancel:
+ nursery.cancel_scope.cancel()
+
+
+@parametrize_ipv6
+async def test_smoke(ipv6):
+ async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address):
+ with endpoint(ipv6=ipv6) as client_endpoint:
+ client_channel = client_endpoint.connect(address, client_ctx)
+ with pytest.raises(trio.NeedHandshakeError):
+ client_channel.get_cleartext_mtu()
+
+ await client_channel.do_handshake()
+ await client_channel.send(b"hello")
+ assert await client_channel.receive() == b"hello"
+ await client_channel.send(b"goodbye")
+ assert await client_channel.receive() == b"goodbye"
+
+ with pytest.raises(ValueError):
+ await client_channel.send(b"")
+
+ client_channel.set_ciphertext_mtu(1234)
+ cleartext_mtu_1234 = client_channel.get_cleartext_mtu()
+ client_channel.set_ciphertext_mtu(4321)
+ assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234
+ client_channel.set_ciphertext_mtu(1234)
+ assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234
+
+
+@slow
+async def test_handshake_over_terrible_network(autojump_clock):
+ HANDSHAKES = 1000
+ r = random.Random(0)
+ fn = FakeNet()
+ fn.enable()
+
+ async with dtls_echo_server() as (_, address):
+ async with trio.open_nursery() as nursery:
+
+ async def route_packet(packet):
+ while True:
+ op = r.choices(
+ ["deliver", "drop", "dupe", "delay"],
+ weights=[0.7, 0.1, 0.1, 0.1],
+ )[0]
+ print(f"{packet.source} -> {packet.destination}: {op}")
+ if op == "drop":
+ return
+ elif op == "dupe":
+ fn.send_packet(packet)
+ elif op == "delay":
+ await trio.sleep(r.random() * 3)
+ # I wanted to test random packet corruption too, but it turns out
+ # openssl has a bug in the following scenario:
+ #
+ # - client sends ClientHello
+ # - server sends HelloVerifyRequest with cookie -- but cookie is
+ # invalid b/c either the ClientHello or HelloVerifyRequest was
+ # corrupted
+ # - client re-sends ClientHello with invalid cookie
+ # - server replies with new HelloVerifyRequest and correct cookie
+ #
+ # At this point, the client *should* switch to the new, valid
+ # cookie. But OpenSSL doesn't; it stubbornly insists on re-sending
+ # the original, invalid cookie over and over. In theory we could
+ # work around this by detecting cookie changes and starting over
+ # with a whole new SSL object, but (a) it doesn't seem worth it, (b)
+ # when I tried then I ran into another issue where OpenSSL got stuck
+ # in an infinite loop sending alerts over and over, which I didn't
+ # dig into because see (a).
+ #
+ # elif op == "distort":
+ # payload = bytearray(packet.payload)
+ # payload[r.randrange(len(payload))] ^= 1 << r.randrange(8)
+ # packet = attr.evolve(packet, payload=payload)
+ else:
+ assert op == "deliver"
+ print(
+ f"{packet.source} -> {packet.destination}: delivered {packet.payload.hex()}"
+ )
+ fn.deliver_packet(packet)
+ break
+
+ def route_packet_wrapper(packet):
+ try:
+ nursery.start_soon(route_packet, packet)
+ except RuntimeError: # pragma: no cover
+ # We're exiting the nursery, so any remaining packets can just get
+ # dropped
+ pass
+
+ fn.route_packet = route_packet_wrapper
+
+ for i in range(HANDSHAKES):
+ print("#" * 80)
+ print("#" * 80)
+ print("#" * 80)
+ with endpoint() as client_endpoint:
+ client = client_endpoint.connect(address, client_ctx)
+ print("client starting do_handshake")
+ await client.do_handshake()
+ print("client finished do_handshake")
+ msg = str(i).encode()
+ # Make multiple attempts to send data, because the network might
+ # drop it
+ while True:
+ with trio.move_on_after(10) as cscope:
+ await client.send(msg)
+ assert await client.receive() == msg
+ if not cscope.cancelled_caught:
+ break
+
+
+async def test_implicit_handshake():
+ async with dtls_echo_server() as (_, address):
+ with endpoint() as client_endpoint:
+ client = client_endpoint.connect(address, client_ctx)
+
+ # Implicit handshake
+ await client.send(b"xyz")
+ assert await client.receive() == b"xyz"
+
+
+async def test_full_duplex():
+ # Tests simultaneous send/receive, and also multiple methods implicitly invoking
+ # do_handshake simultaneously.
+ with endpoint() as server_endpoint, endpoint() as client_endpoint:
+ await server_endpoint.socket.bind(("127.0.0.1", 0))
+ async with trio.open_nursery() as server_nursery:
+
+ async def handler(channel):
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(channel.send, b"from server")
+ nursery.start_soon(channel.receive)
+
+ await server_nursery.start(server_endpoint.serve, server_ctx, handler)
+
+ client = client_endpoint.connect(
+ server_endpoint.socket.getsockname(), client_ctx
+ )
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(client.send, b"from client")
+ nursery.start_soon(client.receive)
+
+ server_nursery.cancel_scope.cancel()
+
+
+async def test_channel_closing():
+ async with dtls_echo_server() as (_, address):
+ with endpoint() as client_endpoint:
+ client = client_endpoint.connect(address, client_ctx)
+ await client.do_handshake()
+ client.close()
+
+ with pytest.raises(trio.ClosedResourceError):
+ await client.send(b"abc")
+ with pytest.raises(trio.ClosedResourceError):
+ await client.receive()
+
+ # close is idempotent
+ client.close()
+ # can also aclose
+ await client.aclose()
+
+
+async def test_serve_exits_cleanly_on_close():
+ async with dtls_echo_server(autocancel=False) as (server_endpoint, address):
+ server_endpoint.close()
+ # Testing that the nursery exits even without being cancelled
+ # close is idempotent
+ server_endpoint.close()
+
+
+async def test_client_multiplex():
+ async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2):
+ with endpoint() as client_endpoint:
+ client1 = client_endpoint.connect(address1, client_ctx)
+ client2 = client_endpoint.connect(address2, client_ctx)
+
+ await client1.send(b"abc")
+ await client2.send(b"xyz")
+ assert await client2.receive() == b"xyz"
+ assert await client1.receive() == b"abc"
+
+ client_endpoint.close()
+
+ with pytest.raises(trio.ClosedResourceError):
+ await client1.send("xxx")
+ with pytest.raises(trio.ClosedResourceError):
+ await client2.receive()
+ with pytest.raises(trio.ClosedResourceError):
+ client_endpoint.connect(address1, client_ctx)
+
+ async with trio.open_nursery() as nursery:
+ with pytest.raises(trio.ClosedResourceError):
+
+ async def null_handler(_): # pragma: no cover
+ pass
+
+ await nursery.start(client_endpoint.serve, server_ctx, null_handler)
+
+
+async def test_dtls_over_dgram_only():
+ with trio.socket.socket() as s:
+ with pytest.raises(ValueError):
+ DTLSEndpoint(s)
+
+
+async def test_double_serve():
+ async def null_handler(_): # pragma: no cover
+ pass
+
+ with endpoint() as server_endpoint:
+ await server_endpoint.socket.bind(("127.0.0.1", 0))
+ async with trio.open_nursery() as nursery:
+ await nursery.start(server_endpoint.serve, server_ctx, null_handler)
+ with pytest.raises(trio.BusyResourceError):
+ await nursery.start(server_endpoint.serve, server_ctx, null_handler)
+
+ nursery.cancel_scope.cancel()
+
+ async with trio.open_nursery() as nursery:
+ await nursery.start(server_endpoint.serve, server_ctx, null_handler)
+ nursery.cancel_scope.cancel()
+
+
+async def test_connect_to_non_server(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+ with endpoint() as client1, endpoint() as client2:
+ await client1.socket.bind(("127.0.0.1", 0))
+ # This should just time out
+ with trio.move_on_after(100) as cscope:
+ channel = client2.connect(client1.socket.getsockname(), client_ctx)
+ await channel.do_handshake()
+ assert cscope.cancelled_caught
+
+
+async def test_incoming_buffer_overflow(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+ for buffer_size in [10, 20]:
+ async with dtls_echo_server() as (_, address):
+ with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint:
+ assert client_endpoint.incoming_packets_buffer == buffer_size
+ client = client_endpoint.connect(address, client_ctx)
+ for i in range(buffer_size + 15):
+ await client.send(str(i).encode())
+ await trio.sleep(1)
+ stats = client.statistics()
+ assert stats.incoming_packets_dropped_in_trio == 15
+ for i in range(buffer_size):
+ assert await client.receive() == str(i).encode()
+ await client.send(b"buffer clear now")
+ assert await client.receive() == b"buffer clear now"
+
+
+async def test_server_socket_doesnt_crash_on_garbage(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+
+ from trio._dtls import (
+ Record,
+ encode_record,
+ HandshakeFragment,
+ encode_handshake_fragment,
+ ContentType,
+ HandshakeType,
+ ProtocolVersion,
+ )
+
+ client_hello = encode_record(
+ Record(
+ content_type=ContentType.handshake,
+ version=ProtocolVersion.DTLS10,
+ epoch_seqno=0,
+ payload=encode_handshake_fragment(
+ HandshakeFragment(
+ msg_type=HandshakeType.client_hello,
+ msg_len=10,
+ msg_seq=0,
+ frag_offset=0,
+ frag_len=10,
+ frag=bytes(10),
+ )
+ ),
+ )
+ )
+
+ client_hello_extended = client_hello + b"\x00"
+ client_hello_short = client_hello[:-1]
+ # cuts off in middle of handshake message header
+ client_hello_really_short = client_hello[:14]
+ client_hello_corrupt_record_len = bytearray(client_hello)
+ client_hello_corrupt_record_len[11] = 0xFF
+
+ client_hello_fragmented = encode_record(
+ Record(
+ content_type=ContentType.handshake,
+ version=ProtocolVersion.DTLS10,
+ epoch_seqno=0,
+ payload=encode_handshake_fragment(
+ HandshakeFragment(
+ msg_type=HandshakeType.client_hello,
+ msg_len=20,
+ msg_seq=0,
+ frag_offset=0,
+ frag_len=10,
+ frag=bytes(10),
+ )
+ ),
+ )
+ )
+
+ client_hello_trailing_data_in_record = encode_record(
+ Record(
+ content_type=ContentType.handshake,
+ version=ProtocolVersion.DTLS10,
+ epoch_seqno=0,
+ payload=encode_handshake_fragment(
+ HandshakeFragment(
+ msg_type=HandshakeType.client_hello,
+ msg_len=20,
+ msg_seq=0,
+ frag_offset=0,
+ frag_len=10,
+ frag=bytes(10),
+ )
+ )
+ + b"\x00",
+ )
+ )
+
+ handshake_empty = encode_record(
+ Record(
+ content_type=ContentType.handshake,
+ version=ProtocolVersion.DTLS10,
+ epoch_seqno=0,
+ payload=b"",
+ )
+ )
+
+ client_hello_truncated_in_cookie = encode_record(
+ Record(
+ content_type=ContentType.handshake,
+ version=ProtocolVersion.DTLS10,
+ epoch_seqno=0,
+ payload=bytes(2 + 32 + 1) + b"\xff",
+ )
+ )
+
+ async with dtls_echo_server() as (_, address):
+ with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock:
+ for bad_packet in [
+ b"",
+ b"xyz",
+ client_hello_extended,
+ client_hello_short,
+ client_hello_really_short,
+ client_hello_corrupt_record_len,
+ client_hello_fragmented,
+ client_hello_trailing_data_in_record,
+ handshake_empty,
+ client_hello_truncated_in_cookie,
+ ]:
+ await sock.sendto(bad_packet, address)
+ await trio.sleep(1)
+
+
+async def test_invalid_cookie_rejected(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+
+ from trio._dtls import decode_client_hello_untrusted, BadPacket
+
+ with trio.CancelScope() as cscope:
+
+ # the first 11 bytes of ClientHello aren't protected by the cookie, so only test
+ # corrupting bytes after that.
+ offset_to_corrupt = count(11)
+
+ def route_packet(packet):
+ try:
+ _, cookie, _ = decode_client_hello_untrusted(packet.payload)
+ except BadPacket:
+ pass
+ else:
+ if len(cookie) != 0:
+ # this is a challenge response packet
+ # let's corrupt the next offset so the handshake should fail
+ payload = bytearray(packet.payload)
+ offset = next(offset_to_corrupt)
+ if offset >= len(payload):
+ # We've tried all offsets. Clamp offset to the end of the
+ # payload, and terminate the test.
+ offset = len(payload) - 1
+ cscope.cancel()
+ payload[offset] ^= 0x01
+ packet = attr.evolve(packet, payload=payload)
+
+ fn.deliver_packet(packet)
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server() as (_, address):
+ while True:
+ with endpoint() as client:
+ channel = client.connect(address, client_ctx)
+ await channel.do_handshake()
+ assert cscope.cancelled_caught
+
+
+async def test_client_cancels_handshake_and_starts_new_one(autojump_clock):
+ # if a client disappears during the handshake, and then starts a new handshake from
+ # scratch, then the first handler's channel should fail, and a new handler get
+ # started
+ fn = FakeNet()
+ fn.enable()
+
+ with endpoint() as server, endpoint() as client:
+ await server.socket.bind(("127.0.0.1", 0))
+ async with trio.open_nursery() as nursery:
+ first_time = True
+
+ async def handler(channel):
+ nonlocal first_time
+ if first_time:
+ first_time = False
+ print("handler: first time, cancelling connect")
+ connect_cscope.cancel()
+ await trio.sleep(0.5)
+ print("handler: handshake should fail now")
+ with pytest.raises(trio.BrokenResourceError):
+ await channel.do_handshake()
+ else:
+ print("handler: not first time, sending hello")
+ await channel.send(b"hello")
+
+ await nursery.start(server.serve, server_ctx, handler)
+
+ print("client: starting first connect")
+ with trio.CancelScope() as connect_cscope:
+ channel = client.connect(server.socket.getsockname(), client_ctx)
+ await channel.do_handshake()
+ assert connect_cscope.cancelled_caught
+
+ print("client: starting second connect")
+ channel = client.connect(server.socket.getsockname(), client_ctx)
+ assert await channel.receive() == b"hello"
+
+ # Give handlers a chance to finish
+ await trio.sleep(10)
+ nursery.cancel_scope.cancel()
+
+
+async def test_swap_client_server():
+ with endpoint() as a, endpoint() as b:
+ await a.socket.bind(("127.0.0.1", 0))
+ await b.socket.bind(("127.0.0.1", 0))
+
+ async def echo_handler(channel):
+ async for packet in channel:
+ await channel.send(packet)
+
+ async def crashing_echo_handler(channel):
+ with pytest.raises(trio.BrokenResourceError):
+ await echo_handler(channel)
+
+ async with trio.open_nursery() as nursery:
+ await nursery.start(a.serve, server_ctx, crashing_echo_handler)
+ await nursery.start(b.serve, server_ctx, echo_handler)
+
+ b_to_a = b.connect(a.socket.getsockname(), client_ctx)
+ await b_to_a.send(b"b as client")
+ assert await b_to_a.receive() == b"b as client"
+
+ a_to_b = a.connect(b.socket.getsockname(), client_ctx)
+ await a_to_b.do_handshake()
+ with pytest.raises(trio.BrokenResourceError):
+ await b_to_a.send(b"association broken")
+ await a_to_b.send(b"a as client")
+ assert await a_to_b.receive() == b"a as client"
+
+ nursery.cancel_scope.cancel()
+
+
+@slow
+async def test_openssl_retransmit_doesnt_break_stuff():
+ # can't use autojump_clock here, because the point of the test is to wait for
+ # openssl's built-in retransmit timer to expire, which is hard-coded to use
+ # wall-clock time.
+ fn = FakeNet()
+ fn.enable()
+
+ blackholed = True
+
+ def route_packet(packet):
+ if blackholed:
+ print("dropped packet", packet)
+ return
+ print("delivered packet", packet)
+ # packets.append(
+ # scapy.all.IP(
+ # src=packet.source.ip.compressed, dst=packet.destination.ip.compressed
+ # )
+ # / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port)
+ # / packet.payload
+ # )
+ fn.deliver_packet(packet)
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server() as (server_endpoint, address):
+ with endpoint() as client_endpoint:
+ async with trio.open_nursery() as nursery:
+
+ async def connecter():
+ client = client_endpoint.connect(address, client_ctx)
+ await client.do_handshake(initial_retransmit_timeout=1.5)
+ await client.send(b"hi")
+ assert await client.receive() == b"hi"
+
+ nursery.start_soon(connecter)
+
+ # openssl's default timeout is 1 second, so this ensures that it thinks
+ # the timeout has expired
+ await trio.sleep(1.1)
+ # disable blackholing and send a garbage packet to wake up openssl so it
+ # notices the timeout has expired
+ blackholed = False
+ await server_endpoint.socket.sendto(
+ b"xxx", client_endpoint.socket.getsockname()
+ )
+ # now the client task should finish connecting and exit cleanly
+
+ # scapy.all.wrpcap("/tmp/trace.pcap", packets)
+
+
+async def test_initial_retransmit_timeout_configuration(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+
+ blackholed = True
+
+ def route_packet(packet):
+ nonlocal blackholed
+ if blackholed:
+ blackholed = False
+ else:
+ fn.deliver_packet(packet)
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server() as (_, address):
+ for t in [1, 2, 4]:
+ with endpoint() as client:
+ before = trio.current_time()
+ blackholed = True
+ channel = client.connect(address, client_ctx)
+ await channel.do_handshake(initial_retransmit_timeout=t)
+ after = trio.current_time()
+ assert after - before == t
+
+
+async def test_explicit_tiny_mtu_is_respected():
+ # ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to
+ # be larger than that. (300 is still smaller than any real network though.)
+ MTU = 300
+
+ fn = FakeNet()
+ fn.enable()
+
+ def route_packet(packet):
+ print(f"delivering {packet}")
+ print(f"payload size: {len(packet.payload)}")
+ assert len(packet.payload) <= MTU
+ fn.deliver_packet(packet)
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server(mtu=MTU) as (server, address):
+ with endpoint() as client:
+ channel = client.connect(address, client_ctx)
+ channel.set_ciphertext_mtu(MTU)
+ await channel.do_handshake()
+ await channel.send(b"hi")
+ assert await channel.receive() == b"hi"
+
+
+@parametrize_ipv6
+async def test_handshake_handles_minimum_network_mtu(ipv6, autojump_clock):
+ # Fake network that has the minimum allowable MTU for whatever protocol we're using.
+ fn = FakeNet()
+ fn.enable()
+
+ if ipv6:
+ mtu = 1280 - 48
+ else:
+ mtu = 576 - 28
+
+ def route_packet(packet):
+ if len(packet.payload) > mtu:
+ print(f"dropping {packet}")
+ else:
+ print(f"delivering {packet}")
+ fn.deliver_packet(packet)
+
+ fn.route_packet = route_packet
+
+ # See if we can successfully do a handshake -- some of the volleys will get dropped,
+ # and the retransmit logic should detect this and back off the MTU to something
+ # smaller until it succeeds.
+ async with dtls_echo_server(ipv6=ipv6) as (_, address):
+ with endpoint(ipv6=ipv6) as client_endpoint:
+ client = client_endpoint.connect(address, client_ctx)
+ # the handshake mtu backoff shouldn't affect the return value from
+ # get_cleartext_mtu, b/c that's under the user's control via
+ # set_ciphertext_mtu
+ client.set_ciphertext_mtu(9999)
+ await client.send(b"xyz")
+ assert await client.receive() == b"xyz"
+ assert client.get_cleartext_mtu() > 9000 # as vegeta said
+
+
+@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
+async def test_system_task_cleaned_up_on_gc():
+ before_tasks = trio.lowlevel.current_statistics().tasks_living
+
+ # We put this into a sub-function so that everything automatically becomes garbage
+ # when the frame exits. For some reason just doing 'del e' wasn't enough on pypy
+ # with coverage enabled -- I think we were hitting this bug:
+ # https://foss.heptapod.net/pypy/pypy/-/issues/3656
+ async def start_and_forget_endpoint():
+ e = endpoint()
+
+ # This connection/handshake attempt can't succeed. The only purpose is to force
+ # the endpoint to set up a receive loop.
+ with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
+ await s.bind(("127.0.0.1", 0))
+ c = e.connect(s.getsockname(), client_ctx)
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(c.do_handshake)
+ await trio.testing.wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+ during_tasks = trio.lowlevel.current_statistics().tasks_living
+ return during_tasks
+
+ with pytest.warns(ResourceWarning):
+ during_tasks = await start_and_forget_endpoint()
+ await trio.testing.wait_all_tasks_blocked()
+ gc_collect_harder()
+
+ await trio.testing.wait_all_tasks_blocked()
+
+ after_tasks = trio.lowlevel.current_statistics().tasks_living
+ assert before_tasks < during_tasks
+ assert before_tasks == after_tasks
+
+
+@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
+async def test_gc_before_system_task_starts():
+ e = endpoint()
+
+ with pytest.warns(ResourceWarning):
+ del e
+ gc_collect_harder()
+
+ await trio.testing.wait_all_tasks_blocked()
+
+
+@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
+async def test_gc_as_packet_received():
+ fn = FakeNet()
+ fn.enable()
+
+ e = endpoint()
+ await e.socket.bind(("127.0.0.1", 0))
+ e._ensure_receive_loop()
+
+ await trio.testing.wait_all_tasks_blocked()
+
+ with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
+ await s.sendto(b"xxx", e.socket.getsockname())
+ # At this point, the endpoint's receive loop has been marked runnable because it
+ # just received a packet; closing the endpoint socket won't interrupt that. But by
+ # the time it wakes up to process the packet, the endpoint will be gone.
+ with pytest.warns(ResourceWarning):
+ del e
+ gc_collect_harder()
+
+
+@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
+def test_gc_after_trio_exits():
+ async def main():
+ # We use fakenet just to make sure no real sockets can leak out of the test
+ # case - on pypy somehow the socket was outliving the gc_collect_harder call
+ # below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode
+ # when called after trio exits, it doesn't need a real socket.
+ fn = FakeNet()
+ fn.enable()
+ return endpoint()
+
+ e = trio.run(main)
+ with pytest.warns(ResourceWarning):
+ del e
+ gc_collect_harder()
+
+
+async def test_already_closed_socket_doesnt_crash():
+ with endpoint() as e:
+ # We close the socket before checkpointing, so the socket will already be closed
+ # when the system task starts up
+ e.socket.close()
+ # Now give it a chance to start up, and hopefully not crash
+ await trio.testing.wait_all_tasks_blocked()
+
+
+async def test_socket_closed_while_processing_clienthello(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+
+ # Check what happens if the socket is discovered to be closed when sending a
+ # HelloVerifyRequest, since that has its own sending logic
+ async with dtls_echo_server() as (server, address):
+
+ def route_packet(packet):
+ fn.deliver_packet(packet)
+ server.socket.close()
+
+ fn.route_packet = route_packet
+
+ with endpoint() as client_endpoint:
+ with trio.move_on_after(10):
+ client = client_endpoint.connect(address, client_ctx)
+ await client.do_handshake()
+
+
+async def test_association_replaced_while_handshake_running(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+
+ def route_packet(packet):
+ pass
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server() as (_, address):
+ with endpoint() as client_endpoint:
+ c1 = client_endpoint.connect(address, client_ctx)
+ async with trio.open_nursery() as nursery:
+
+ async def doomed_handshake():
+ with pytest.raises(trio.BrokenResourceError):
+ await c1.do_handshake()
+
+ nursery.start_soon(doomed_handshake)
+
+ await trio.sleep(10)
+
+ client_endpoint.connect(address, client_ctx)
+
+
+async def test_association_replaced_before_handshake_starts():
+ fn = FakeNet()
+ fn.enable()
+
+ # This test shouldn't send any packets
+ def route_packet(packet): # pragma: no cover
+ assert False
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server() as (_, address):
+ with endpoint() as client_endpoint:
+ c1 = client_endpoint.connect(address, client_ctx)
+ client_endpoint.connect(address, client_ctx)
+ with pytest.raises(trio.BrokenResourceError):
+ await c1.do_handshake()
+
+
+async def test_send_to_closed_local_port():
+ # On Windows, sending a UDP packet to a closed local port can cause a weird
+ # ECONNRESET error later, inside the receive task. Make sure we're handling it
+ # properly.
+ async with dtls_echo_server() as (_, address):
+ with endpoint() as client_endpoint:
+ async with trio.open_nursery() as nursery:
+ for i in range(1, 10):
+ channel = client_endpoint.connect(("127.0.0.1", i), client_ctx)
+ nursery.start_soon(channel.do_handshake)
+ channel = client_endpoint.connect(address, client_ctx)
+ await channel.send(b"xxx")
+ assert await channel.receive() == b"xxx"
+ nursery.cancel_scope.cancel()
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_exports.py b/venv/lib/python3.9/site-packages/trio/tests/test_exports.py
new file mode 100644
index 00000000..8d6b2d61
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_exports.py
@@ -0,0 +1,145 @@
+import re
+import sys
+import importlib
+import types
+import inspect
+import enum
+
+import pytest
+
+import trio
+import trio.testing
+
+from .. import _core
+from .. import _util
+
+
+def test_core_is_properly_reexported():
+ # Each export from _core should be re-exported by exactly one of these
+ # three modules:
+ sources = [trio, trio.lowlevel, trio.testing]
+ for symbol in dir(_core):
+ if symbol.startswith("_") or symbol == "tests":
+ continue
+ found = 0
+ for source in sources:
+ if symbol in dir(source) and getattr(source, symbol) is getattr(
+ _core, symbol
+ ):
+ found += 1
+ print(symbol, found)
+ assert found == 1
+
+
+def public_modules(module):
+ yield module
+ for name, class_ in module.__dict__.items():
+ if name.startswith("_"): # pragma: no cover
+ continue
+ if not isinstance(class_, types.ModuleType):
+ continue
+ if not class_.__name__.startswith(module.__name__): # pragma: no cover
+ continue
+ if class_ is module:
+ continue
+ # We should rename the trio.tests module (#274), but until then we use
+ # a special-case hack:
+ if class_.__name__ == "trio.tests":
+ continue
+ yield from public_modules(class_)
+
+
+PUBLIC_MODULES = list(public_modules(trio))
+PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES]
+
+
+# It doesn't make sense for downstream redistributors to run this test, since
+# they might be using a newer version of Python with additional symbols which
+# won't be reflected in trio.socket, and this shouldn't cause downstream test
+# runs to start failing.
+@pytest.mark.redistributors_should_skip
+# pylint/jedi often have trouble with alpha releases, where Python's internals
+# are in flux, grammar may not have settled down, etc.
+@pytest.mark.skipif(
+ sys.version_info.releaselevel == "alpha",
+ reason="skip static introspection tools on Python dev/alpha releases",
+)
+@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES)
+@pytest.mark.parametrize("tool", ["pylint", "jedi"])
+@pytest.mark.filterwarnings(
+ # https://github.com/pypa/setuptools/issues/3274
+ "ignore:module 'sre_constants' is deprecated:DeprecationWarning",
+)
+def test_static_tool_sees_all_symbols(tool, modname):
+ module = importlib.import_module(modname)
+
+ def no_underscores(symbols):
+ return {symbol for symbol in symbols if not symbol.startswith("_")}
+
+ runtime_names = no_underscores(dir(module))
+
+ # We should rename the trio.tests module (#274), but until then we use a
+ # special-case hack:
+ if modname == "trio":
+ runtime_names.remove("tests")
+
+ if tool == "pylint":
+ from pylint.lint import PyLinter
+
+ linter = PyLinter()
+ ast = linter.get_ast(module.__file__, modname)
+ static_names = no_underscores(ast)
+ elif tool == "jedi":
+ import jedi
+
+ # Simulate typing "import trio; trio.<TAB>"
+ script = jedi.Script("import {}; {}.".format(modname, modname))
+ completions = script.complete()
+ static_names = no_underscores(c.name for c in completions)
+ else: # pragma: no cover
+ assert False
+
+ # It's expected that the static set will contain more names than the
+ # runtime set:
+ # - static tools are sometimes sloppy and include deleted names
+ # - some symbols are platform-specific at runtime, but always show up in
+ # static analysis (e.g. in trio.socket or trio.lowlevel)
+ # So we check that the runtime names are a subset of the static names.
+ missing_names = runtime_names - static_names
+ if missing_names: # pragma: no cover
+ print("{} can't see the following names in {}:".format(tool, modname))
+ print()
+ for name in sorted(missing_names):
+ print(" {}".format(name))
+ assert False
+
+
+def test_classes_are_final():
+ for module in PUBLIC_MODULES:
+ for name, class_ in module.__dict__.items():
+ if not isinstance(class_, type):
+ continue
+ # Deprecated classes are exported with a leading underscore
+ if name.startswith("_"): # pragma: no cover
+ continue
+
+ # Abstract classes can be subclassed, because that's the whole
+ # point of ABCs
+ if inspect.isabstract(class_):
+ continue
+ # Exceptions are allowed to be subclassed, because exception
+ # subclassing isn't used to inherit behavior.
+ if issubclass(class_, BaseException):
+ continue
+ # These are classes that are conceptually abstract, but
+ # inspect.isabstract returns False for boring reasons.
+ if class_ in {trio.abc.Instrument, trio.socket.SocketType}:
+ continue
+ # Enums have their own metaclass, so we can't use our metaclasses.
+ # And I don't think there's a lot of risk from people subclassing
+ # enums...
+ if issubclass(class_, enum.Enum):
+ continue
+ # ... insert other special cases here ...
+
+ assert isinstance(class_, _util.Final)
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_fakenet.py b/venv/lib/python3.9/site-packages/trio/tests/test_fakenet.py
new file mode 100644
index 00000000..bc691c9d
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_fakenet.py
@@ -0,0 +1,44 @@
+import pytest
+
+import trio
+from trio.testing._fake_net import FakeNet
+
+
+def fn():
+ fn = FakeNet()
+ fn.enable()
+ return fn
+
+
+async def test_basic_udp():
+ fn()
+ s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
+ s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
+
+ await s1.bind(("127.0.0.1", 0))
+ ip, port = s1.getsockname()
+ assert ip == "127.0.0.1"
+ assert port != 0
+ await s2.sendto(b"xyz", s1.getsockname())
+ data, addr = await s1.recvfrom(10)
+ assert data == b"xyz"
+ assert addr == s2.getsockname()
+ await s1.sendto(b"abc", s2.getsockname())
+ data, addr = await s2.recvfrom(10)
+ assert data == b"abc"
+ assert addr == s1.getsockname()
+
+
+async def test_msg_trunc():
+ fn()
+ s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
+ s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
+ await s1.bind(("127.0.0.1", 0))
+ await s2.sendto(b"xyz", s1.getsockname())
+ data, addr = await s1.recvfrom(10)
+
+
+async def test_basic_tcp():
+ fn()
+ with pytest.raises(NotImplementedError):
+ trio.socket.socket()
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_file_io.py b/venv/lib/python3.9/site-packages/trio/tests/test_file_io.py
new file mode 100644
index 00000000..b40f7518
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_file_io.py
@@ -0,0 +1,198 @@
+import io
+import os
+
+import pytest
+from unittest import mock
+from unittest.mock import sentinel
+
+import trio
+from trio import _core
+from trio._file_io import AsyncIOWrapper, _FILE_SYNC_ATTRS, _FILE_ASYNC_METHODS
+
+
+@pytest.fixture
+def path(tmpdir):
+ return os.fspath(tmpdir.join("test"))
+
+
+@pytest.fixture
+def wrapped():
+ return mock.Mock(spec_set=io.StringIO)
+
+
+@pytest.fixture
+def async_file(wrapped):
+ return trio.wrap_file(wrapped)
+
+
+def test_wrap_invalid():
+ with pytest.raises(TypeError):
+ trio.wrap_file(str())
+
+
+def test_wrap_non_iobase():
+ class FakeFile:
+ def close(self): # pragma: no cover
+ pass
+
+ def write(self): # pragma: no cover
+ pass
+
+ wrapped = FakeFile()
+ assert not isinstance(wrapped, io.IOBase)
+
+ async_file = trio.wrap_file(wrapped)
+ assert isinstance(async_file, AsyncIOWrapper)
+
+ del FakeFile.write
+
+ with pytest.raises(TypeError):
+ trio.wrap_file(FakeFile())
+
+
+def test_wrapped_property(async_file, wrapped):
+ assert async_file.wrapped is wrapped
+
+
+def test_dir_matches_wrapped(async_file, wrapped):
+ attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS)
+
+ # all supported attrs in wrapped should be available in async_file
+ assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped))
+ # all supported attrs not in wrapped should not be available in async_file
+ assert not any(
+ attr in dir(async_file) for attr in attrs if attr not in dir(wrapped)
+ )
+
+
+def test_unsupported_not_forwarded():
+ class FakeFile(io.RawIOBase):
+ def unsupported_attr(self): # pragma: no cover
+ pass
+
+ async_file = trio.wrap_file(FakeFile())
+
+ assert hasattr(async_file.wrapped, "unsupported_attr")
+
+ with pytest.raises(AttributeError):
+ getattr(async_file, "unsupported_attr")
+
+
+def test_sync_attrs_forwarded(async_file, wrapped):
+ for attr_name in _FILE_SYNC_ATTRS:
+ if attr_name not in dir(async_file):
+ continue
+
+ assert getattr(async_file, attr_name) is getattr(wrapped, attr_name)
+
+
+def test_sync_attrs_match_wrapper(async_file, wrapped):
+ for attr_name in _FILE_SYNC_ATTRS:
+ if attr_name in dir(async_file):
+ continue
+
+ with pytest.raises(AttributeError):
+ getattr(async_file, attr_name)
+
+ with pytest.raises(AttributeError):
+ getattr(wrapped, attr_name)
+
+
+def test_async_methods_generated_once(async_file):
+ for meth_name in _FILE_ASYNC_METHODS:
+ if meth_name not in dir(async_file):
+ continue
+
+ assert getattr(async_file, meth_name) is getattr(async_file, meth_name)
+
+
+def test_async_methods_signature(async_file):
+ # use read as a representative of all async methods
+ assert async_file.read.__name__ == "read"
+ assert async_file.read.__qualname__ == "AsyncIOWrapper.read"
+
+ assert "io.StringIO.read" in async_file.read.__doc__
+
+
+async def test_async_methods_wrap(async_file, wrapped):
+ for meth_name in _FILE_ASYNC_METHODS:
+ if meth_name not in dir(async_file):
+ continue
+
+ meth = getattr(async_file, meth_name)
+ wrapped_meth = getattr(wrapped, meth_name)
+
+ value = await meth(sentinel.argument, keyword=sentinel.keyword)
+
+ wrapped_meth.assert_called_once_with(
+ sentinel.argument, keyword=sentinel.keyword
+ )
+ assert value == wrapped_meth()
+
+ wrapped.reset_mock()
+
+
+async def test_async_methods_match_wrapper(async_file, wrapped):
+ for meth_name in _FILE_ASYNC_METHODS:
+ if meth_name in dir(async_file):
+ continue
+
+ with pytest.raises(AttributeError):
+ getattr(async_file, meth_name)
+
+ with pytest.raises(AttributeError):
+ getattr(wrapped, meth_name)
+
+
+async def test_open(path):
+ f = await trio.open_file(path, "w")
+
+ assert isinstance(f, AsyncIOWrapper)
+
+ await f.aclose()
+
+
+async def test_open_context_manager(path):
+ async with await trio.open_file(path, "w") as f:
+ assert isinstance(f, AsyncIOWrapper)
+ assert not f.closed
+
+ assert f.closed
+
+
+async def test_async_iter():
+ async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar"))
+ expected = list(async_file.wrapped)
+ result = []
+ async_file.wrapped.seek(0)
+
+ async for line in async_file:
+ result.append(line)
+
+ assert result == expected
+
+
+async def test_aclose_cancelled(path):
+ with _core.CancelScope() as cscope:
+ f = await trio.open_file(path, "w")
+ cscope.cancel()
+
+ with pytest.raises(_core.Cancelled):
+ await f.write("a")
+
+ with pytest.raises(_core.Cancelled):
+ await f.aclose()
+
+ assert f.closed
+
+
+async def test_detach_rewraps_asynciobase():
+ raw = io.BytesIO()
+ buffered = io.BufferedReader(raw)
+
+ async_file = trio.wrap_file(buffered)
+
+ detached = await async_file.detach()
+
+ assert isinstance(detached, AsyncIOWrapper)
+ assert detached.wrapped is raw
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_generic.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_generic.py
new file mode 100644
index 00000000..df2b2cec
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_generic.py
@@ -0,0 +1,94 @@
+import pytest
+
+import attr
+
+from ..abc import SendStream, ReceiveStream
+from .._highlevel_generic import StapledStream
+
+
+@attr.s
+class RecordSendStream(SendStream):
+ record = attr.ib(factory=list)
+
+ async def send_all(self, data):
+ self.record.append(("send_all", data))
+
+ async def wait_send_all_might_not_block(self):
+ self.record.append("wait_send_all_might_not_block")
+
+ async def aclose(self):
+ self.record.append("aclose")
+
+
+@attr.s
+class RecordReceiveStream(ReceiveStream):
+ record = attr.ib(factory=list)
+
+ async def receive_some(self, max_bytes=None):
+ self.record.append(("receive_some", max_bytes))
+
+ async def aclose(self):
+ self.record.append("aclose")
+
+
+async def test_StapledStream():
+ send_stream = RecordSendStream()
+ receive_stream = RecordReceiveStream()
+ stapled = StapledStream(send_stream, receive_stream)
+
+ assert stapled.send_stream is send_stream
+ assert stapled.receive_stream is receive_stream
+
+ await stapled.send_all(b"foo")
+ await stapled.wait_send_all_might_not_block()
+ assert send_stream.record == [
+ ("send_all", b"foo"),
+ "wait_send_all_might_not_block",
+ ]
+ send_stream.record.clear()
+
+ await stapled.send_eof()
+ assert send_stream.record == ["aclose"]
+ send_stream.record.clear()
+
+ async def fake_send_eof():
+ send_stream.record.append("send_eof")
+
+ send_stream.send_eof = fake_send_eof
+ await stapled.send_eof()
+ assert send_stream.record == ["send_eof"]
+
+ send_stream.record.clear()
+ assert receive_stream.record == []
+
+ await stapled.receive_some(1234)
+ assert receive_stream.record == [("receive_some", 1234)]
+ assert send_stream.record == []
+ receive_stream.record.clear()
+
+ await stapled.aclose()
+ assert receive_stream.record == ["aclose"]
+ assert send_stream.record == ["aclose"]
+
+
+async def test_StapledStream_with_erroring_close():
+ # Make sure that if one of the aclose methods errors out, then the other
+ # one still gets called.
+ class BrokenSendStream(RecordSendStream):
+ async def aclose(self):
+ await super().aclose()
+ raise ValueError
+
+ class BrokenReceiveStream(RecordReceiveStream):
+ async def aclose(self):
+ await super().aclose()
+ raise ValueError
+
+ stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
+
+ with pytest.raises(ValueError) as excinfo:
+ await stapled.aclose()
+ assert isinstance(excinfo.value.__context__, ValueError)
+
+ assert stapled.send_stream.record == ["aclose"]
+ assert stapled.receive_stream.record == ["aclose"]
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_listeners.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_listeners.py
new file mode 100644
index 00000000..10398473
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_listeners.py
@@ -0,0 +1,300 @@
+import sys
+
+import pytest
+
+import socket as stdlib_socket
+import errno
+
+import attr
+
+import trio
+from trio import open_tcp_listeners, serve_tcp, SocketListener, open_tcp_stream
+from trio.testing import open_stream_to_socket_listener
+from .. import socket as tsocket
+from .._core.tests.tutil import slow, creates_ipv6, binds_ipv6
+
+if sys.version_info < (3, 11):
+ from exceptiongroup import BaseExceptionGroup
+
+
+async def test_open_tcp_listeners_basic():
+ listeners = await open_tcp_listeners(0)
+ assert isinstance(listeners, list)
+ for obj in listeners:
+ assert isinstance(obj, SocketListener)
+ # Binds to wildcard address by default
+ assert obj.socket.family in [tsocket.AF_INET, tsocket.AF_INET6]
+ assert obj.socket.getsockname()[0] in ["0.0.0.0", "::"]
+
+ listener = listeners[0]
+ # Make sure the backlog is at least 2
+ c1 = await open_stream_to_socket_listener(listener)
+ c2 = await open_stream_to_socket_listener(listener)
+
+ s1 = await listener.accept()
+ s2 = await listener.accept()
+
+ # Note that we don't know which client stream is connected to which server
+ # stream
+ await s1.send_all(b"x")
+ await s2.send_all(b"x")
+ assert await c1.receive_some(1) == b"x"
+ assert await c2.receive_some(1) == b"x"
+
+ for resource in [c1, c2, s1, s2] + listeners:
+ await resource.aclose()
+
+
+async def test_open_tcp_listeners_specific_port_specific_host():
+ # Pick a port
+ sock = tsocket.socket()
+ await sock.bind(("127.0.0.1", 0))
+ host, port = sock.getsockname()
+ sock.close()
+
+ (listener,) = await open_tcp_listeners(port, host=host)
+ async with listener:
+ assert listener.socket.getsockname() == (host, port)
+
+
+@binds_ipv6
+async def test_open_tcp_listeners_ipv6_v6only():
+ # Check IPV6_V6ONLY is working properly
+ (ipv6_listener,) = await open_tcp_listeners(0, host="::1")
+ async with ipv6_listener:
+ _, port, *_ = ipv6_listener.socket.getsockname()
+
+ with pytest.raises(OSError):
+ await open_tcp_stream("127.0.0.1", port)
+
+
+async def test_open_tcp_listeners_rebind():
+ (l1,) = await open_tcp_listeners(0, host="127.0.0.1")
+ sockaddr1 = l1.socket.getsockname()
+
+ # Plain old rebinding while it's still there should fail, even if we have
+ # SO_REUSEADDR set
+ with stdlib_socket.socket() as probe:
+ probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
+ with pytest.raises(OSError):
+ probe.bind(sockaddr1)
+
+ # Now use the first listener to set up some connections in various states,
+ # and make sure that they don't create any obstacle to rebinding a second
+ # listener after the first one is closed.
+ c_established = await open_stream_to_socket_listener(l1)
+ s_established = await l1.accept()
+
+ c_time_wait = await open_stream_to_socket_listener(l1)
+ s_time_wait = await l1.accept()
+ # Server-initiated close leaves socket in TIME_WAIT
+ await s_time_wait.aclose()
+
+ await l1.aclose()
+ (l2,) = await open_tcp_listeners(sockaddr1[1], host="127.0.0.1")
+ sockaddr2 = l2.socket.getsockname()
+
+ assert sockaddr1 == sockaddr2
+ assert s_established.socket.getsockname() == sockaddr2
+ assert c_time_wait.socket.getpeername() == sockaddr2
+
+ for resource in [
+ l1,
+ l2,
+ c_established,
+ s_established,
+ c_time_wait,
+ s_time_wait,
+ ]:
+ await resource.aclose()
+
+
+class FakeOSError(OSError):
+ pass
+
+
+@attr.s
+class FakeSocket(tsocket.SocketType):
+ family = attr.ib()
+ type = attr.ib()
+ proto = attr.ib()
+
+ closed = attr.ib(default=False)
+ poison_listen = attr.ib(default=False)
+ backlog = attr.ib(default=None)
+
+ def getsockopt(self, level, option):
+ if (level, option) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
+ return True
+ assert False # pragma: no cover
+
+ def setsockopt(self, level, option, value):
+ pass
+
+ async def bind(self, sockaddr):
+ pass
+
+ def listen(self, backlog):
+ assert self.backlog is None
+ assert backlog is not None
+ self.backlog = backlog
+ if self.poison_listen:
+ raise FakeOSError("whoops")
+
+ def close(self):
+ self.closed = True
+
+
+@attr.s
+class FakeSocketFactory:
+ poison_after = attr.ib()
+ sockets = attr.ib(factory=list)
+ raise_on_family = attr.ib(factory=dict) # family => errno
+
+ def socket(self, family, type, proto):
+ if family in self.raise_on_family:
+ raise OSError(self.raise_on_family[family], "nope")
+ sock = FakeSocket(family, type, proto)
+ self.poison_after -= 1
+ if self.poison_after == 0:
+ sock.poison_listen = True
+ self.sockets.append(sock)
+ return sock
+
+
+@attr.s
+class FakeHostnameResolver:
+ family_addr_pairs = attr.ib()
+
+ async def getaddrinfo(self, host, port, family, type, proto, flags):
+ return [
+ (family, tsocket.SOCK_STREAM, 0, "", (addr, port))
+ for family, addr in self.family_addr_pairs
+ ]
+
+
+async def test_open_tcp_listeners_multiple_host_cleanup_on_error():
+ # If we were trying to bind to multiple hosts and one of them failed, they
+ # call get cleaned up before returning
+ fsf = FakeSocketFactory(3)
+ tsocket.set_custom_socket_factory(fsf)
+ tsocket.set_custom_hostname_resolver(
+ FakeHostnameResolver(
+ [
+ (tsocket.AF_INET, "1.1.1.1"),
+ (tsocket.AF_INET, "2.2.2.2"),
+ (tsocket.AF_INET, "3.3.3.3"),
+ ]
+ )
+ )
+
+ with pytest.raises(FakeOSError):
+ await open_tcp_listeners(80, host="example.org")
+
+ assert len(fsf.sockets) == 3
+ for sock in fsf.sockets:
+ assert sock.closed
+
+
+async def test_open_tcp_listeners_port_checking():
+ for host in ["127.0.0.1", None]:
+ with pytest.raises(TypeError):
+ await open_tcp_listeners(None, host=host)
+ with pytest.raises(TypeError):
+ await open_tcp_listeners(b"80", host=host)
+ with pytest.raises(TypeError):
+ await open_tcp_listeners("http", host=host)
+
+
+async def test_serve_tcp():
+ async def handler(stream):
+ await stream.send_all(b"x")
+
+ async with trio.open_nursery() as nursery:
+ listeners = await nursery.start(serve_tcp, handler, 0)
+ stream = await open_stream_to_socket_listener(listeners[0])
+ async with stream:
+ await stream.receive_some(1) == b"x"
+ nursery.cancel_scope.cancel()
+
+
+@pytest.mark.parametrize(
+ "try_families",
+ [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
+)
+@pytest.mark.parametrize(
+ "fail_families",
+ [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
+)
+async def test_open_tcp_listeners_some_address_families_unavailable(
+ try_families, fail_families
+):
+ fsf = FakeSocketFactory(
+ 10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families}
+ )
+ tsocket.set_custom_socket_factory(fsf)
+ tsocket.set_custom_hostname_resolver(
+ FakeHostnameResolver([(family, "foo") for family in try_families])
+ )
+
+ should_succeed = try_families - fail_families
+
+ if not should_succeed:
+ with pytest.raises(OSError) as exc_info:
+ await open_tcp_listeners(80, host="example.org")
+
+ assert "This system doesn't support" in str(exc_info.value)
+ if isinstance(exc_info.value.__cause__, BaseExceptionGroup):
+ for subexc in exc_info.value.__cause__.exceptions:
+ assert "nope" in str(subexc)
+ else:
+ assert isinstance(exc_info.value.__cause__, OSError)
+ assert "nope" in str(exc_info.value.__cause__)
+ else:
+ listeners = await open_tcp_listeners(80)
+ for listener in listeners:
+ should_succeed.remove(listener.socket.family)
+ assert not should_succeed
+
+
+async def test_open_tcp_listeners_socket_fails_not_afnosupport():
+ fsf = FakeSocketFactory(
+ 10,
+ raise_on_family={
+ tsocket.AF_INET: errno.EAFNOSUPPORT,
+ tsocket.AF_INET6: errno.EINVAL,
+ },
+ )
+ tsocket.set_custom_socket_factory(fsf)
+ tsocket.set_custom_hostname_resolver(
+ FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")])
+ )
+
+ with pytest.raises(OSError) as exc_info:
+ await open_tcp_listeners(80, host="example.org")
+ assert exc_info.value.errno == errno.EINVAL
+ assert exc_info.value.__cause__ is None
+ assert "nope" in str(exc_info.value)
+
+
+# We used to have an elaborate test that opened a real TCP listening socket
+# and then tried to measure its backlog by making connections to it. And most
+# of the time, it worked. But no matter what we tried, it was always fragile,
+# because it had to do things like use timeouts to guess when the listening
+# queue was full, sometimes the CI hosts go into SYN-cookie mode (where there
+# effectively is no backlog), sometimes the host might not be enough resources
+# to give us the full requested backlog... it was a mess. So now we just check
+# that the backlog argument is passed through correctly.
+async def test_open_tcp_listeners_backlog():
+ fsf = FakeSocketFactory(99)
+ tsocket.set_custom_socket_factory(fsf)
+ for (given, expected) in [
+ (None, 0xFFFF),
+ (99999999, 0xFFFF),
+ (10, 10),
+ (1, 1),
+ ]:
+ listeners = await open_tcp_listeners(0, backlog=given)
+ assert listeners
+ for listener in listeners:
+ assert listener.socket.backlog == expected
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_stream.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_stream.py
new file mode 100644
index 00000000..0f3b6a0b
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_stream.py
@@ -0,0 +1,574 @@
+import pytest
+import sys
+import socket
+
+import attr
+
+import trio
+from trio.socket import AF_INET, AF_INET6, SOCK_STREAM, IPPROTO_TCP
+from trio._highlevel_open_tcp_stream import (
+ reorder_for_rfc_6555_section_5_4,
+ close_all,
+ open_tcp_stream,
+ format_host_port,
+)
+
+if sys.version_info < (3, 11):
+ from exceptiongroup import BaseExceptionGroup
+
+
+def test_close_all():
+ class CloseMe:
+ closed = False
+
+ def close(self):
+ self.closed = True
+
+ class CloseKiller:
+ def close(self):
+ raise OSError
+
+ c = CloseMe()
+ with close_all() as to_close:
+ to_close.add(c)
+ assert c.closed
+
+ c = CloseMe()
+ with pytest.raises(RuntimeError):
+ with close_all() as to_close:
+ to_close.add(c)
+ raise RuntimeError
+ assert c.closed
+
+ c = CloseMe()
+ with pytest.raises(OSError):
+ with close_all() as to_close:
+ to_close.add(CloseKiller())
+ to_close.add(c)
+ assert c.closed
+
+
+def test_reorder_for_rfc_6555_section_5_4():
+ def fake4(i):
+ return (
+ AF_INET,
+ SOCK_STREAM,
+ IPPROTO_TCP,
+ "",
+ ("10.0.0.{}".format(i), 80),
+ )
+
+ def fake6(i):
+ return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", ("::{}".format(i), 80))
+
+ for fake in fake4, fake6:
+ # No effect on homogeneous lists
+ targets = [fake(0), fake(1), fake(2)]
+ reorder_for_rfc_6555_section_5_4(targets)
+ assert targets == [fake(0), fake(1), fake(2)]
+
+ # Single item lists also OK
+ targets = [fake(0)]
+ reorder_for_rfc_6555_section_5_4(targets)
+ assert targets == [fake(0)]
+
+ # If the list starts out with different families in positions 0 and 1,
+ # then it's left alone
+ orig = [fake4(0), fake6(0), fake4(1), fake6(1)]
+ targets = list(orig)
+ reorder_for_rfc_6555_section_5_4(targets)
+ assert targets == orig
+
+ # If not, it's reordered
+ targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)]
+ reorder_for_rfc_6555_section_5_4(targets)
+ assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
+
+
+def test_format_host_port():
+ assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
+ assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
+ assert format_host_port("example.com", 443) == "example.com:443"
+ assert format_host_port(b"example.com", 443) == "example.com:443"
+ assert format_host_port("::1", "http") == "[::1]:http"
+ assert format_host_port(b"::1", "http") == "[::1]:http"
+
+
+# Make sure we can connect to localhost using real kernel sockets
+async def test_open_tcp_stream_real_socket_smoketest():
+ listen_sock = trio.socket.socket()
+ await listen_sock.bind(("127.0.0.1", 0))
+ _, listen_port = listen_sock.getsockname()
+ listen_sock.listen(1)
+ client_stream = await open_tcp_stream("127.0.0.1", listen_port)
+ server_sock, _ = await listen_sock.accept()
+ await client_stream.send_all(b"x")
+ assert await server_sock.recv(1) == b"x"
+ await client_stream.aclose()
+ server_sock.close()
+
+ listen_sock.close()
+
+
+async def test_open_tcp_stream_input_validation():
+ with pytest.raises(ValueError):
+ await open_tcp_stream(None, 80)
+ with pytest.raises(TypeError):
+ await open_tcp_stream("127.0.0.1", b"80")
+
+
+def can_bind_127_0_0_2():
+ with socket.socket() as s:
+ try:
+ s.bind(("127.0.0.2", 0))
+ except OSError:
+ return False
+ return s.getsockname()[0] == "127.0.0.2"
+
+
+async def test_local_address_real():
+ with trio.socket.socket() as listener:
+ await listener.bind(("127.0.0.1", 0))
+ listener.listen()
+
+ # It's hard to test local_address properly, because you need multiple
+ # local addresses that you can bind to. Fortunately, on most Linux
+ # systems, you can bind to any 127.*.*.* address, and they all go
+ # through the loopback interface. So we can use a non-standard
+ # loopback address. On other systems, the only address we know for
+ # certain we have is 127.0.0.1, so we can't really test local_address=
+ # properly -- passing local_address=127.0.0.1 is indistinguishable
+ # from not passing local_address= at all. But, we can still do a smoke
+ # test to make sure the local_address= code doesn't crash.
+ if can_bind_127_0_0_2():
+ local_address = "127.0.0.2"
+ else:
+ local_address = "127.0.0.1"
+
+ async with await open_tcp_stream(
+ *listener.getsockname(), local_address=local_address
+ ) as client_stream:
+ assert client_stream.socket.getsockname()[0] == local_address
+ if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
+ assert client_stream.socket.getsockopt(
+ trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT
+ )
+
+ server_sock, remote_addr = await listener.accept()
+ await client_stream.aclose()
+ server_sock.close()
+ assert remote_addr[0] == local_address
+
+ # Trying to connect to an ipv4 address with the ipv6 wildcard
+ # local_address should fail
+ with pytest.raises(OSError):
+ await open_tcp_stream(*listener.getsockname(), local_address="::")
+
+ # But the ipv4 wildcard address should work
+ async with await open_tcp_stream(
+ *listener.getsockname(), local_address="0.0.0.0"
+ ) as client_stream:
+ server_sock, remote_addr = await listener.accept()
+ server_sock.close()
+ assert remote_addr == client_stream.socket.getsockname()
+
+
+# Now, thorough tests using fake sockets
+
+
+@attr.s(eq=False)
+class FakeSocket(trio.socket.SocketType):
+ scenario = attr.ib()
+ family = attr.ib()
+ type = attr.ib()
+ proto = attr.ib()
+
+ ip = attr.ib(default=None)
+ port = attr.ib(default=None)
+ succeeded = attr.ib(default=False)
+ closed = attr.ib(default=False)
+ failing = attr.ib(default=False)
+
+ async def connect(self, sockaddr):
+ self.ip = sockaddr[0]
+ self.port = sockaddr[1]
+ assert self.ip not in self.scenario.sockets
+ self.scenario.sockets[self.ip] = self
+ self.scenario.connect_times[self.ip] = trio.current_time()
+ delay, result = self.scenario.ip_dict[self.ip]
+ await trio.sleep(delay)
+ if result == "error":
+ raise OSError("sorry")
+ if result == "postconnect_fail":
+ self.failing = True
+ self.succeeded = True
+
+ def close(self):
+ self.closed = True
+
+ # called when SocketStream is constructed
+ def setsockopt(self, *args, **kwargs):
+ if self.failing:
+ # raise something that isn't OSError as SocketStream
+ # ignores those
+ raise KeyboardInterrupt
+
+
+class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
+ def __init__(self, port, ip_list, supported_families):
+ # ip_list have to be unique
+ ip_order = [ip for (ip, _, _) in ip_list]
+ assert len(set(ip_order)) == len(ip_list)
+ ip_dict = {}
+ for ip, delay, result in ip_list:
+ assert 0 <= delay
+ assert result in ["error", "success", "postconnect_fail"]
+ ip_dict[ip] = (delay, result)
+
+ self.port = port
+ self.ip_order = ip_order
+ self.ip_dict = ip_dict
+ self.supported_families = supported_families
+ self.socket_count = 0
+ self.sockets = {}
+ self.connect_times = {}
+
+ def socket(self, family, type, proto):
+ if family not in self.supported_families:
+ raise OSError("pretending not to support this family")
+ self.socket_count += 1
+ return FakeSocket(self, family, type, proto)
+
+ def _ip_to_gai_entry(self, ip):
+ if ":" in ip:
+ family = trio.socket.AF_INET6
+ sockaddr = (ip, self.port, 0, 0)
+ else:
+ family = trio.socket.AF_INET
+ sockaddr = (ip, self.port)
+ return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
+
+ async def getaddrinfo(self, host, port, family, type, proto, flags):
+ assert host == b"test.example.com"
+ assert port == self.port
+ assert family == trio.socket.AF_UNSPEC
+ assert type == trio.socket.SOCK_STREAM
+ assert proto == 0
+ assert flags == 0
+ return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
+
+ async def getnameinfo(self, sockaddr, flags): # pragma: no cover
+ raise NotImplementedError
+
+ def check(self, succeeded):
+ # sockets only go into self.sockets when connect is called; make sure
+ # all the sockets that were created did in fact go in there.
+ assert self.socket_count == len(self.sockets)
+
+ for ip, socket in self.sockets.items():
+ assert ip in self.ip_dict
+ if socket is not succeeded:
+ assert socket.closed
+ assert socket.port == self.port
+
+
+async def run_scenario(
+ # The port to connect to
+ port,
+ # A list of
+ # (ip, delay, result)
+ # tuples, where delay is in seconds and result is "success" or "error"
+ # The ip's will be returned from getaddrinfo in this order, and then
+ # connect() calls to them will have the given result.
+ ip_list,
+ *,
+ # If False, AF_INET4/6 sockets error out on creation, before connect is
+ # even called.
+ ipv4_supported=True,
+ ipv6_supported=True,
+ # Normally, we return (winning_sock, scenario object)
+ # If this is True, we require there to be an exception, and return
+ # (exception, scenario object)
+ expect_error=(),
+ **kwargs,
+):
+ supported_families = set()
+ if ipv4_supported:
+ supported_families.add(trio.socket.AF_INET)
+ if ipv6_supported:
+ supported_families.add(trio.socket.AF_INET6)
+ scenario = Scenario(port, ip_list, supported_families)
+ trio.socket.set_custom_hostname_resolver(scenario)
+ trio.socket.set_custom_socket_factory(scenario)
+
+ try:
+ stream = await open_tcp_stream("test.example.com", port, **kwargs)
+ assert expect_error == ()
+ scenario.check(stream.socket)
+ return (stream.socket, scenario)
+ except AssertionError: # pragma: no cover
+ raise
+ except expect_error as exc:
+ scenario.check(None)
+ return (exc, scenario)
+
+
+async def test_one_host_quick_success(autojump_clock):
+ sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
+ assert sock.ip == "1.2.3.4"
+ assert trio.current_time() == 0.123
+
+
+async def test_one_host_slow_success(autojump_clock):
+ sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
+ assert sock.ip == "1.2.3.4"
+ assert trio.current_time() == 100
+
+
+async def test_one_host_quick_fail(autojump_clock):
+ exc, scenario = await run_scenario(
+ 82, [("1.2.3.4", 0.123, "error")], expect_error=OSError
+ )
+ assert isinstance(exc, OSError)
+ assert trio.current_time() == 0.123
+
+
+async def test_one_host_slow_fail(autojump_clock):
+ exc, scenario = await run_scenario(
+ 83, [("1.2.3.4", 100, "error")], expect_error=OSError
+ )
+ assert isinstance(exc, OSError)
+ assert trio.current_time() == 100
+
+
+async def test_one_host_failed_after_connect(autojump_clock):
+ exc, scenario = await run_scenario(
+ 83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt
+ )
+ assert isinstance(exc, KeyboardInterrupt)
+
+
+# With the default 0.250 second delay, the third attempt will win
+async def test_basic_fallthrough(autojump_clock):
+ sock, scenario = await run_scenario(
+ 80,
+ [
+ ("1.1.1.1", 1, "success"),
+ ("2.2.2.2", 1, "success"),
+ ("3.3.3.3", 0.2, "success"),
+ ],
+ )
+ assert sock.ip == "3.3.3.3"
+ # current time is default time + default time + connection time
+ assert trio.current_time() == (0.250 + 0.250 + 0.2)
+ assert scenario.connect_times == {
+ "1.1.1.1": 0,
+ "2.2.2.2": 0.250,
+ "3.3.3.3": 0.500,
+ }
+
+
+async def test_early_success(autojump_clock):
+ sock, scenario = await run_scenario(
+ 80,
+ [
+ ("1.1.1.1", 1, "success"),
+ ("2.2.2.2", 0.1, "success"),
+ ("3.3.3.3", 0.2, "success"),
+ ],
+ )
+ assert sock.ip == "2.2.2.2"
+ assert trio.current_time() == (0.250 + 0.1)
+ assert scenario.connect_times == {
+ "1.1.1.1": 0,
+ "2.2.2.2": 0.250,
+ # 3.3.3.3 was never even started
+ }
+
+
+# With a 0.450 second delay, the first attempt will win
+async def test_custom_delay(autojump_clock):
+ sock, scenario = await run_scenario(
+ 80,
+ [
+ ("1.1.1.1", 1, "success"),
+ ("2.2.2.2", 1, "success"),
+ ("3.3.3.3", 0.2, "success"),
+ ],
+ happy_eyeballs_delay=0.450,
+ )
+ assert sock.ip == "1.1.1.1"
+ assert trio.current_time() == 1
+ assert scenario.connect_times == {
+ "1.1.1.1": 0,
+ "2.2.2.2": 0.450,
+ "3.3.3.3": 0.900,
+ }
+
+
+async def test_custom_errors_expedite(autojump_clock):
+ sock, scenario = await run_scenario(
+ 80,
+ [
+ ("1.1.1.1", 0.1, "error"),
+ ("2.2.2.2", 0.2, "error"),
+ ("3.3.3.3", 10, "success"),
+ # .25 is the default timeout
+ ("4.4.4.4", 0.25, "success"),
+ ],
+ )
+ assert sock.ip == "4.4.4.4"
+ assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
+ assert scenario.connect_times == {
+ "1.1.1.1": 0,
+ "2.2.2.2": 0.1,
+ "3.3.3.3": 0.1 + 0.2,
+ "4.4.4.4": 0.1 + 0.2 + 0.25,
+ }
+
+
+async def test_all_fail(autojump_clock):
+ exc, scenario = await run_scenario(
+ 80,
+ [
+ ("1.1.1.1", 0.1, "error"),
+ ("2.2.2.2", 0.2, "error"),
+ ("3.3.3.3", 10, "error"),
+ ("4.4.4.4", 0.250, "error"),
+ ],
+ expect_error=OSError,
+ )
+ assert isinstance(exc, OSError)
+ assert isinstance(exc.__cause__, BaseExceptionGroup)
+ assert len(exc.__cause__.exceptions) == 4
+ assert trio.current_time() == (0.1 + 0.2 + 10)
+ assert scenario.connect_times == {
+ "1.1.1.1": 0,
+ "2.2.2.2": 0.1,
+ "3.3.3.3": 0.1 + 0.2,
+ "4.4.4.4": 0.1 + 0.2 + 0.25,
+ }
+
+
+async def test_multi_success(autojump_clock):
+ sock, scenario = await run_scenario(
+ 80,
+ [
+ ("1.1.1.1", 0.5, "error"),
+ ("2.2.2.2", 10, "success"),
+ ("3.3.3.3", 10 - 1, "success"),
+ ("4.4.4.4", 10 - 2, "success"),
+ ("5.5.5.5", 0.5, "error"),
+ ],
+ happy_eyeballs_delay=1,
+ )
+ assert not scenario.sockets["1.1.1.1"].succeeded
+ assert (
+ scenario.sockets["2.2.2.2"].succeeded
+ or scenario.sockets["3.3.3.3"].succeeded
+ or scenario.sockets["4.4.4.4"].succeeded
+ )
+ assert not scenario.sockets["5.5.5.5"].succeeded
+ assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
+ assert trio.current_time() == (0.5 + 10)
+ assert scenario.connect_times == {
+ "1.1.1.1": 0,
+ "2.2.2.2": 0.5,
+ "3.3.3.3": 1.5,
+ "4.4.4.4": 2.5,
+ "5.5.5.5": 3.5,
+ }
+
+
+async def test_does_reorder(autojump_clock):
+ sock, scenario = await run_scenario(
+ 80,
+ [
+ ("1.1.1.1", 10, "error"),
+ # This would win if we tried it first...
+ ("2.2.2.2", 1, "success"),
+ # But in fact we try this first, because of section 5.4
+ ("::3", 0.5, "success"),
+ ],
+ happy_eyeballs_delay=1,
+ )
+ assert sock.ip == "::3"
+ assert trio.current_time() == 1 + 0.5
+ assert scenario.connect_times == {
+ "1.1.1.1": 0,
+ "::3": 1,
+ }
+
+
+async def test_handles_no_ipv4(autojump_clock):
+ sock, scenario = await run_scenario(
+ 80,
+ # Here the ipv6 addresses fail at socket creation time, so the connect
+ # configuration doesn't matter
+ [
+ ("::1", 10, "success"),
+ ("2.2.2.2", 0, "success"),
+ ("::3", 0.1, "success"),
+ ("4.4.4.4", 0, "success"),
+ ],
+ happy_eyeballs_delay=1,
+ ipv4_supported=False,
+ )
+ assert sock.ip == "::3"
+ assert trio.current_time() == 1 + 0.1
+ assert scenario.connect_times == {
+ "::1": 0,
+ "::3": 1.0,
+ }
+
+
+async def test_handles_no_ipv6(autojump_clock):
+ sock, scenario = await run_scenario(
+ 80,
+ # Here the ipv6 addresses fail at socket creation time, so the connect
+ # configuration doesn't matter
+ [
+ ("::1", 0, "success"),
+ ("2.2.2.2", 10, "success"),
+ ("::3", 0, "success"),
+ ("4.4.4.4", 0.1, "success"),
+ ],
+ happy_eyeballs_delay=1,
+ ipv6_supported=False,
+ )
+ assert sock.ip == "4.4.4.4"
+ assert trio.current_time() == 1 + 0.1
+ assert scenario.connect_times == {
+ "2.2.2.2": 0,
+ "4.4.4.4": 1.0,
+ }
+
+
+async def test_no_hosts(autojump_clock):
+ exc, scenario = await run_scenario(80, [], expect_error=OSError)
+ assert "no results found" in str(exc)
+
+
+async def test_cancel(autojump_clock):
+ with trio.move_on_after(5) as cancel_scope:
+ exc, scenario = await run_scenario(
+ 80,
+ [
+ ("1.1.1.1", 10, "success"),
+ ("2.2.2.2", 10, "success"),
+ ("3.3.3.3", 10, "success"),
+ ("4.4.4.4", 10, "success"),
+ ],
+ expect_error=BaseExceptionGroup,
+ )
+ # What comes out should be 1 or more Cancelled errors that all belong
+ # to this cancel_scope; this is the easiest way to check that
+ raise exc
+ assert cancel_scope.cancelled_caught
+
+ assert trio.current_time() == 5
+
+ # This should have been called already, but just to make sure, since the
+ # exception-handling logic in run_scenario is a bit complicated and the
+ # main thing we care about here is that all the sockets were cleaned up.
+ scenario.check(succeeded=False)
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_unix_stream.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_unix_stream.py
new file mode 100644
index 00000000..211aff3e
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_unix_stream.py
@@ -0,0 +1,67 @@
+import os
+import socket
+import tempfile
+
+import pytest
+
+from trio import open_unix_socket, Path
+from trio._highlevel_open_unix_stream import close_on_error
+
+if not hasattr(socket, "AF_UNIX"):
+ pytestmark = pytest.mark.skip("Needs unix socket support")
+
+
+def test_close_on_error():
+ class CloseMe:
+ closed = False
+
+ def close(self):
+ self.closed = True
+
+ with close_on_error(CloseMe()) as c:
+ pass
+ assert not c.closed
+
+ with pytest.raises(RuntimeError):
+ with close_on_error(CloseMe()) as c:
+ raise RuntimeError
+ assert c.closed
+
+
+@pytest.mark.parametrize("filename", [4, 4.5])
+async def test_open_with_bad_filename_type(filename):
+ with pytest.raises(TypeError):
+ await open_unix_socket(filename)
+
+
+async def test_open_bad_socket():
+ # mktemp is marked as insecure, but that's okay, we don't want the file to
+ # exist
+ name = tempfile.mktemp()
+ with pytest.raises(FileNotFoundError):
+ await open_unix_socket(name)
+
+
+async def test_open_unix_socket():
+ for name_type in [Path, str]:
+ name = tempfile.mktemp()
+ serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ with serv_sock:
+ serv_sock.bind(name)
+ try:
+ serv_sock.listen(1)
+
+ # The actual function we're testing
+ unix_socket = await open_unix_socket(name_type(name))
+
+ async with unix_socket:
+ client, _ = serv_sock.accept()
+ with client:
+ await unix_socket.send_all(b"test")
+ assert client.recv(2048) == b"test"
+
+ client.sendall(b"response")
+ received = await unix_socket.receive_some(2048)
+ assert received == b"response"
+ finally:
+ os.unlink(name)
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_serve_listeners.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_serve_listeners.py
new file mode 100644
index 00000000..b028092e
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_serve_listeners.py
@@ -0,0 +1,145 @@
+import pytest
+
+from functools import partial
+import errno
+
+import attr
+
+import trio
+from trio.testing import memory_stream_pair, wait_all_tasks_blocked
+
+
+@attr.s(hash=False, eq=False)
+class MemoryListener(trio.abc.Listener):
+ closed = attr.ib(default=False)
+ accepted_streams = attr.ib(factory=list)
+ queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1)))
+ accept_hook = attr.ib(default=None)
+
+ async def connect(self):
+ assert not self.closed
+ client, server = memory_stream_pair()
+ await self.queued_streams[0].send(server)
+ return client
+
+ async def accept(self):
+ await trio.lowlevel.checkpoint()
+ assert not self.closed
+ if self.accept_hook is not None:
+ await self.accept_hook()
+ stream = await self.queued_streams[1].receive()
+ self.accepted_streams.append(stream)
+ return stream
+
+ async def aclose(self):
+ self.closed = True
+ await trio.lowlevel.checkpoint()
+
+
+async def test_serve_listeners_basic():
+ listeners = [MemoryListener(), MemoryListener()]
+
+ record = []
+
+ def close_hook():
+ # Make sure this is a forceful close
+ assert trio.current_effective_deadline() == float("-inf")
+ record.append("closed")
+
+ async def handler(stream):
+ await stream.send_all(b"123")
+ assert await stream.receive_some(10) == b"456"
+ stream.send_stream.close_hook = close_hook
+ stream.receive_stream.close_hook = close_hook
+
+ async def client(listener):
+ s = await listener.connect()
+ assert await s.receive_some(10) == b"123"
+ await s.send_all(b"456")
+
+ async def do_tests(parent_nursery):
+ async with trio.open_nursery() as nursery:
+ for listener in listeners:
+ for _ in range(3):
+ nursery.start_soon(client, listener)
+
+ await wait_all_tasks_blocked()
+
+ # verifies that all 6 streams x 2 directions each were closed ok
+ assert len(record) == 12
+
+ parent_nursery.cancel_scope.cancel()
+
+ async with trio.open_nursery() as nursery:
+ l2 = await nursery.start(trio.serve_listeners, handler, listeners)
+ assert l2 == listeners
+ # This is just split into another function because gh-136 isn't
+ # implemented yet
+ nursery.start_soon(do_tests, nursery)
+
+ for listener in listeners:
+ assert listener.closed
+
+
+async def test_serve_listeners_accept_unrecognized_error():
+ for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
+ listener = MemoryListener()
+
+ async def raise_error():
+ raise error
+
+ listener.accept_hook = raise_error
+
+ with pytest.raises(type(error)) as excinfo:
+ await trio.serve_listeners(None, [listener])
+ assert excinfo.value is error
+
+
+async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog):
+ listener = MemoryListener()
+
+ async def raise_EMFILE():
+ raise OSError(errno.EMFILE, "out of file descriptors")
+
+ listener.accept_hook = raise_EMFILE
+
+ # It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900
+ # = 10 times total
+ with trio.move_on_after(0.950):
+ await trio.serve_listeners(None, [listener])
+
+ assert len(caplog.records) == 10
+ for record in caplog.records:
+ assert "retrying" in record.msg
+ assert record.exc_info[1].errno == errno.EMFILE
+
+
+async def test_serve_listeners_connection_nursery(autojump_clock):
+ listener = MemoryListener()
+
+ async def handler(stream):
+ await trio.sleep(1)
+
+ class Done(Exception):
+ pass
+
+ async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED):
+ async with trio.open_nursery() as nursery:
+ task_status.started(nursery)
+ await wait_all_tasks_blocked()
+ assert len(nursery.child_tasks) == 10
+ raise Done
+
+ with pytest.raises(Done):
+ async with trio.open_nursery() as nursery:
+ handler_nursery = await nursery.start(connection_watcher)
+ await nursery.start(
+ partial(
+ trio.serve_listeners,
+ handler,
+ [listener],
+ handler_nursery=handler_nursery,
+ )
+ )
+ for _ in range(10):
+ nursery.start_soon(listener.connect)
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_socket.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_socket.py
new file mode 100644
index 00000000..9dcb834d
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_socket.py
@@ -0,0 +1,267 @@
+import pytest
+
+import sys
+import socket as stdlib_socket
+import errno
+
+from .. import _core
+from ..testing import (
+ check_half_closeable_stream,
+ wait_all_tasks_blocked,
+ assert_checkpoints,
+)
+from .._highlevel_socket import *
+from .. import socket as tsocket
+
+
+async def test_SocketStream_basics():
+ # stdlib socket bad (even if connected)
+ a, b = stdlib_socket.socketpair()
+ with a, b:
+ with pytest.raises(TypeError):
+ SocketStream(a)
+
+ # DGRAM socket bad
+ with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
+ with pytest.raises(ValueError):
+ SocketStream(sock)
+
+ a, b = tsocket.socketpair()
+ with a, b:
+ s = SocketStream(a)
+ assert s.socket is a
+
+ # Use a real, connected socket to test socket options, because
+ # socketpair() might give us a unix socket that doesn't support any of
+ # these options
+ with tsocket.socket() as listen_sock:
+ await listen_sock.bind(("127.0.0.1", 0))
+ listen_sock.listen(1)
+ with tsocket.socket() as client_sock:
+ await client_sock.connect(listen_sock.getsockname())
+
+ s = SocketStream(client_sock)
+
+ # TCP_NODELAY enabled by default
+ assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
+ # We can disable it though
+ s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
+ assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
+
+ b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
+ assert isinstance(b, bytes)
+
+
+async def test_SocketStream_send_all():
+ BIG = 10000000
+
+ a_sock, b_sock = tsocket.socketpair()
+ with a_sock, b_sock:
+ a = SocketStream(a_sock)
+ b = SocketStream(b_sock)
+
+ # Check a send_all that has to be split into multiple parts (on most
+ # platforms... on Windows every send() either succeeds or fails as a
+ # whole)
+ async def sender():
+ data = bytearray(BIG)
+ await a.send_all(data)
+ # send_all uses memoryviews internally, which temporarily "lock"
+ # the object they view. If it doesn't clean them up properly, then
+ # some bytearray operations might raise an error afterwards, which
+ # would be a pretty weird and annoying side-effect to spring on
+ # users. So test that this doesn't happen, by forcing the
+ # bytearray's underlying buffer to be realloc'ed:
+ data += bytes(BIG)
+ # (Note: the above line of code doesn't do a very good job at
+ # testing anything, because:
+ # - on CPython, the refcount GC generally cleans up memoryviews
+ # for us even if we're sloppy.
+ # - on PyPy3, at least as of 5.7.0, the memoryview code and the
+ # bytearray code conspire so that resizing never fails – if
+ # resizing forces the bytearray's internal buffer to move, then
+ # all memoryview references are automagically updated (!!).
+ # See:
+ # https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
+ # But I'm leaving the test here in hopes that if this ever changes
+ # and we break our implementation of send_all, then we'll get some
+ # early warning...)
+
+ async def receiver():
+ # Make sure the sender fills up the kernel buffers and blocks
+ await wait_all_tasks_blocked()
+ nbytes = 0
+ while nbytes < BIG:
+ nbytes += len(await b.receive_some(BIG))
+ assert nbytes == BIG
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(sender)
+ nursery.start_soon(receiver)
+
+ # We know that we received BIG bytes of NULs so far. Make sure that
+ # was all the data in there.
+ await a.send_all(b"e")
+ assert await b.receive_some(10) == b"e"
+ await a.send_eof()
+ assert await b.receive_some(10) == b""
+
+
+async def fill_stream(s):
+ async def sender():
+ while True:
+ await s.send_all(b"x" * 10000)
+
+ async def waiter(nursery):
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(sender)
+ nursery.start_soon(waiter, nursery)
+
+
+async def test_SocketStream_generic():
+ async def stream_maker():
+ left, right = tsocket.socketpair()
+ return SocketStream(left), SocketStream(right)
+
+ async def clogged_stream_maker():
+ left, right = await stream_maker()
+ await fill_stream(left)
+ await fill_stream(right)
+ return left, right
+
+ await check_half_closeable_stream(stream_maker, clogged_stream_maker)
+
+
+async def test_SocketListener():
+ # Not a Trio socket
+ with stdlib_socket.socket() as s:
+ s.bind(("127.0.0.1", 0))
+ s.listen(10)
+ with pytest.raises(TypeError):
+ SocketListener(s)
+
+ # Not a SOCK_STREAM
+ with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
+ await s.bind(("127.0.0.1", 0))
+ with pytest.raises(ValueError) as excinfo:
+ SocketListener(s)
+ excinfo.match(r".*SOCK_STREAM")
+
+ # Didn't call .listen()
+ # macOS has no way to check for this, so skip testing it there.
+ if sys.platform != "darwin":
+ with tsocket.socket() as s:
+ await s.bind(("127.0.0.1", 0))
+ with pytest.raises(ValueError) as excinfo:
+ SocketListener(s)
+ excinfo.match(r".*listen")
+
+ listen_sock = tsocket.socket()
+ await listen_sock.bind(("127.0.0.1", 0))
+ listen_sock.listen(10)
+ listener = SocketListener(listen_sock)
+
+ assert listener.socket is listen_sock
+
+ client_sock = tsocket.socket()
+ await client_sock.connect(listen_sock.getsockname())
+ with assert_checkpoints():
+ server_stream = await listener.accept()
+ assert isinstance(server_stream, SocketStream)
+ assert server_stream.socket.getsockname() == listen_sock.getsockname()
+ assert server_stream.socket.getpeername() == client_sock.getsockname()
+
+ with assert_checkpoints():
+ await listener.aclose()
+
+ with assert_checkpoints():
+ await listener.aclose()
+
+ with assert_checkpoints():
+ with pytest.raises(_core.ClosedResourceError):
+ await listener.accept()
+
+ client_sock.close()
+ await server_stream.aclose()
+
+
+async def test_SocketListener_socket_closed_underfoot():
+ listen_sock = tsocket.socket()
+ await listen_sock.bind(("127.0.0.1", 0))
+ listen_sock.listen(10)
+ listener = SocketListener(listen_sock)
+
+ # Close the socket, not the listener
+ listen_sock.close()
+
+ # SocketListener gives correct error
+ with assert_checkpoints():
+ with pytest.raises(_core.ClosedResourceError):
+ await listener.accept()
+
+
+async def test_SocketListener_accept_errors():
+ class FakeSocket(tsocket.SocketType):
+ def __init__(self, events):
+ self._events = iter(events)
+
+ type = tsocket.SOCK_STREAM
+
+ # Fool the check for SO_ACCEPTCONN in SocketListener.__init__
+ def getsockopt(self, level, opt):
+ return True
+
+ def setsockopt(self, level, opt, value):
+ pass
+
+ async def accept(self):
+ await _core.checkpoint()
+ event = next(self._events)
+ if isinstance(event, BaseException):
+ raise event
+ else:
+ return event, None
+
+ fake_server_sock = FakeSocket([])
+
+ fake_listen_sock = FakeSocket(
+ [
+ OSError(errno.ECONNABORTED, "Connection aborted"),
+ OSError(errno.EPERM, "Permission denied"),
+ OSError(errno.EPROTO, "Bad protocol"),
+ fake_server_sock,
+ OSError(errno.EMFILE, "Out of file descriptors"),
+ OSError(errno.EFAULT, "attempt to write to read-only memory"),
+ OSError(errno.ENOBUFS, "out of buffers"),
+ fake_server_sock,
+ ]
+ )
+
+ l = SocketListener(fake_listen_sock)
+
+ with assert_checkpoints():
+ s = await l.accept()
+ assert s.socket is fake_server_sock
+
+ for code in [errno.EMFILE, errno.EFAULT, errno.ENOBUFS]:
+ with assert_checkpoints():
+ with pytest.raises(OSError) as excinfo:
+ await l.accept()
+ assert excinfo.value.errno == code
+
+ with assert_checkpoints():
+ s = await l.accept()
+ assert s.socket is fake_server_sock
+
+
+async def test_socket_stream_works_when_peer_has_already_closed():
+ sock_a, sock_b = tsocket.socketpair()
+ with sock_a, sock_b:
+ await sock_b.send(b"x")
+ sock_b.close()
+ stream = SocketStream(sock_a)
+ assert await stream.receive_some(1) == b"x"
+ assert await stream.receive_some(1) == b""
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_ssl_helpers.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_ssl_helpers.py
new file mode 100644
index 00000000..c00f5dc4
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_ssl_helpers.py
@@ -0,0 +1,113 @@
+import pytest
+
+from functools import partial
+
+import attr
+
+import trio
+from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP
+import trio.testing
+from .test_ssl import client_ctx, SERVER_CTX
+
+from .._highlevel_ssl_helpers import (
+ open_ssl_over_tcp_stream,
+ open_ssl_over_tcp_listeners,
+ serve_ssl_over_tcp,
+)
+
+
+async def echo_handler(stream):
+ async with stream:
+ try:
+ while True:
+ data = await stream.receive_some(10000)
+ if not data:
+ break
+ await stream.send_all(data)
+ except trio.BrokenResourceError:
+ pass
+
+
+# Resolver that always returns the given sockaddr, no matter what host/port
+# you ask for.
+@attr.s
+class FakeHostnameResolver(trio.abc.HostnameResolver):
+ sockaddr = attr.ib()
+
+ async def getaddrinfo(self, *args):
+ return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]
+
+ async def getnameinfo(self, *args): # pragma: no cover
+ raise NotImplementedError
+
+
+# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
+# noqa is needed because flake8 doesn't understand how pytest fixtures work.
+async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811
+ async with trio.open_nursery() as nursery:
+ (listener,) = await nursery.start(
+ partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1")
+ )
+ async with listener:
+ sockaddr = listener.transport_listener.socket.getsockname()
+ hostname_resolver = FakeHostnameResolver(sockaddr)
+ trio.socket.set_custom_hostname_resolver(hostname_resolver)
+
+ # We don't have the right trust set up
+ # (checks that ssl_context=None is doing some validation)
+ stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
+ async with stream:
+ with pytest.raises(trio.BrokenResourceError):
+ await stream.do_handshake()
+
+ # We have the trust but not the hostname
+ # (checks custom ssl_context + hostname checking)
+ stream = await open_ssl_over_tcp_stream(
+ "xyzzy.example.org", 80, ssl_context=client_ctx
+ )
+ async with stream:
+ with pytest.raises(trio.BrokenResourceError):
+ await stream.do_handshake()
+
+ # This one should work!
+ stream = await open_ssl_over_tcp_stream(
+ "trio-test-1.example.org", 80, ssl_context=client_ctx
+ )
+ async with stream:
+ assert isinstance(stream, trio.SSLStream)
+ assert stream.server_hostname == "trio-test-1.example.org"
+ await stream.send_all(b"x")
+ assert await stream.receive_some(1) == b"x"
+
+ # Check https_compatible settings are being passed through
+ assert not stream._https_compatible
+ stream = await open_ssl_over_tcp_stream(
+ "trio-test-1.example.org",
+ 80,
+ ssl_context=client_ctx,
+ https_compatible=True,
+ # also, smoke test happy_eyeballs_delay
+ happy_eyeballs_delay=1,
+ )
+ async with stream:
+ assert stream._https_compatible
+
+ # Stop the echo server
+ nursery.cancel_scope.cancel()
+
+
+async def test_open_ssl_over_tcp_listeners():
+ (listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
+ async with listener:
+ assert isinstance(listener, trio.SSLListener)
+ tl = listener.transport_listener
+ assert isinstance(tl, trio.SocketListener)
+ assert tl.socket.getsockname()[0] == "127.0.0.1"
+
+ assert not listener._https_compatible
+
+ (listener,) = await open_ssl_over_tcp_listeners(
+ 0, SERVER_CTX, host="127.0.0.1", https_compatible=True
+ )
+ async with listener:
+ assert listener._https_compatible
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_path.py b/venv/lib/python3.9/site-packages/trio/tests/test_path.py
new file mode 100644
index 00000000..284bcf82
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_path.py
@@ -0,0 +1,262 @@
+import os
+import pathlib
+
+import pytest
+
+import trio
+from trio._path import AsyncAutoWrapperType as Type
+from trio._file_io import AsyncIOWrapper
+
+
+@pytest.fixture
+def path(tmpdir):
+ p = str(tmpdir.join("test"))
+ return trio.Path(p)
+
+
+def method_pair(path, method_name):
+ path = pathlib.Path(path)
+ async_path = trio.Path(path)
+ return getattr(path, method_name), getattr(async_path, method_name)
+
+
+async def test_open_is_async_context_manager(path):
+ async with await path.open("w") as f:
+ assert isinstance(f, AsyncIOWrapper)
+
+ assert f.closed
+
+
+async def test_magic():
+ path = trio.Path("test")
+
+ assert str(path) == "test"
+ assert bytes(path) == b"test"
+
+
+cls_pairs = [
+ (trio.Path, pathlib.Path),
+ (pathlib.Path, trio.Path),
+ (trio.Path, trio.Path),
+]
+
+
+@pytest.mark.parametrize("cls_a,cls_b", cls_pairs)
+async def test_cmp_magic(cls_a, cls_b):
+ a, b = cls_a(""), cls_b("")
+ assert a == b
+ assert not a != b
+
+ a, b = cls_a("a"), cls_b("b")
+ assert a < b
+ assert b > a
+
+ # this is intentionally testing equivalence with none, due to the
+ # other=sentinel logic in _forward_magic
+ assert not a == None # noqa
+ assert not b == None # noqa
+
+
+# upstream python3.8 bug: we should also test (pathlib.Path, trio.Path), but
+# __*div__ does not properly raise NotImplementedError like the other comparison
+# magic, so trio.Path's implementation does not get dispatched
+cls_pairs = [
+ (trio.Path, pathlib.Path),
+ (trio.Path, trio.Path),
+ (trio.Path, str),
+ (str, trio.Path),
+]
+
+
+@pytest.mark.parametrize("cls_a,cls_b", cls_pairs)
+async def test_div_magic(cls_a, cls_b):
+ a, b = cls_a("a"), cls_b("b")
+
+ result = a / b
+ assert isinstance(result, trio.Path)
+ assert str(result) == os.path.join("a", "b")
+
+
+@pytest.mark.parametrize(
+ "cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)]
+)
+@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"])
+async def test_hash_magic(cls_a, cls_b, path):
+ a, b = cls_a(path), cls_b(path)
+ assert hash(a) == hash(b)
+
+
+async def test_forwarded_properties(path):
+ # use `name` as a representative of forwarded properties
+
+ assert "name" in dir(path)
+ assert path.name == "test"
+
+
+async def test_async_method_signature(path):
+ # use `resolve` as a representative of wrapped methods
+
+ assert path.resolve.__name__ == "resolve"
+ assert path.resolve.__qualname__ == "Path.resolve"
+
+ assert "pathlib.Path.resolve" in path.resolve.__doc__
+
+
+@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
+async def test_compare_async_stat_methods(method_name):
+
+ method, async_method = method_pair(".", method_name)
+
+ result = method()
+ async_result = await async_method()
+
+ assert result == async_result
+
+
+async def test_invalid_name_not_wrapped(path):
+ with pytest.raises(AttributeError):
+ getattr(path, "invalid_fake_attr")
+
+
+@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
+async def test_async_methods_rewrap(method_name):
+
+ method, async_method = method_pair(".", method_name)
+
+ result = method()
+ async_result = await async_method()
+
+ assert isinstance(async_result, trio.Path)
+ assert str(result) == str(async_result)
+
+
+async def test_forward_methods_rewrap(path, tmpdir):
+ with_name = path.with_name("foo")
+ with_suffix = path.with_suffix(".py")
+
+ assert isinstance(with_name, trio.Path)
+ assert with_name == tmpdir.join("foo")
+ assert isinstance(with_suffix, trio.Path)
+ assert with_suffix == tmpdir.join("test.py")
+
+
+async def test_forward_properties_rewrap(path):
+ assert isinstance(path.parent, trio.Path)
+
+
+async def test_forward_methods_without_rewrap(path, tmpdir):
+ path = await path.parent.resolve()
+
+ assert path.as_uri().startswith("file:///")
+
+
+async def test_repr():
+ path = trio.Path(".")
+
+ assert repr(path) == "trio.Path('.')"
+
+
+class MockWrapped:
+ unsupported = "unsupported"
+ _private = "private"
+
+
+class MockWrapper:
+ _forwards = MockWrapped
+ _wraps = MockWrapped
+
+
+async def test_type_forwards_unsupported():
+ with pytest.raises(TypeError):
+ Type.generate_forwards(MockWrapper, {})
+
+
+async def test_type_wraps_unsupported():
+ with pytest.raises(TypeError):
+ Type.generate_wraps(MockWrapper, {})
+
+
+async def test_type_forwards_private():
+ Type.generate_forwards(MockWrapper, {"unsupported": None})
+
+ assert not hasattr(MockWrapper, "_private")
+
+
+async def test_type_wraps_private():
+ Type.generate_wraps(MockWrapper, {"unsupported": None})
+
+ assert not hasattr(MockWrapper, "_private")
+
+
+@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath])
+async def test_path_wraps_path(path, meth):
+ wrapped = await path.absolute()
+ result = meth(path, wrapped)
+ if result is None:
+ result = path
+
+ assert wrapped == result
+
+
+async def test_path_nonpath():
+ with pytest.raises(TypeError):
+ trio.Path(1)
+
+
+async def test_open_file_can_open_path(path):
+ async with await trio.open_file(path, "w") as f:
+ assert f.name == os.fspath(path)
+
+
+async def test_globmethods(path):
+ # Populate a directory tree
+ await path.mkdir()
+ await (path / "foo").mkdir()
+ await (path / "foo" / "_bar.txt").write_bytes(b"")
+ await (path / "bar.txt").write_bytes(b"")
+ await (path / "bar.dat").write_bytes(b"")
+
+ # Path.glob
+ for _pattern, _results in {
+ "*.txt": {"bar.txt"},
+ "**/*.txt": {"_bar.txt", "bar.txt"},
+ }.items():
+ entries = set()
+ for entry in await path.glob(_pattern):
+ assert isinstance(entry, trio.Path)
+ entries.add(entry.name)
+
+ assert entries == _results
+
+ # Path.rglob
+ entries = set()
+ for entry in await path.rglob("*.txt"):
+ assert isinstance(entry, trio.Path)
+ entries.add(entry.name)
+
+ assert entries == {"_bar.txt", "bar.txt"}
+
+
+async def test_iterdir(path):
+ # Populate a directory
+ await path.mkdir()
+ await (path / "foo").mkdir()
+ await (path / "bar.txt").write_bytes(b"")
+
+ entries = set()
+ for entry in await path.iterdir():
+ assert isinstance(entry, trio.Path)
+ entries.add(entry.name)
+
+ assert entries == {"bar.txt", "foo"}
+
+
+async def test_classmethods():
+ assert isinstance(await trio.Path.home(), trio.Path)
+
+ # pathlib.Path has only two classmethods
+ assert str(await trio.Path.home()) == os.path.expanduser("~")
+ assert str(await trio.Path.cwd()) == os.getcwd()
+
+ # Wrapped method has docstring
+ assert trio.Path.home.__doc__
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_scheduler_determinism.py b/venv/lib/python3.9/site-packages/trio/tests/test_scheduler_determinism.py
new file mode 100644
index 00000000..e2d3167e
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_scheduler_determinism.py
@@ -0,0 +1,40 @@
+import trio
+
+
+async def scheduler_trace():
+ """Returns a scheduler-dependent value we can use to check determinism."""
+ trace = []
+
+ async def tracer(name):
+ for i in range(50):
+ trace.append((name, i))
+ await trio.sleep(0)
+
+ async with trio.open_nursery() as nursery:
+ for i in range(5):
+ nursery.start_soon(tracer, i)
+
+ return tuple(trace)
+
+
+def test_the_trio_scheduler_is_not_deterministic():
+ # At least, not yet. See https://github.com/python-trio/trio/issues/32
+ traces = []
+ for _ in range(10):
+ traces.append(trio.run(scheduler_trace))
+ assert len(set(traces)) == len(traces)
+
+
+def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch):
+ monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True)
+ traces = []
+ for _ in range(10):
+ state = trio._core._run._r.getstate()
+ try:
+ trio._core._run._r.seed(0)
+ traces.append(trio.run(scheduler_trace))
+ finally:
+ trio._core._run._r.setstate(state)
+
+ assert len(traces) == 10
+ assert len(set(traces)) == 1
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_signals.py b/venv/lib/python3.9/site-packages/trio/tests/test_signals.py
new file mode 100644
index 00000000..235772f9
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_signals.py
@@ -0,0 +1,177 @@
+import signal
+
+import pytest
+
+import trio
+from .. import _core
+from .._util import signal_raise
+from .._signals import open_signal_receiver, _signal_handler
+
+
+async def test_open_signal_receiver():
+ orig = signal.getsignal(signal.SIGILL)
+ with open_signal_receiver(signal.SIGILL) as receiver:
+ # Raise it a few times, to exercise signal coalescing, both at the
+ # call_soon level and at the SignalQueue level
+ signal_raise(signal.SIGILL)
+ signal_raise(signal.SIGILL)
+ await _core.wait_all_tasks_blocked()
+ signal_raise(signal.SIGILL)
+ await _core.wait_all_tasks_blocked()
+ async for signum in receiver: # pragma: no branch
+ assert signum == signal.SIGILL
+ break
+ assert receiver._pending_signal_count() == 0
+ signal_raise(signal.SIGILL)
+ async for signum in receiver: # pragma: no branch
+ assert signum == signal.SIGILL
+ break
+ assert receiver._pending_signal_count() == 0
+ with pytest.raises(RuntimeError):
+ await receiver.__anext__()
+ assert signal.getsignal(signal.SIGILL) is orig
+
+
+async def test_open_signal_receiver_restore_handler_after_one_bad_signal():
+ orig = signal.getsignal(signal.SIGILL)
+ with pytest.raises(ValueError):
+ with open_signal_receiver(signal.SIGILL, 1234567):
+ pass # pragma: no cover
+ # Still restored even if we errored out
+ assert signal.getsignal(signal.SIGILL) is orig
+
+
+async def test_open_signal_receiver_empty_fail():
+ with pytest.raises(TypeError, match="No signals were provided"):
+ with open_signal_receiver():
+ pass
+
+
+async def test_open_signal_receiver_restore_handler_after_duplicate_signal():
+ orig = signal.getsignal(signal.SIGILL)
+ with open_signal_receiver(signal.SIGILL, signal.SIGILL):
+ pass
+ # Still restored correctly
+ assert signal.getsignal(signal.SIGILL) is orig
+
+
+async def test_catch_signals_wrong_thread():
+ async def naughty():
+ with open_signal_receiver(signal.SIGINT):
+ pass # pragma: no cover
+
+ with pytest.raises(RuntimeError):
+ await trio.to_thread.run_sync(trio.run, naughty)
+
+
+async def test_open_signal_receiver_conflict():
+ with pytest.raises(trio.BusyResourceError):
+ with open_signal_receiver(signal.SIGILL) as receiver:
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(receiver.__anext__)
+ nursery.start_soon(receiver.__anext__)
+
+
+# Blocks until all previous calls to run_sync_soon(idempotent=True) have been
+# processed.
+async def wait_run_sync_soon_idempotent_queue_barrier():
+ ev = trio.Event()
+ token = _core.current_trio_token()
+ token.run_sync_soon(ev.set, idempotent=True)
+ await ev.wait()
+
+
+async def test_open_signal_receiver_no_starvation():
+ # Set up a situation where there are always 2 pending signals available to
+ # report, and make sure that instead of getting the same signal reported
+ # over and over, it alternates between reporting both of them.
+ with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
+ try:
+ print(signal.getsignal(signal.SIGILL))
+ previous = None
+ for _ in range(10):
+ signal_raise(signal.SIGILL)
+ signal_raise(signal.SIGFPE)
+ await wait_run_sync_soon_idempotent_queue_barrier()
+ if previous is None:
+ previous = await receiver.__anext__()
+ else:
+ got = await receiver.__anext__()
+ assert got in [signal.SIGILL, signal.SIGFPE]
+ assert got != previous
+ previous = got
+ # Clear out the last signal so it doesn't get redelivered
+ while receiver._pending_signal_count() != 0:
+ await receiver.__anext__()
+ except: # pragma: no cover
+ # If there's an unhandled exception above, then exiting the
+ # open_signal_receiver block might cause the signal to be
+ # redelivered and give us a core dump instead of a traceback...
+ import traceback
+
+ traceback.print_exc()
+
+
+async def test_catch_signals_race_condition_on_exit():
+ delivered_directly = set()
+
+ def direct_handler(signo, frame):
+ delivered_directly.add(signo)
+
+ print(1)
+ # Test the version where the call_soon *doesn't* have a chance to run
+ # before we exit the with block:
+ with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
+ with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
+ signal_raise(signal.SIGILL)
+ signal_raise(signal.SIGFPE)
+ await wait_run_sync_soon_idempotent_queue_barrier()
+ assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
+ delivered_directly.clear()
+
+ print(2)
+ # Test the version where the call_soon *does* have a chance to run before
+ # we exit the with block:
+ with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
+ with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
+ signal_raise(signal.SIGILL)
+ signal_raise(signal.SIGFPE)
+ await wait_run_sync_soon_idempotent_queue_barrier()
+ assert receiver._pending_signal_count() == 2
+ assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
+ delivered_directly.clear()
+
+ # Again, but with a SIG_IGN signal:
+
+ print(3)
+ with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
+ with open_signal_receiver(signal.SIGILL) as receiver:
+ signal_raise(signal.SIGILL)
+ await wait_run_sync_soon_idempotent_queue_barrier()
+ # test passes if the process reaches this point without dying
+
+ print(4)
+ with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
+ with open_signal_receiver(signal.SIGILL) as receiver:
+ signal_raise(signal.SIGILL)
+ await wait_run_sync_soon_idempotent_queue_barrier()
+ assert receiver._pending_signal_count() == 1
+ # test passes if the process reaches this point without dying
+
+ # Check exception chaining if there are multiple exception-raising
+ # handlers
+ def raise_handler(signum, _):
+ raise RuntimeError(signum)
+
+ with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler):
+ with pytest.raises(RuntimeError) as excinfo:
+ with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
+ signal_raise(signal.SIGILL)
+ signal_raise(signal.SIGFPE)
+ await wait_run_sync_soon_idempotent_queue_barrier()
+ assert receiver._pending_signal_count() == 2
+ exc = excinfo.value
+ signums = {exc.args[0]}
+ assert isinstance(exc.__context__, RuntimeError)
+ signums.add(exc.__context__.args[0])
+ assert signums == {signal.SIGILL, signal.SIGFPE}
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_socket.py b/venv/lib/python3.9/site-packages/trio/tests/test_socket.py
new file mode 100644
index 00000000..1fa3721f
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_socket.py
@@ -0,0 +1,1017 @@
+import errno
+
+import pytest
+import attr
+
+import os
+import socket as stdlib_socket
+import inspect
+import tempfile
+import sys as _sys
+from .._core.tests.tutil import creates_ipv6, binds_ipv6
+from .. import _core
+from .. import _socket as _tsocket
+from .. import socket as tsocket
+from .._socket import _NUMERIC_ONLY, _try_sync
+from ..testing import assert_checkpoints, wait_all_tasks_blocked
+
+################################################################
+# utils
+################################################################
+
+
+class MonkeypatchedGAI:
+ def __init__(self, orig_getaddrinfo):
+ self._orig_getaddrinfo = orig_getaddrinfo
+ self._responses = {}
+ self.record = []
+
+ # get a normalized getaddrinfo argument tuple
+ def _frozenbind(self, *args, **kwargs):
+ sig = inspect.signature(self._orig_getaddrinfo)
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+ frozenbound = bound.args
+ assert not bound.kwargs
+ return frozenbound
+
+ def set(self, response, *args, **kwargs):
+ self._responses[self._frozenbind(*args, **kwargs)] = response
+
+ def getaddrinfo(self, *args, **kwargs):
+ bound = self._frozenbind(*args, **kwargs)
+ self.record.append(bound)
+ if bound in self._responses:
+ return self._responses[bound]
+ elif bound[-1] & stdlib_socket.AI_NUMERICHOST:
+ return self._orig_getaddrinfo(*args, **kwargs)
+ else:
+ raise RuntimeError("gai called with unexpected arguments {}".format(bound))
+
+
+@pytest.fixture
+def monkeygai(monkeypatch):
+ controller = MonkeypatchedGAI(stdlib_socket.getaddrinfo)
+ monkeypatch.setattr(stdlib_socket, "getaddrinfo", controller.getaddrinfo)
+ return controller
+
+
+async def test__try_sync():
+ with assert_checkpoints():
+ async with _try_sync():
+ pass
+
+ with assert_checkpoints():
+ with pytest.raises(KeyError):
+ async with _try_sync():
+ raise KeyError
+
+ async with _try_sync():
+ raise BlockingIOError
+
+ def _is_ValueError(exc):
+ return isinstance(exc, ValueError)
+
+ async with _try_sync(_is_ValueError):
+ raise ValueError
+
+ with assert_checkpoints():
+ with pytest.raises(BlockingIOError):
+ async with _try_sync(_is_ValueError):
+ raise BlockingIOError
+
+
+################################################################
+# basic re-exports
+################################################################
+
+
+def test_socket_has_some_reexports():
+ assert tsocket.SOL_SOCKET == stdlib_socket.SOL_SOCKET
+ assert tsocket.TCP_NODELAY == stdlib_socket.TCP_NODELAY
+ assert tsocket.gaierror == stdlib_socket.gaierror
+ assert tsocket.ntohs == stdlib_socket.ntohs
+
+
+################################################################
+# name resolution
+################################################################
+
+
+async def test_getaddrinfo(monkeygai):
+ def check(got, expected):
+ # win32 returns 0 for the proto field
+ # musl and glibc have inconsistent handling of the canonical name
+ # field (https://github.com/python-trio/trio/issues/1499)
+ # Neither field gets used much and there isn't much opportunity for us
+ # to mess them up, so we don't bother checking them here
+ def interesting_fields(gai_tup):
+ # (family, type, proto, canonname, sockaddr)
+ family, type, proto, canonname, sockaddr = gai_tup
+ return (family, type, sockaddr)
+
+ def filtered(gai_list):
+ return [interesting_fields(gai_tup) for gai_tup in gai_list]
+
+ assert filtered(got) == filtered(expected)
+
+ # Simple non-blocking non-error cases, ipv4 and ipv6:
+ with assert_checkpoints():
+ res = await tsocket.getaddrinfo("127.0.0.1", "12345", type=tsocket.SOCK_STREAM)
+
+ check(
+ res,
+ [
+ (
+ tsocket.AF_INET, # 127.0.0.1 is ipv4
+ tsocket.SOCK_STREAM,
+ tsocket.IPPROTO_TCP,
+ "",
+ ("127.0.0.1", 12345),
+ ),
+ ],
+ )
+
+ with assert_checkpoints():
+ res = await tsocket.getaddrinfo("::1", "12345", type=tsocket.SOCK_DGRAM)
+ check(
+ res,
+ [
+ (
+ tsocket.AF_INET6,
+ tsocket.SOCK_DGRAM,
+ tsocket.IPPROTO_UDP,
+ "",
+ ("::1", 12345, 0, 0),
+ ),
+ ],
+ )
+
+ monkeygai.set("x", b"host", "port", family=0, type=0, proto=0, flags=0)
+ with assert_checkpoints():
+ res = await tsocket.getaddrinfo("host", "port")
+ assert res == "x"
+ assert monkeygai.record[-1] == (b"host", "port", 0, 0, 0, 0)
+
+ # check raising an error from a non-blocking getaddrinfo
+ with assert_checkpoints():
+ with pytest.raises(tsocket.gaierror) as excinfo:
+ await tsocket.getaddrinfo("::1", "12345", type=-1)
+ # Linux + glibc, Windows
+ expected_errnos = {tsocket.EAI_SOCKTYPE}
+ # Linux + musl
+ expected_errnos.add(tsocket.EAI_SERVICE)
+ # macOS
+ if hasattr(tsocket, "EAI_BADHINTS"):
+ expected_errnos.add(tsocket.EAI_BADHINTS)
+ assert excinfo.value.errno in expected_errnos
+
+ # check raising an error from a blocking getaddrinfo (exploits the fact
+ # that monkeygai raises if it gets a non-numeric request it hasn't been
+ # given an answer for)
+ with assert_checkpoints():
+ with pytest.raises(RuntimeError):
+ await tsocket.getaddrinfo("asdf", "12345")
+
+
+async def test_getnameinfo():
+ # Trivial test:
+ ni_numeric = stdlib_socket.NI_NUMERICHOST | stdlib_socket.NI_NUMERICSERV
+ with assert_checkpoints():
+ got = await tsocket.getnameinfo(("127.0.0.1", 1234), ni_numeric)
+ assert got == ("127.0.0.1", "1234")
+
+ # getnameinfo requires a numeric address as input:
+ with assert_checkpoints():
+ with pytest.raises(tsocket.gaierror):
+ await tsocket.getnameinfo(("google.com", 80), 0)
+
+ with assert_checkpoints():
+ with pytest.raises(tsocket.gaierror):
+ await tsocket.getnameinfo(("localhost", 80), 0)
+
+ # Blocking call to get expected values:
+ host, service = stdlib_socket.getnameinfo(("127.0.0.1", 80), 0)
+
+ # Some working calls:
+ got = await tsocket.getnameinfo(("127.0.0.1", 80), 0)
+ assert got == (host, service)
+
+ got = await tsocket.getnameinfo(("127.0.0.1", 80), tsocket.NI_NUMERICHOST)
+ assert got == ("127.0.0.1", service)
+
+ got = await tsocket.getnameinfo(("127.0.0.1", 80), tsocket.NI_NUMERICSERV)
+ assert got == (host, "80")
+
+
+################################################################
+# constructors
+################################################################
+
+
+async def test_from_stdlib_socket():
+ sa, sb = stdlib_socket.socketpair()
+ assert not isinstance(sa, tsocket.SocketType)
+ with sa, sb:
+ ta = tsocket.from_stdlib_socket(sa)
+ assert isinstance(ta, tsocket.SocketType)
+ assert sa.fileno() == ta.fileno()
+ await ta.send(b"x")
+ assert sb.recv(1) == b"x"
+
+ # rejects other types
+ with pytest.raises(TypeError):
+ tsocket.from_stdlib_socket(1)
+
+ class MySocket(stdlib_socket.socket):
+ pass
+
+ with MySocket() as mysock:
+ with pytest.raises(TypeError):
+ tsocket.from_stdlib_socket(mysock)
+
+
+async def test_from_fd():
+ sa, sb = stdlib_socket.socketpair()
+ ta = tsocket.fromfd(sa.fileno(), sa.family, sa.type, sa.proto)
+ with sa, sb, ta:
+ assert ta.fileno() != sa.fileno()
+ await ta.send(b"x")
+ assert sb.recv(3) == b"x"
+
+
+async def test_socketpair_simple():
+ async def child(sock):
+ print("sending hello")
+ await sock.send(b"h")
+ assert await sock.recv(1) == b"h"
+
+ a, b = tsocket.socketpair()
+ with a, b:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(child, a)
+ nursery.start_soon(child, b)
+
+
+@pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only")
+async def test_fromshare():
+ a, b = tsocket.socketpair()
+ with a, b:
+ # share with ourselves
+ shared = a.share(os.getpid())
+ a2 = tsocket.fromshare(shared)
+ with a2:
+ assert a.fileno() != a2.fileno()
+ await a2.send(b"x")
+ assert await b.recv(1) == b"x"
+
+
+async def test_socket():
+ with tsocket.socket() as s:
+ assert isinstance(s, tsocket.SocketType)
+ assert s.family == tsocket.AF_INET
+
+
+@creates_ipv6
+async def test_socket_v6():
+ with tsocket.socket(tsocket.AF_INET6, tsocket.SOCK_DGRAM) as s:
+ assert isinstance(s, tsocket.SocketType)
+ assert s.family == tsocket.AF_INET6
+
+
+@pytest.mark.skipif(not _sys.platform == "linux", reason="linux only")
+async def test_sniff_sockopts():
+ from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM
+
+ # generate the combinations of families/types we're testing:
+ sockets = []
+ for family in [AF_INET, AF_INET6]:
+ for type in [SOCK_DGRAM, SOCK_STREAM]:
+ sockets.append(stdlib_socket.socket(family, type))
+ for socket in sockets:
+ # regular Trio socket constructor
+ tsocket_socket = tsocket.socket(fileno=socket.fileno())
+ # check family / type for correctness:
+ assert tsocket_socket.family == socket.family
+ assert tsocket_socket.type == socket.type
+ tsocket_socket.detach()
+
+ # fromfd constructor
+ tsocket_from_fd = tsocket.fromfd(socket.fileno(), AF_INET, SOCK_STREAM)
+ # check family / type for correctness:
+ assert tsocket_from_fd.family == socket.family
+ assert tsocket_from_fd.type == socket.type
+ tsocket_from_fd.close()
+
+ socket.close()
+
+
+################################################################
+# _SocketType
+################################################################
+
+
+async def test_SocketType_basics():
+ sock = tsocket.socket()
+ with sock as cm_enter_value:
+ assert cm_enter_value is sock
+ assert isinstance(sock.fileno(), int)
+ assert not sock.get_inheritable()
+ sock.set_inheritable(True)
+ assert sock.get_inheritable()
+
+ sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
+ assert not sock.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
+ sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, True)
+ assert sock.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
+ # closed sockets have fileno() == -1
+ assert sock.fileno() == -1
+
+ # smoke test
+ repr(sock)
+
+ # detach
+ with tsocket.socket() as sock:
+ fd = sock.fileno()
+ assert sock.detach() == fd
+ assert sock.fileno() == -1
+
+ # close
+ sock = tsocket.socket()
+ assert sock.fileno() >= 0
+ sock.close()
+ assert sock.fileno() == -1
+
+ # share was tested above together with fromshare
+
+ # check __dir__
+ assert "family" in dir(sock)
+ assert "recv" in dir(sock)
+ assert "setsockopt" in dir(sock)
+
+ # our __getattr__ handles unknown names
+ with pytest.raises(AttributeError):
+ sock.asdf
+
+ # type family proto
+ stdlib_sock = stdlib_socket.socket()
+ sock = tsocket.from_stdlib_socket(stdlib_sock)
+ assert sock.type == _tsocket.real_socket_type(stdlib_sock.type)
+ assert sock.family == stdlib_sock.family
+ assert sock.proto == stdlib_sock.proto
+ sock.close()
+
+
+async def test_SocketType_dup():
+ a, b = tsocket.socketpair()
+ with a, b:
+ a2 = a.dup()
+ with a2:
+ assert isinstance(a2, tsocket.SocketType)
+ assert a2.fileno() != a.fileno()
+ a.close()
+ await a2.send(b"x")
+ assert await b.recv(1) == b"x"
+
+
+async def test_SocketType_shutdown():
+ a, b = tsocket.socketpair()
+ with a, b:
+ await a.send(b"x")
+ assert await b.recv(1) == b"x"
+ assert not a.did_shutdown_SHUT_WR
+ assert not b.did_shutdown_SHUT_WR
+ a.shutdown(tsocket.SHUT_WR)
+ assert a.did_shutdown_SHUT_WR
+ assert not b.did_shutdown_SHUT_WR
+ assert await b.recv(1) == b""
+ await b.send(b"y")
+ assert await a.recv(1) == b"y"
+
+ a, b = tsocket.socketpair()
+ with a, b:
+ assert not a.did_shutdown_SHUT_WR
+ a.shutdown(tsocket.SHUT_RD)
+ assert not a.did_shutdown_SHUT_WR
+
+ a, b = tsocket.socketpair()
+ with a, b:
+ assert not a.did_shutdown_SHUT_WR
+ a.shutdown(tsocket.SHUT_RDWR)
+ assert a.did_shutdown_SHUT_WR
+
+
+@pytest.mark.parametrize(
+ "address, socket_type",
+ [
+ ("127.0.0.1", tsocket.AF_INET),
+ pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6),
+ ],
+)
+async def test_SocketType_simple_server(address, socket_type):
+ # listen, bind, accept, connect, getpeername, getsockname
+ listener = tsocket.socket(socket_type)
+ client = tsocket.socket(socket_type)
+ with listener, client:
+ await listener.bind((address, 0))
+ listener.listen(20)
+ addr = listener.getsockname()[:2]
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.connect, addr)
+ server, client_addr = await listener.accept()
+ with server:
+ assert client_addr == server.getpeername() == client.getsockname()
+ await server.send(b"x")
+ assert await client.recv(1) == b"x"
+
+
+async def test_SocketType_is_readable():
+ a, b = tsocket.socketpair()
+ with a, b:
+ assert not a.is_readable()
+ await b.send(b"x")
+ await _core.wait_readable(a)
+ assert a.is_readable()
+ assert await a.recv(1) == b"x"
+ assert not a.is_readable()
+
+
+# On some macOS systems, getaddrinfo likes to return V4-mapped addresses even
+# when we *don't* pass AI_V4MAPPED.
+# https://github.com/python-trio/trio/issues/580
+def gai_without_v4mapped_is_buggy(): # pragma: no cover
+ try:
+ stdlib_socket.getaddrinfo("1.2.3.4", 0, family=stdlib_socket.AF_INET6)
+ except stdlib_socket.gaierror:
+ return False
+ else:
+ return True
+
+
+@attr.s
+class Addresses:
+ bind_all = attr.ib()
+ localhost = attr.ib()
+ arbitrary = attr.ib()
+ broadcast = attr.ib()
+
+
+# Direct thorough tests of the implicit resolver helpers
+@pytest.mark.parametrize(
+ "socket_type, addrs",
+ [
+ (
+ tsocket.AF_INET,
+ Addresses(
+ bind_all="0.0.0.0",
+ localhost="127.0.0.1",
+ arbitrary="1.2.3.4",
+ broadcast="255.255.255.255",
+ ),
+ ),
+ pytest.param(
+ tsocket.AF_INET6,
+ Addresses(
+ bind_all="::",
+ localhost="::1",
+ arbitrary="1::2",
+ broadcast="::ffff:255.255.255.255",
+ ),
+ marks=creates_ipv6,
+ ),
+ ],
+)
+async def test_SocketType_resolve(socket_type, addrs):
+ v6 = socket_type == tsocket.AF_INET6
+
+ def pad(addr):
+ if v6:
+ while len(addr) < 4:
+ addr += (0,)
+ return addr
+
+ def assert_eq(actual, expected):
+ assert pad(expected) == pad(actual)
+
+ with tsocket.socket(family=socket_type) as sock:
+ # For some reason the stdlib special-cases "" to pass NULL to
+ # getaddrinfo. They also error out on None, but whatever, None is much
+ # more consistent, so we accept it too.
+ for null in [None, ""]:
+ got = await sock._resolve_address_nocp((null, 80), local=True)
+ assert_eq(got, (addrs.bind_all, 80))
+ got = await sock._resolve_address_nocp((null, 80), local=False)
+ assert_eq(got, (addrs.localhost, 80))
+
+ # AI_PASSIVE only affects the wildcard address, so for everything else
+ # local=True/local=False should work the same:
+ for local in [False, True]:
+
+ async def res(*args):
+ return await sock._resolve_address_nocp(*args, local=local)
+
+ assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80))
+ if v6:
+ # Check handling of different length ipv6 address tuples
+ assert_eq(await res(("1::2", 80)), ("1::2", 80, 0, 0))
+ assert_eq(await res(("1::2", 80, 0)), ("1::2", 80, 0, 0))
+ assert_eq(await res(("1::2", 80, 0, 0)), ("1::2", 80, 0, 0))
+ # Non-zero flowinfo/scopeid get passed through
+ assert_eq(await res(("1::2", 80, 1)), ("1::2", 80, 1, 0))
+ assert_eq(await res(("1::2", 80, 1, 2)), ("1::2", 80, 1, 2))
+
+ # And again with a string port, as a trick to avoid the
+ # already-resolved address fastpath and make sure we call
+ # getaddrinfo
+ assert_eq(await res(("1::2", "80")), ("1::2", 80, 0, 0))
+ assert_eq(await res(("1::2", "80", 0)), ("1::2", 80, 0, 0))
+ assert_eq(await res(("1::2", "80", 0, 0)), ("1::2", 80, 0, 0))
+ assert_eq(await res(("1::2", "80", 1)), ("1::2", 80, 1, 0))
+ assert_eq(await res(("1::2", "80", 1, 2)), ("1::2", 80, 1, 2))
+
+ # V4 mapped addresses resolved if V6ONLY is False
+ sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, False)
+ assert_eq(await res(("1.2.3.4", "http")), ("::ffff:1.2.3.4", 80))
+
+ # Check the <broadcast> special case, because why not
+ assert_eq(await res(("<broadcast>", 123)), (addrs.broadcast, 123))
+
+ # But not if it's true (at least on systems where getaddrinfo works
+ # correctly)
+ if v6 and not gai_without_v4mapped_is_buggy():
+ sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, True)
+ with pytest.raises(tsocket.gaierror) as excinfo:
+ await res(("1.2.3.4", 80))
+ # Windows, macOS
+ expected_errnos = {tsocket.EAI_NONAME}
+ # Linux
+ if hasattr(tsocket, "EAI_ADDRFAMILY"):
+ expected_errnos.add(tsocket.EAI_ADDRFAMILY)
+ assert excinfo.value.errno in expected_errnos
+
+ # A family where we know nothing about the addresses, so should just
+ # pass them through. This should work on Linux, which is enough to
+ # smoke test the basic functionality...
+ try:
+ netlink_sock = tsocket.socket(
+ family=tsocket.AF_NETLINK, type=tsocket.SOCK_DGRAM
+ )
+ except (AttributeError, OSError):
+ pass
+ else:
+ assert (
+ await netlink_sock._resolve_address_nocp("asdf", local=local)
+ == "asdf"
+ )
+ netlink_sock.close()
+
+ with pytest.raises(ValueError):
+ await res("1.2.3.4")
+ with pytest.raises(ValueError):
+ await res(("1.2.3.4",))
+ with pytest.raises(ValueError):
+ if v6:
+ await res(("1.2.3.4", 80, 0, 0, 0))
+ else:
+ await res(("1.2.3.4", 80, 0, 0))
+
+
+async def test_SocketType_unresolved_names():
+ with tsocket.socket() as sock:
+ await sock.bind(("localhost", 0))
+ assert sock.getsockname()[0] == "127.0.0.1"
+ sock.listen(10)
+
+ with tsocket.socket() as sock2:
+ await sock2.connect(("localhost", sock.getsockname()[1]))
+ assert sock2.getpeername() == sock.getsockname()
+
+ # check gaierror propagates out
+ with tsocket.socket() as sock:
+ with pytest.raises(tsocket.gaierror):
+ # definitely not a valid request
+ await sock.bind(("1.2:3", -1))
+
+
+# This tests all the complicated paths through _nonblocking_helper, using recv
+# as a stand-in for all the methods that use _nonblocking_helper.
+async def test_SocketType_non_blocking_paths():
+ a, b = stdlib_socket.socketpair()
+ with a, b:
+ ta = tsocket.from_stdlib_socket(a)
+ b.setblocking(False)
+
+ # cancel before even calling
+ b.send(b"1")
+ with _core.CancelScope() as cscope:
+ cscope.cancel()
+ with assert_checkpoints():
+ with pytest.raises(_core.Cancelled):
+ await ta.recv(10)
+ # immediate success (also checks that the previous attempt didn't
+ # actually read anything)
+ with assert_checkpoints():
+ await ta.recv(10) == b"1"
+ # immediate failure
+ with assert_checkpoints():
+ with pytest.raises(TypeError):
+ await ta.recv("haha")
+ # block then succeed
+
+ async def do_successful_blocking_recv():
+ with assert_checkpoints():
+ assert await ta.recv(10) == b"2"
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(do_successful_blocking_recv)
+ await wait_all_tasks_blocked()
+ b.send(b"2")
+ # block then cancelled
+
+ async def do_cancelled_blocking_recv():
+ with assert_checkpoints():
+ with pytest.raises(_core.Cancelled):
+ await ta.recv(10)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(do_cancelled_blocking_recv)
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+ # Okay, here's the trickiest one: we want to exercise the path where
+ # the task is signaled to wake, goes to recv, but then the recv fails,
+ # so it has to go back to sleep and try again. Strategy: have two
+ # tasks waiting on two sockets (to work around the rule against having
+ # two tasks waiting on the same socket), wake them both up at the same
+ # time, and whichever one runs first "steals" the data from the
+ # other:
+ tb = tsocket.from_stdlib_socket(b)
+
+ async def t1():
+ with assert_checkpoints():
+ assert await ta.recv(1) == b"a"
+ with assert_checkpoints():
+ assert await tb.recv(1) == b"b"
+
+ async def t2():
+ with assert_checkpoints():
+ assert await tb.recv(1) == b"b"
+ with assert_checkpoints():
+ assert await ta.recv(1) == b"a"
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(t1)
+ nursery.start_soon(t2)
+ await wait_all_tasks_blocked()
+ a.send(b"b")
+ b.send(b"a")
+ await wait_all_tasks_blocked()
+ a.send(b"b")
+ b.send(b"a")
+
+
+# This tests the complicated paths through connect
+async def test_SocketType_connect_paths():
+ with tsocket.socket() as sock:
+ with pytest.raises(ValueError):
+ # Should be a tuple
+ await sock.connect("localhost")
+
+ # cancelled before we start
+ with tsocket.socket() as sock:
+ with _core.CancelScope() as cancel_scope:
+ cancel_scope.cancel()
+ with pytest.raises(_core.Cancelled):
+ await sock.connect(("127.0.0.1", 80))
+
+ # Cancelled in between the connect() call and the connect completing
+ with _core.CancelScope() as cancel_scope:
+ with tsocket.socket() as sock, tsocket.socket() as listener:
+ await listener.bind(("127.0.0.1", 0))
+ listener.listen()
+
+ # Swap in our weird subclass under the trio.socket._SocketType's
+ # nose -- and then swap it back out again before we hit
+ # wait_socket_writable, which insists on a real socket.
+ class CancelSocket(stdlib_socket.socket):
+ def connect(self, *args, **kwargs):
+ cancel_scope.cancel()
+ sock._sock = stdlib_socket.fromfd(
+ self.detach(), self.family, self.type
+ )
+ sock._sock.connect(*args, **kwargs)
+ # If connect *doesn't* raise, then pretend it did
+ raise BlockingIOError # pragma: no cover
+
+ sock._sock.close()
+ sock._sock = CancelSocket()
+
+ with assert_checkpoints():
+ with pytest.raises(_core.Cancelled):
+ await sock.connect(listener.getsockname())
+ assert sock.fileno() == -1
+
+ # Failed connect (hopefully after raising BlockingIOError)
+ with tsocket.socket() as sock:
+ with pytest.raises(OSError):
+ # TCP port 2 is not assigned. Pretty sure nothing will be
+ # listening there. (We used to bind a port and then *not* call
+ # listen() to ensure nothing was listening there, but it turns
+ # out on macOS if you do this it takes 30 seconds for the
+ # connect to fail. Really. Also if you use a non-routable
+ # address. This way fails instantly though. As long as nothing
+ # is listening on port 2.)
+ await sock.connect(("127.0.0.1", 2))
+
+
+async def test_resolve_address_exception_in_connect_closes_socket():
+ # Here we are testing issue 247, any cancellation will leave the socket closed
+ with _core.CancelScope() as cancel_scope:
+ with tsocket.socket() as sock:
+
+ async def _resolve_address_nocp(self, *args, **kwargs):
+ cancel_scope.cancel()
+ await _core.checkpoint()
+
+ sock._resolve_address_nocp = _resolve_address_nocp
+ with assert_checkpoints():
+ with pytest.raises(_core.Cancelled):
+ await sock.connect("")
+ assert sock.fileno() == -1
+
+
+async def test_send_recv_variants():
+ a, b = tsocket.socketpair()
+ with a, b:
+ # recv, including with flags
+ assert await a.send(b"x") == 1
+ assert await b.recv(10, tsocket.MSG_PEEK) == b"x"
+ assert await b.recv(10) == b"x"
+
+ # recv_into
+ await a.send(b"x")
+ buf = bytearray(10)
+ await b.recv_into(buf)
+ assert buf == b"x" + b"\x00" * 9
+
+ if hasattr(a, "sendmsg"):
+ assert await a.sendmsg([b"xxx"], []) == 3
+ assert await b.recv(10) == b"xxx"
+
+ a = tsocket.socket(type=tsocket.SOCK_DGRAM)
+ b = tsocket.socket(type=tsocket.SOCK_DGRAM)
+ with a, b:
+ await a.bind(("127.0.0.1", 0))
+ await b.bind(("127.0.0.1", 0))
+
+ targets = [b.getsockname(), ("localhost", b.getsockname()[1])]
+
+ # recvfrom + sendto, with and without names
+ for target in targets:
+ assert await a.sendto(b"xxx", target) == 3
+ (data, addr) = await b.recvfrom(10)
+ assert data == b"xxx"
+ assert addr == a.getsockname()
+
+ # sendto + flags
+ #
+ # I can't find any flags that send() accepts... on Linux at least
+ # passing MSG_MORE to send_some on a connected UDP socket seems to
+ # just be ignored.
+ #
+ # But there's no MSG_MORE on Windows or macOS. I guess send_some flags
+ # are really not very useful, but at least this tests them a bit.
+ if hasattr(tsocket, "MSG_MORE"):
+ await a.sendto(b"xxx", tsocket.MSG_MORE, b.getsockname())
+ await a.sendto(b"yyy", tsocket.MSG_MORE, b.getsockname())
+ await a.sendto(b"zzz", b.getsockname())
+ (data, addr) = await b.recvfrom(10)
+ assert data == b"xxxyyyzzz"
+ assert addr == a.getsockname()
+
+ # recvfrom_into
+ assert await a.sendto(b"xxx", b.getsockname()) == 3
+ buf = bytearray(10)
+ (nbytes, addr) = await b.recvfrom_into(buf)
+ assert nbytes == 3
+ assert buf == b"xxx" + b"\x00" * 7
+ assert addr == a.getsockname()
+
+ if hasattr(b, "recvmsg"):
+ assert await a.sendto(b"xxx", b.getsockname()) == 3
+ (data, ancdata, msg_flags, addr) = await b.recvmsg(10)
+ assert data == b"xxx"
+ assert ancdata == []
+ assert msg_flags == 0
+ assert addr == a.getsockname()
+
+ if hasattr(b, "recvmsg_into"):
+ assert await a.sendto(b"xyzw", b.getsockname()) == 4
+ buf1 = bytearray(2)
+ buf2 = bytearray(3)
+ ret = await b.recvmsg_into([buf1, buf2])
+ (nbytes, ancdata, msg_flags, addr) = ret
+ assert nbytes == 4
+ assert buf1 == b"xy"
+ assert buf2 == b"zw" + b"\x00"
+ assert ancdata == []
+ assert msg_flags == 0
+ assert addr == a.getsockname()
+
+ if hasattr(a, "sendmsg"):
+ for target in targets:
+ assert await a.sendmsg([b"x", b"yz"], [], 0, target) == 3
+ assert await b.recvfrom(10) == (b"xyz", a.getsockname())
+
+ a = tsocket.socket(type=tsocket.SOCK_DGRAM)
+ b = tsocket.socket(type=tsocket.SOCK_DGRAM)
+ with a, b:
+ await b.bind(("127.0.0.1", 0))
+ await a.connect(b.getsockname())
+ # send on a connected udp socket; each call creates a separate
+ # datagram
+ await a.send(b"xxx")
+ await a.send(b"yyy")
+ assert await b.recv(10) == b"xxx"
+ assert await b.recv(10) == b"yyy"
+
+
+async def test_idna(monkeygai):
+ # This is the encoding for "faß.de", which uses one of the characters that
+ # IDNA 2003 handles incorrectly:
+ monkeygai.set("ok faß.de", b"xn--fa-hia.de", 80)
+ monkeygai.set("ok ::1", "::1", 80, flags=_NUMERIC_ONLY)
+ monkeygai.set("ok ::1", b"::1", 80, flags=_NUMERIC_ONLY)
+ # Some things that should not reach the underlying socket.getaddrinfo:
+ monkeygai.set("bad", "fass.de", 80)
+ # We always call socket.getaddrinfo with bytes objects:
+ monkeygai.set("bad", "xn--fa-hia.de", 80)
+
+ assert "ok ::1" == await tsocket.getaddrinfo("::1", 80)
+ assert "ok ::1" == await tsocket.getaddrinfo(b"::1", 80)
+ assert "ok faß.de" == await tsocket.getaddrinfo("faß.de", 80)
+ assert "ok faß.de" == await tsocket.getaddrinfo("xn--fa-hia.de", 80)
+ assert "ok faß.de" == await tsocket.getaddrinfo(b"xn--fa-hia.de", 80)
+
+
+async def test_getprotobyname():
+ # These are the constants used in IP header fields, so the numeric values
+ # had *better* be stable across systems...
+ assert await tsocket.getprotobyname("udp") == 17
+ assert await tsocket.getprotobyname("tcp") == 6
+
+
+async def test_custom_hostname_resolver(monkeygai):
+ class CustomResolver:
+ async def getaddrinfo(self, host, port, family, type, proto, flags):
+ return ("custom_gai", host, port, family, type, proto, flags)
+
+ async def getnameinfo(self, sockaddr, flags):
+ return ("custom_gni", sockaddr, flags)
+
+ cr = CustomResolver()
+
+ assert tsocket.set_custom_hostname_resolver(cr) is None
+
+ # Check that the arguments are all getting passed through.
+ # We have to use valid calls to avoid making the underlying system
+ # getaddrinfo cranky when it's used for NUMERIC checks.
+ for vals in [
+ (tsocket.AF_INET, 0, 0, 0),
+ (0, tsocket.SOCK_STREAM, 0, 0),
+ (0, 0, tsocket.IPPROTO_TCP, 0),
+ (0, 0, 0, tsocket.AI_CANONNAME),
+ ]:
+ assert await tsocket.getaddrinfo("localhost", "foo", *vals) == (
+ "custom_gai",
+ b"localhost",
+ "foo",
+ *vals,
+ )
+
+ # IDNA encoding is handled before calling the special object
+ got = await tsocket.getaddrinfo("föö", "foo")
+ expected = ("custom_gai", b"xn--f-1gaa", "foo", 0, 0, 0, 0)
+ assert got == expected
+
+ assert await tsocket.getnameinfo("a", 0) == ("custom_gni", "a", 0)
+
+ # We can set it back to None
+ assert tsocket.set_custom_hostname_resolver(None) is cr
+
+ # And now Trio switches back to calling socket.getaddrinfo (specifically
+ # our monkeypatched version of socket.getaddrinfo)
+ monkeygai.set("x", b"host", "port", family=0, type=0, proto=0, flags=0)
+ assert await tsocket.getaddrinfo("host", "port") == "x"
+
+
+async def test_custom_socket_factory():
+ class CustomSocketFactory:
+ def socket(self, family, type, proto):
+ return ("hi", family, type, proto)
+
+ csf = CustomSocketFactory()
+
+ assert tsocket.set_custom_socket_factory(csf) is None
+
+ assert tsocket.socket() == ("hi", tsocket.AF_INET, tsocket.SOCK_STREAM, 0)
+ assert tsocket.socket(1, 2, 3) == ("hi", 1, 2, 3)
+
+ # socket with fileno= doesn't call our custom method
+ fd = stdlib_socket.socket().detach()
+ wrapped = tsocket.socket(fileno=fd)
+ assert hasattr(wrapped, "bind")
+ wrapped.close()
+
+ # Likewise for socketpair
+ a, b = tsocket.socketpair()
+ with a, b:
+ assert hasattr(a, "bind")
+ assert hasattr(b, "bind")
+
+ assert tsocket.set_custom_socket_factory(None) is csf
+
+
+async def test_SocketType_is_abstract():
+ with pytest.raises(TypeError):
+ tsocket.SocketType()
+
+
+@pytest.mark.skipif(not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets")
+async def test_unix_domain_socket():
+ # Bind has a special branch to use a thread, since it has to do filesystem
+ # traversal. Maybe connect should too? Not sure.
+
+ async def check_AF_UNIX(path):
+ with tsocket.socket(family=tsocket.AF_UNIX) as lsock:
+ await lsock.bind(path)
+ lsock.listen(10)
+ with tsocket.socket(family=tsocket.AF_UNIX) as csock:
+ await csock.connect(path)
+ ssock, _ = await lsock.accept()
+ with ssock:
+ await csock.send(b"x")
+ assert await ssock.recv(1) == b"x"
+
+ # Can't use tmpdir fixture, because we can exceed the maximum AF_UNIX path
+ # length on macOS.
+ with tempfile.TemporaryDirectory() as tmpdir:
+ path = "{}/sock".format(tmpdir)
+ await check_AF_UNIX(path)
+
+ try:
+ cookie = os.urandom(20).hex().encode("ascii")
+ await check_AF_UNIX(b"\x00trio-test-" + cookie)
+ except FileNotFoundError:
+ # macOS doesn't support abstract filenames with the leading NUL byte
+ pass
+
+
+async def test_interrupted_by_close():
+ a_stdlib, b_stdlib = stdlib_socket.socketpair()
+ with a_stdlib, b_stdlib:
+ a_stdlib.setblocking(False)
+
+ data = b"x" * 99999
+
+ try:
+ while True:
+ a_stdlib.send(data)
+ except BlockingIOError:
+ pass
+
+ a = tsocket.from_stdlib_socket(a_stdlib)
+
+ async def sender():
+ with pytest.raises(_core.ClosedResourceError):
+ await a.send(data)
+
+ async def receiver():
+ with pytest.raises(_core.ClosedResourceError):
+ await a.recv(1)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(sender)
+ nursery.start_soon(receiver)
+ await wait_all_tasks_blocked()
+ a.close()
+
+
+async def test_many_sockets():
+ total = 5000 # Must be more than MAX_AFD_GROUP_SIZE
+ sockets = []
+ for x in range(total // 2):
+ try:
+ a, b = stdlib_socket.socketpair()
+ except OSError as e: # pragma: no cover
+ assert e.errno in (errno.EMFILE, errno.ENFILE)
+ break
+ sockets += [a, b]
+ async with _core.open_nursery() as nursery:
+ for s in sockets:
+ nursery.start_soon(_core.wait_readable, s)
+ await _core.wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+ for sock in sockets:
+ sock.close()
+ if x != total // 2 - 1: # pragma: no cover
+ print(f"Unable to open more than {(x-1)*2} sockets.")
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_ssl.py b/venv/lib/python3.9/site-packages/trio/tests/test_ssl.py
new file mode 100644
index 00000000..7825e319
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_ssl.py
@@ -0,0 +1,1303 @@
+import os
+import re
+import sys
+
+import pytest
+
+import threading
+import socket as stdlib_socket
+import ssl
+from contextlib import contextmanager
+from functools import partial
+
+from OpenSSL import SSL
+import trustme
+from async_generator import asynccontextmanager
+
+import trio
+from .. import _core
+from .._highlevel_socket import SocketStream, SocketListener
+from .._highlevel_generic import aclose_forcefully
+from .._core import ClosedResourceError, BrokenResourceError
+from .._highlevel_open_tcp_stream import open_tcp_stream
+from .. import socket as tsocket
+from .._ssl import SSLStream, SSLListener, NeedHandshakeError, _is_eof
+from .._util import ConflictDetector
+
+from .._core.tests.tutil import slow
+
+from ..testing import (
+ assert_checkpoints,
+ Sequencer,
+ memory_stream_pair,
+ lockstep_stream_pair,
+ check_two_way_stream,
+)
+
+# We have two different kinds of echo server fixtures we use for testing. The
+# first is a real server written using the stdlib ssl module and blocking
+# sockets. It runs in a thread and we talk to it over a real socketpair(), to
+# validate interoperability in a semi-realistic setting.
+#
+# The second is a very weird virtual echo server that lives inside a custom
+# Stream class. It lives entirely inside the Python object space; there are no
+# operating system calls in it at all. No threads, no I/O, nothing. It's
+# 'send_all' call takes encrypted data from a client and feeds it directly into
+# the server-side TLS state engine to decrypt, then takes that data, feeds it
+# back through to get the encrypted response, and returns it from 'receive_some'. This
+# gives us full control and reproducibility. This server is written using
+# PyOpenSSL, so that we can trigger renegotiations on demand. It also allows
+# us to insert random (virtual) delays, to really exercise all the weird paths
+# in SSLStream's state engine.
+#
+# Both present a certificate for "trio-test-1.example.org".
+
+TRIO_TEST_CA = trustme.CA()
+TRIO_TEST_1_CERT = TRIO_TEST_CA.issue_server_cert("trio-test-1.example.org")
+
+SERVER_CTX = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
+ SERVER_CTX.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
+
+TRIO_TEST_1_CERT.configure_cert(SERVER_CTX)
+
+
+# TLS 1.3 has a lot of changes from previous versions. So we want to run tests
+# with both TLS 1.3, and TLS 1.2.
+# "tls13" means that we're willing to negotiate TLS 1.3. Usually that's
+# what will happen, but the renegotiation tests explicitly force a
+# downgrade on the server side. "tls12" means we refuse to negotiate TLS
+# 1.3, so we'll almost certainly use TLS 1.2.
+@pytest.fixture(scope="module", params=["tls13", "tls12"])
+def client_ctx(request):
+ ctx = ssl.create_default_context()
+
+ if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
+ ctx.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
+
+ TRIO_TEST_CA.configure_trust(ctx)
+ if request.param in ["default", "tls13"]:
+ return ctx
+ elif request.param == "tls12":
+ ctx.maximum_version = ssl.TLSVersion.TLSv1_2
+ return ctx
+ else: # pragma: no cover
+ assert False
+
+
+# The blocking socket server.
+def ssl_echo_serve_sync(sock, *, expect_fail=False):
+ try:
+ wrapped = SERVER_CTX.wrap_socket(
+ sock, server_side=True, suppress_ragged_eofs=False
+ )
+ with wrapped:
+ wrapped.do_handshake()
+ while True:
+ data = wrapped.recv(4096)
+ if not data:
+ # other side has initiated a graceful shutdown; we try to
+ # respond in kind but it's legal for them to have already
+ # gone away.
+ exceptions = (BrokenPipeError, ssl.SSLZeroReturnError)
+ try:
+ wrapped.unwrap()
+ except exceptions:
+ pass
+ except ssl.SSLWantWriteError: # pragma: no cover
+ # Under unclear conditions, CPython sometimes raises
+ # SSLWantWriteError here. This is a bug (bpo-32219),
+ # but it's not our bug. Christian Heimes thinks
+ # it's fixed in 'recent' CPython versions so we fail
+ # the test for those and ignore it for earlier
+ # versions.
+ if (
+ sys.implementation.name != "cpython"
+ or sys.version_info >= (3, 8)
+ ):
+ pytest.fail(
+ "still an issue on recent python versions "
+ "add a comment to "
+ "https://bugs.python.org/issue32219"
+ )
+ return
+ wrapped.sendall(data)
+ # This is an obscure workaround for an openssl bug. In server mode, in
+ # some versions, openssl sends some extra data at the end of do_handshake
+ # that it shouldn't send. Normally this is harmless, but, if the other
+ # side shuts down the connection before it reads that data, it might cause
+ # the OS to report a ECONNREST or even ECONNABORTED (which is just wrong,
+ # since ECONNABORTED is supposed to mean that connect() failed, but what
+ # can you do). In this case the other side did nothing wrong, but there's
+ # no way to recover, so we let it pass, and just cross our fingers its not
+ # hiding any (other) real bugs. For more details see:
+ #
+ # https://github.com/python-trio/trio/issues/1293
+ #
+ # Also, this happens frequently but non-deterministically, so we have to
+ # 'no cover' it to avoid coverage flapping.
+ except (ConnectionResetError, ConnectionAbortedError): # pragma: no cover
+ return
+ except Exception as exc:
+ if expect_fail:
+ print("ssl_echo_serve_sync got error as expected:", exc)
+ else: # pragma: no cover
+ print("ssl_echo_serve_sync got unexpected error:", exc)
+ raise
+ else:
+ if expect_fail: # pragma: no cover
+ raise RuntimeError("failed to fail?")
+ finally:
+ sock.close()
+
+
+# Fixture that gives a raw socket connected to a trio-test-1 echo server
+# (running in a thread). Useful for testing making connections with different
+# SSLContexts.
+@asynccontextmanager
+async def ssl_echo_server_raw(**kwargs):
+ a, b = stdlib_socket.socketpair()
+ async with trio.open_nursery() as nursery:
+ # Exiting the 'with a, b' context manager closes the sockets, which
+ # causes the thread to exit (possibly with an error), which allows the
+ # nursery context manager to exit too.
+ with a, b:
+ nursery.start_soon(
+ trio.to_thread.run_sync, partial(ssl_echo_serve_sync, b, **kwargs)
+ )
+
+ yield SocketStream(tsocket.from_stdlib_socket(a))
+
+
+# Fixture that gives a properly set up SSLStream connected to a trio-test-1
+# echo server (running in a thread)
+@asynccontextmanager
+async def ssl_echo_server(client_ctx, **kwargs):
+ async with ssl_echo_server_raw(**kwargs) as sock:
+ yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org")
+
+
+# The weird in-memory server ... thing.
+# Doesn't inherit from Stream because I left out the methods that we don't
+# actually need.
+class PyOpenSSLEchoStream:
+ def __init__(self, sleeper=None):
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
+ # TLS 1.3 removes renegotiation support. Which is great for them, but
+ # we still have to support versions before that, and that means we
+ # need to test renegotiation support, which means we need to force this
+ # to use a lower version where this test server can trigger
+ # renegotiations. Of course TLS 1.3 support isn't released yet, but
+ # I'm told that this will work once it is. (And once it is we can
+ # remove the pragma: no cover too.) Alternatively, we could switch to
+ # using TLSv1_2_METHOD.
+ #
+ # Discussion: https://github.com/pyca/pyopenssl/issues/624
+
+ # This is the right way, but we can't use it until this PR is in a
+ # released:
+ # https://github.com/pyca/pyopenssl/pull/861
+ #
+ # if hasattr(SSL, "OP_NO_TLSv1_3"):
+ # ctx.set_options(SSL.OP_NO_TLSv1_3)
+ #
+ # Fortunately pyopenssl uses cryptography under the hood, so we can be
+ # confident that they're using the same version of openssl
+ from cryptography.hazmat.bindings.openssl.binding import Binding
+
+ b = Binding()
+ if hasattr(b.lib, "SSL_OP_NO_TLSv1_3"):
+ ctx.set_options(b.lib.SSL_OP_NO_TLSv1_3)
+
+ # Unfortunately there's currently no way to say "use 1.3 or worse", we
+ # can only disable specific versions. And if the two sides start
+ # negotiating 1.4 at some point in the future, it *might* mean that
+ # our tests silently stop working properly. So the next line is a
+ # tripwire to remind us we need to revisit this stuff in 5 years or
+ # whatever when the next TLS version is released:
+ assert not hasattr(SSL, "OP_NO_TLSv1_4")
+ TRIO_TEST_1_CERT.configure_cert(ctx)
+ self._conn = SSL.Connection(ctx, None)
+ self._conn.set_accept_state()
+ self._lot = _core.ParkingLot()
+ self._pending_cleartext = bytearray()
+
+ self._send_all_conflict_detector = ConflictDetector(
+ "simultaneous calls to PyOpenSSLEchoStream.send_all"
+ )
+ self._receive_some_conflict_detector = ConflictDetector(
+ "simultaneous calls to PyOpenSSLEchoStream.receive_some"
+ )
+
+ if sleeper is None:
+
+ async def no_op_sleeper(_):
+ return
+
+ self.sleeper = no_op_sleeper
+ else:
+ self.sleeper = sleeper
+
+ async def aclose(self):
+ self._conn.bio_shutdown()
+
+ def renegotiate_pending(self):
+ return self._conn.renegotiate_pending()
+
+ def renegotiate(self):
+ # Returns false if a renegotiation is already in progress, meaning
+ # nothing happens.
+ assert self._conn.renegotiate()
+
+ async def wait_send_all_might_not_block(self):
+ with self._send_all_conflict_detector:
+ await _core.checkpoint()
+ await _core.checkpoint()
+ await self.sleeper("wait_send_all_might_not_block")
+
+ async def send_all(self, data):
+ print(" --> transport_stream.send_all")
+ with self._send_all_conflict_detector:
+ await _core.checkpoint()
+ await _core.checkpoint()
+ await self.sleeper("send_all")
+ self._conn.bio_write(data)
+ while True:
+ await self.sleeper("send_all")
+ try:
+ data = self._conn.recv(1)
+ except SSL.ZeroReturnError:
+ self._conn.shutdown()
+ print("renegotiations:", self._conn.total_renegotiations())
+ break
+ except SSL.WantReadError:
+ break
+ else:
+ self._pending_cleartext += data
+ self._lot.unpark_all()
+ await self.sleeper("send_all")
+ print(" <-- transport_stream.send_all finished")
+
+ async def receive_some(self, nbytes=None):
+ print(" --> transport_stream.receive_some")
+ if nbytes is None:
+ nbytes = 65536 # arbitrary
+ with self._receive_some_conflict_detector:
+ try:
+ await _core.checkpoint()
+ await _core.checkpoint()
+ while True:
+ await self.sleeper("receive_some")
+ try:
+ return self._conn.bio_read(nbytes)
+ except SSL.WantReadError:
+ # No data in our ciphertext buffer; try to generate
+ # some.
+ if self._pending_cleartext:
+ # We have some cleartext; maybe we can encrypt it
+ # and then return it.
+ print(" trying", self._pending_cleartext)
+ try:
+ # PyOpenSSL bug: doesn't accept bytearray
+ # https://github.com/pyca/pyopenssl/issues/621
+ next_byte = self._pending_cleartext[0:1]
+ self._conn.send(bytes(next_byte))
+ # Apparently this next bit never gets hit in the
+ # test suite, but it's not an interesting omission
+ # so let's pragma it.
+ except SSL.WantReadError: # pragma: no cover
+ # We didn't manage to send the cleartext (and
+ # in particular we better leave it there to
+ # try again, due to openssl's retry
+ # semantics), but it's possible we pushed a
+ # renegotiation forward and *now* we have data
+ # to send.
+ try:
+ return self._conn.bio_read(nbytes)
+ except SSL.WantReadError:
+ # Nope. We're just going to have to wait
+ # for someone to call send_all() to give
+ # use more data.
+ print("parking (a)")
+ await self._lot.park()
+ else:
+ # We successfully sent that byte, so we don't
+ # have to again.
+ del self._pending_cleartext[0:1]
+ else:
+ # no pending cleartext; nothing to do but wait for
+ # someone to call send_all
+ print("parking (b)")
+ await self._lot.park()
+ finally:
+ await self.sleeper("receive_some")
+ print(" <-- transport_stream.receive_some finished")
+
+
+async def test_PyOpenSSLEchoStream_gives_resource_busy_errors():
+ # Make sure that PyOpenSSLEchoStream complains if two tasks call send_all
+ # at the same time, or ditto for receive_some. The tricky cases where SSLStream
+ # might accidentally do this are during renegotiation, which we test using
+ # PyOpenSSLEchoStream, so this makes sure that if we do have a bug then
+ # PyOpenSSLEchoStream will notice and complain.
+
+ s = PyOpenSSLEchoStream()
+ with pytest.raises(_core.BusyResourceError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(s.send_all, b"x")
+ nursery.start_soon(s.send_all, b"x")
+ assert "simultaneous" in str(excinfo.value)
+
+ s = PyOpenSSLEchoStream()
+ with pytest.raises(_core.BusyResourceError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(s.send_all, b"x")
+ nursery.start_soon(s.wait_send_all_might_not_block)
+ assert "simultaneous" in str(excinfo.value)
+
+ s = PyOpenSSLEchoStream()
+ with pytest.raises(_core.BusyResourceError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(s.wait_send_all_might_not_block)
+ nursery.start_soon(s.wait_send_all_might_not_block)
+ assert "simultaneous" in str(excinfo.value)
+
+ s = PyOpenSSLEchoStream()
+ with pytest.raises(_core.BusyResourceError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(s.receive_some, 1)
+ nursery.start_soon(s.receive_some, 1)
+ assert "simultaneous" in str(excinfo.value)
+
+
+@contextmanager
+def virtual_ssl_echo_server(client_ctx, **kwargs):
+ fakesock = PyOpenSSLEchoStream(**kwargs)
+ yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org")
+
+
+def ssl_wrap_pair(
+ client_ctx,
+ client_transport,
+ server_transport,
+ *,
+ client_kwargs={},
+ server_kwargs={},
+):
+ client_ssl = SSLStream(
+ client_transport,
+ client_ctx,
+ server_hostname="trio-test-1.example.org",
+ **client_kwargs,
+ )
+ server_ssl = SSLStream(
+ server_transport, SERVER_CTX, server_side=True, **server_kwargs
+ )
+ return client_ssl, server_ssl
+
+
+def ssl_memory_stream_pair(client_ctx, **kwargs):
+ client_transport, server_transport = memory_stream_pair()
+ return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
+
+
+def ssl_lockstep_stream_pair(client_ctx, **kwargs):
+ client_transport, server_transport = lockstep_stream_pair()
+ return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
+
+
+# Simple smoke test for handshake/send/receive/shutdown talking to a
+# synchronous server, plus make sure that we do the bare minimum of
+# certificate checking (even though this is really Python's responsibility)
+async def test_ssl_client_basics(client_ctx):
+ # Everything OK
+ async with ssl_echo_server(client_ctx) as s:
+ assert not s.server_side
+ await s.send_all(b"x")
+ assert await s.receive_some(1) == b"x"
+ await s.aclose()
+
+ # Didn't configure the CA file, should fail
+ async with ssl_echo_server_raw(expect_fail=True) as sock:
+ bad_client_ctx = ssl.create_default_context()
+ s = SSLStream(sock, bad_client_ctx, server_hostname="trio-test-1.example.org")
+ assert not s.server_side
+ with pytest.raises(BrokenResourceError) as excinfo:
+ await s.send_all(b"x")
+ assert isinstance(excinfo.value.__cause__, ssl.SSLError)
+
+ # Trusted CA, but wrong host name
+ async with ssl_echo_server_raw(expect_fail=True) as sock:
+ s = SSLStream(sock, client_ctx, server_hostname="trio-test-2.example.org")
+ assert not s.server_side
+ with pytest.raises(BrokenResourceError) as excinfo:
+ await s.send_all(b"x")
+ assert isinstance(excinfo.value.__cause__, ssl.CertificateError)
+
+
+async def test_ssl_server_basics(client_ctx):
+ a, b = stdlib_socket.socketpair()
+ with a, b:
+ server_sock = tsocket.from_stdlib_socket(b)
+ server_transport = SSLStream(
+ SocketStream(server_sock), SERVER_CTX, server_side=True
+ )
+ assert server_transport.server_side
+
+ def client():
+ with client_ctx.wrap_socket(
+ a, server_hostname="trio-test-1.example.org"
+ ) as client_sock:
+ client_sock.sendall(b"x")
+ assert client_sock.recv(1) == b"y"
+ client_sock.sendall(b"z")
+ client_sock.unwrap()
+
+ t = threading.Thread(target=client)
+ t.start()
+
+ assert await server_transport.receive_some(1) == b"x"
+ await server_transport.send_all(b"y")
+ assert await server_transport.receive_some(1) == b"z"
+ assert await server_transport.receive_some(1) == b""
+ await server_transport.aclose()
+
+ t.join()
+
+
+async def test_attributes(client_ctx):
+ async with ssl_echo_server_raw(expect_fail=True) as sock:
+ good_ctx = client_ctx
+ bad_ctx = ssl.create_default_context()
+ s = SSLStream(sock, good_ctx, server_hostname="trio-test-1.example.org")
+
+ assert s.transport_stream is sock
+
+ # Forwarded attribute getting
+ assert s.context is good_ctx
+ assert s.server_side == False # noqa
+ assert s.server_hostname == "trio-test-1.example.org"
+ with pytest.raises(AttributeError):
+ s.asfdasdfsa
+
+ # __dir__
+ assert "transport_stream" in dir(s)
+ assert "context" in dir(s)
+
+ # Setting the attribute goes through to the underlying object
+
+ # most attributes on SSLObject are read-only
+ with pytest.raises(AttributeError):
+ s.server_side = True
+ with pytest.raises(AttributeError):
+ s.server_hostname = "asdf"
+
+ # but .context is *not*. Check that we forward attribute setting by
+ # making sure that after we set the bad context our handshake indeed
+ # fails:
+ s.context = bad_ctx
+ assert s.context is bad_ctx
+ with pytest.raises(BrokenResourceError) as excinfo:
+ await s.do_handshake()
+ assert isinstance(excinfo.value.__cause__, ssl.SSLError)
+
+
+# Note: this test fails horribly if we force TLS 1.2 and trigger a
+# renegotiation at the beginning (e.g. by switching to the pyopenssl
+# server). Usually the client crashes in SSLObject.write with "UNEXPECTED
+# RECORD"; sometimes we get something more exotic like a SyscallError. This is
+# odd because openssl isn't doing any syscalls, but so it goes. After lots of
+# websearching I'm pretty sure this is due to a bug in OpenSSL, where it just
+# can't reliably handle full-duplex communication combined with
+# renegotiation. Nice, eh?
+#
+# https://rt.openssl.org/Ticket/Display.html?id=3712
+# https://rt.openssl.org/Ticket/Display.html?id=2481
+# http://openssl.6102.n7.nabble.com/TLS-renegotiation-failure-on-receiving-application-data-during-handshake-td48127.html
+# https://stackoverflow.com/questions/18728355/ssl-renegotiation-with-full-duplex-socket-communication
+#
+# In some variants of this test (maybe only against the java server?) I've
+# also seen cases where our send_all blocks waiting to write, and then our receive_some
+# also blocks waiting to write, and they never wake up again. It looks like
+# some kind of deadlock. I suspect there may be an issue where we've filled up
+# the send buffers, and the remote side is trying to handle the renegotiation
+# from inside a write() call, so it has a problem: there's all this application
+# data clogging up the pipe, but it can't process and return it to the
+# application because it's in write(), and it doesn't want to buffer infinite
+# amounts of data, and... actually I guess those are the only two choices.
+#
+# NSS even documents that you shouldn't try to do a renegotiation except when
+# the connection is idle:
+#
+# https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/SSL_functions/sslfnc.html#1061582
+#
+# I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it...
+
+
+async def test_full_duplex_basics(client_ctx):
+ CHUNKS = 30
+ CHUNK_SIZE = 32768
+ EXPECTED = CHUNKS * CHUNK_SIZE
+
+ sent = bytearray()
+ received = bytearray()
+
+ async def sender(s):
+ nonlocal sent
+ for i in range(CHUNKS):
+ print(i)
+ chunk = bytes([i] * CHUNK_SIZE)
+ sent += chunk
+ await s.send_all(chunk)
+
+ async def receiver(s):
+ nonlocal received
+ while len(received) < EXPECTED:
+ chunk = await s.receive_some(CHUNK_SIZE // 2)
+ received += chunk
+
+ async with ssl_echo_server(client_ctx) as s:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(sender, s)
+ nursery.start_soon(receiver, s)
+ # And let's have some doing handshakes too, everyone
+ # simultaneously
+ nursery.start_soon(s.do_handshake)
+ nursery.start_soon(s.do_handshake)
+
+ await s.aclose()
+
+ assert len(sent) == len(received) == EXPECTED
+ assert sent == received
+
+
+async def test_renegotiation_simple(client_ctx):
+ with virtual_ssl_echo_server(client_ctx) as s:
+ await s.do_handshake()
+
+ s.transport_stream.renegotiate()
+ await s.send_all(b"a")
+ assert await s.receive_some(1) == b"a"
+
+ # Have to send some more data back and forth to make sure the
+ # renegotiation is finished before shutting down the
+ # connection... otherwise openssl raises an error. I think this is a
+ # bug in openssl but what can ya do.
+ await s.send_all(b"b")
+ assert await s.receive_some(1) == b"b"
+
+ await s.aclose()
+
+
+@slow
+async def test_renegotiation_randomized(mock_clock, client_ctx):
+ # The only blocking things in this function are our random sleeps, so 0 is
+ # a good threshold.
+ mock_clock.autojump_threshold = 0
+
+ import random
+
+ r = random.Random(0)
+
+ async def sleeper(_):
+ await trio.sleep(r.uniform(0, 10))
+
+ async def clear():
+ while s.transport_stream.renegotiate_pending():
+ with assert_checkpoints():
+ await send(b"-")
+ with assert_checkpoints():
+ await expect(b"-")
+ print("-- clear --")
+
+ async def send(byte):
+ await s.transport_stream.sleeper("outer send")
+ print("calling SSLStream.send_all", byte)
+ with assert_checkpoints():
+ await s.send_all(byte)
+
+ async def expect(expected):
+ await s.transport_stream.sleeper("expect")
+ print("calling SSLStream.receive_some, expecting", expected)
+ assert len(expected) == 1
+ with assert_checkpoints():
+ assert await s.receive_some(1) == expected
+
+ with virtual_ssl_echo_server(client_ctx, sleeper=sleeper) as s:
+ await s.do_handshake()
+
+ await send(b"a")
+ s.transport_stream.renegotiate()
+ await expect(b"a")
+
+ await clear()
+
+ for i in range(100):
+ b1 = bytes([i % 0xFF])
+ b2 = bytes([(2 * i) % 0xFF])
+ s.transport_stream.renegotiate()
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(send, b1)
+ nursery.start_soon(expect, b1)
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(expect, b2)
+ nursery.start_soon(send, b2)
+ await clear()
+
+ for i in range(100):
+ b1 = bytes([i % 0xFF])
+ b2 = bytes([(2 * i) % 0xFF])
+ await send(b1)
+ s.transport_stream.renegotiate()
+ await expect(b1)
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(expect, b2)
+ nursery.start_soon(send, b2)
+ await clear()
+
+ # Checking that wait_send_all_might_not_block and receive_some don't
+ # conflict:
+
+ # 1) Set up a situation where expect (receive_some) is blocked sending,
+ # and wait_send_all_might_not_block comes in.
+
+ # Our receive_some() call will get stuck when it hits send_all
+ async def sleeper_with_slow_send_all(method):
+ if method == "send_all":
+ await trio.sleep(100000)
+
+ # And our wait_send_all_might_not_block call will give it time to get
+ # stuck, and then start
+ async def sleep_then_wait_writable():
+ await trio.sleep(1000)
+ await s.wait_send_all_might_not_block()
+
+ with virtual_ssl_echo_server(client_ctx, sleeper=sleeper_with_slow_send_all) as s:
+ await send(b"x")
+ s.transport_stream.renegotiate()
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(expect, b"x")
+ nursery.start_soon(sleep_then_wait_writable)
+
+ await clear()
+
+ await s.aclose()
+
+ # 2) Same, but now wait_send_all_might_not_block is stuck when
+ # receive_some tries to send.
+
+ async def sleeper_with_slow_wait_writable_and_expect(method):
+ if method == "wait_send_all_might_not_block":
+ await trio.sleep(100000)
+ elif method == "expect":
+ await trio.sleep(1000)
+
+ with virtual_ssl_echo_server(
+ client_ctx, sleeper=sleeper_with_slow_wait_writable_and_expect
+ ) as s:
+ await send(b"x")
+ s.transport_stream.renegotiate()
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(expect, b"x")
+ nursery.start_soon(s.wait_send_all_might_not_block)
+
+ await clear()
+
+ await s.aclose()
+
+
+async def test_resource_busy_errors(client_ctx):
+ async def do_send_all():
+ with assert_checkpoints():
+ await s.send_all(b"x")
+
+ async def do_receive_some():
+ with assert_checkpoints():
+ await s.receive_some(1)
+
+ async def do_wait_send_all_might_not_block():
+ with assert_checkpoints():
+ await s.wait_send_all_might_not_block()
+
+ s, _ = ssl_lockstep_stream_pair(client_ctx)
+ with pytest.raises(_core.BusyResourceError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(do_send_all)
+ nursery.start_soon(do_send_all)
+ assert "another task" in str(excinfo.value)
+
+ s, _ = ssl_lockstep_stream_pair(client_ctx)
+ with pytest.raises(_core.BusyResourceError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(do_receive_some)
+ nursery.start_soon(do_receive_some)
+ assert "another task" in str(excinfo.value)
+
+ s, _ = ssl_lockstep_stream_pair(client_ctx)
+ with pytest.raises(_core.BusyResourceError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(do_send_all)
+ nursery.start_soon(do_wait_send_all_might_not_block)
+ assert "another task" in str(excinfo.value)
+
+ s, _ = ssl_lockstep_stream_pair(client_ctx)
+ with pytest.raises(_core.BusyResourceError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(do_wait_send_all_might_not_block)
+ nursery.start_soon(do_wait_send_all_might_not_block)
+ assert "another task" in str(excinfo.value)
+
+
+async def test_wait_writable_calls_underlying_wait_writable():
+ record = []
+
+ class NotAStream:
+ async def wait_send_all_might_not_block(self):
+ record.append("ok")
+
+ ctx = ssl.create_default_context()
+ s = SSLStream(NotAStream(), ctx, server_hostname="x")
+ await s.wait_send_all_might_not_block()
+ assert record == ["ok"]
+
+
+@pytest.mark.skipif(
+ os.name == "nt" and sys.version_info >= (3, 10),
+ reason="frequently fails on Windows + Python 3.10",
+)
+async def test_checkpoints(client_ctx):
+ async with ssl_echo_server(client_ctx) as s:
+ with assert_checkpoints():
+ await s.do_handshake()
+ with assert_checkpoints():
+ await s.do_handshake()
+ with assert_checkpoints():
+ await s.wait_send_all_might_not_block()
+ with assert_checkpoints():
+ await s.send_all(b"xxx")
+ with assert_checkpoints():
+ await s.receive_some(1)
+ # These receive_some's in theory could return immediately, because the
+ # "xxx" was sent in a single record and after the first
+ # receive_some(1) the rest are sitting inside the SSLObject's internal
+ # buffers.
+ with assert_checkpoints():
+ await s.receive_some(1)
+ with assert_checkpoints():
+ await s.receive_some(1)
+ with assert_checkpoints():
+ await s.unwrap()
+
+ async with ssl_echo_server(client_ctx) as s:
+ await s.do_handshake()
+ with assert_checkpoints():
+ await s.aclose()
+
+
+async def test_send_all_empty_string(client_ctx):
+ async with ssl_echo_server(client_ctx) as s:
+ await s.do_handshake()
+
+ # underlying SSLObject interprets writing b"" as indicating an EOF,
+ # for some reason. Make sure we don't inherit this.
+ with assert_checkpoints():
+ await s.send_all(b"")
+ with assert_checkpoints():
+ await s.send_all(b"")
+ await s.send_all(b"x")
+ assert await s.receive_some(1) == b"x"
+
+ await s.aclose()
+
+
+@pytest.mark.parametrize("https_compatible", [False, True])
+async def test_SSLStream_generic(client_ctx, https_compatible):
+ async def stream_maker():
+ return ssl_memory_stream_pair(
+ client_ctx,
+ client_kwargs={"https_compatible": https_compatible},
+ server_kwargs={"https_compatible": https_compatible},
+ )
+
+ async def clogged_stream_maker():
+ client, server = ssl_lockstep_stream_pair(client_ctx)
+ # If we don't do handshakes up front, then we run into a problem in
+ # the following situation:
+ # - server does wait_send_all_might_not_block
+ # - client does receive_some to unclog it
+ # Then the client's receive_some will actually send some data to start
+ # the handshake, and itself get stuck.
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.do_handshake)
+ nursery.start_soon(server.do_handshake)
+ return client, server
+
+ await check_two_way_stream(stream_maker, clogged_stream_maker)
+
+
+async def test_unwrap(client_ctx):
+ client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
+ client_transport = client_ssl.transport_stream
+ server_transport = server_ssl.transport_stream
+
+ seq = Sequencer()
+
+ async def client():
+ await client_ssl.do_handshake()
+ await client_ssl.send_all(b"x")
+ assert await client_ssl.receive_some(1) == b"y"
+ await client_ssl.send_all(b"z")
+
+ # After sending that, disable outgoing data from our end, to make
+ # sure the server doesn't see our EOF until after we've sent some
+ # trailing data
+ async with seq(0):
+ send_all_hook = client_transport.send_stream.send_all_hook
+ client_transport.send_stream.send_all_hook = None
+
+ assert await client_ssl.receive_some(1) == b""
+ assert client_ssl.transport_stream is client_transport
+ # We just received EOF. Unwrap the connection and send some more.
+ raw, trailing = await client_ssl.unwrap()
+ assert raw is client_transport
+ assert trailing == b""
+ assert client_ssl.transport_stream is None
+ await raw.send_all(b"trailing")
+
+ # Reconnect the streams. Now the server will receive both our shutdown
+ # acknowledgement + the trailing data in a single lump.
+ client_transport.send_stream.send_all_hook = send_all_hook
+ await client_transport.send_stream.send_all_hook()
+
+ async def server():
+ await server_ssl.do_handshake()
+ assert await server_ssl.receive_some(1) == b"x"
+ await server_ssl.send_all(b"y")
+ assert await server_ssl.receive_some(1) == b"z"
+ # Now client is blocked waiting for us to send something, but
+ # instead we close the TLS connection (with sequencer to make sure
+ # that the client won't see and automatically respond before we've had
+ # a chance to disable the client->server transport)
+ async with seq(1):
+ raw, trailing = await server_ssl.unwrap()
+ assert raw is server_transport
+ assert trailing == b"trailing"
+ assert server_ssl.transport_stream is None
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client)
+ nursery.start_soon(server)
+
+
+async def test_closing_nice_case(client_ctx):
+ # the nice case: graceful closes all around
+
+ client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
+ client_transport = client_ssl.transport_stream
+
+ # Both the handshake and the close require back-and-forth discussion, so
+ # we need to run them concurrently
+ async def client_closer():
+ with assert_checkpoints():
+ await client_ssl.aclose()
+
+ async def server_closer():
+ assert await server_ssl.receive_some(10) == b""
+ assert await server_ssl.receive_some(10) == b""
+ with assert_checkpoints():
+ await server_ssl.aclose()
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client_closer)
+ nursery.start_soon(server_closer)
+
+ # closing the SSLStream also closes its transport
+ with pytest.raises(ClosedResourceError):
+ await client_transport.send_all(b"123")
+
+ # once closed, it's OK to close again
+ with assert_checkpoints():
+ await client_ssl.aclose()
+ with assert_checkpoints():
+ await client_ssl.aclose()
+
+ # Trying to send more data does not work
+ with pytest.raises(ClosedResourceError):
+ await server_ssl.send_all(b"123")
+
+ # And once the connection is has been closed *locally*, then instead of
+ # getting empty bytestrings we get a proper error
+ with pytest.raises(ClosedResourceError):
+ await client_ssl.receive_some(10) == b""
+
+ with pytest.raises(ClosedResourceError):
+ await client_ssl.unwrap()
+
+ with pytest.raises(ClosedResourceError):
+ await client_ssl.do_handshake()
+
+ # Check that a graceful close *before* handshaking gives a clean EOF on
+ # the other side
+ client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx)
+
+ async def expect_eof_server():
+ with assert_checkpoints():
+ assert await server_ssl.receive_some(10) == b""
+ with assert_checkpoints():
+ await server_ssl.aclose()
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client_ssl.aclose)
+ nursery.start_soon(expect_eof_server)
+
+
+async def test_send_all_fails_in_the_middle(client_ctx):
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.do_handshake)
+ nursery.start_soon(server.do_handshake)
+
+ async def bad_hook():
+ raise KeyError
+
+ client.transport_stream.send_stream.send_all_hook = bad_hook
+
+ with pytest.raises(KeyError):
+ await client.send_all(b"x")
+
+ with pytest.raises(BrokenResourceError):
+ await client.wait_send_all_might_not_block()
+
+ closed = 0
+
+ def close_hook():
+ nonlocal closed
+ closed += 1
+
+ client.transport_stream.send_stream.close_hook = close_hook
+ client.transport_stream.receive_stream.close_hook = close_hook
+ await client.aclose()
+
+ assert closed == 2
+
+
+async def test_ssl_over_ssl(client_ctx):
+ client_0, server_0 = memory_stream_pair()
+
+ client_1 = SSLStream(
+ client_0, client_ctx, server_hostname="trio-test-1.example.org"
+ )
+ server_1 = SSLStream(server_0, SERVER_CTX, server_side=True)
+
+ client_2 = SSLStream(
+ client_1, client_ctx, server_hostname="trio-test-1.example.org"
+ )
+ server_2 = SSLStream(server_1, SERVER_CTX, server_side=True)
+
+ async def client():
+ await client_2.send_all(b"hi")
+ assert await client_2.receive_some(10) == b"bye"
+
+ async def server():
+ assert await server_2.receive_some(10) == b"hi"
+ await server_2.send_all(b"bye")
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client)
+ nursery.start_soon(server)
+
+
+async def test_ssl_bad_shutdown(client_ctx):
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.do_handshake)
+ nursery.start_soon(server.do_handshake)
+
+ await trio.aclose_forcefully(client)
+ # now the server sees a broken stream
+ with pytest.raises(BrokenResourceError):
+ await server.receive_some(10)
+ with pytest.raises(BrokenResourceError):
+ await server.send_all(b"x" * 10)
+
+ await server.aclose()
+
+
+async def test_ssl_bad_shutdown_but_its_ok(client_ctx):
+ client, server = ssl_memory_stream_pair(
+ client_ctx,
+ server_kwargs={"https_compatible": True},
+ client_kwargs={"https_compatible": True},
+ )
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.do_handshake)
+ nursery.start_soon(server.do_handshake)
+
+ await trio.aclose_forcefully(client)
+ # the server sees that as a clean shutdown
+ assert await server.receive_some(10) == b""
+ with pytest.raises(BrokenResourceError):
+ await server.send_all(b"x" * 10)
+
+ await server.aclose()
+
+
+async def test_ssl_handshake_failure_during_aclose():
+ # Weird scenario: aclose() triggers an automatic handshake, and this
+ # fails. This also exercises a bit of code in aclose() that was otherwise
+ # uncovered, for re-raising exceptions after calling aclose_forcefully on
+ # the underlying transport.
+ async with ssl_echo_server_raw(expect_fail=True) as sock:
+ # Don't configure trust correctly
+ client_ctx = ssl.create_default_context()
+ s = SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org")
+ # It's a little unclear here whether aclose should swallow the error
+ # or let it escape. We *do* swallow the error if it arrives when we're
+ # sending close_notify, because both sides closing the connection
+ # simultaneously is allowed. But I guess when https_compatible=False
+ # then it's bad if we can get through a whole connection with a peer
+ # that has no valid certificate, and never raise an error.
+ with pytest.raises(BrokenResourceError):
+ await s.aclose()
+
+
+async def test_ssl_only_closes_stream_once(client_ctx):
+ # We used to have a bug where if transport_stream.aclose() raised an
+ # error, we would call it again. This checks that that's fixed.
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.do_handshake)
+ nursery.start_soon(server.do_handshake)
+
+ client_orig_close_hook = client.transport_stream.send_stream.close_hook
+ transport_close_count = 0
+
+ def close_hook():
+ nonlocal transport_close_count
+ client_orig_close_hook()
+ transport_close_count += 1
+ raise KeyError
+
+ client.transport_stream.send_stream.close_hook = close_hook
+
+ with pytest.raises(KeyError):
+ await client.aclose()
+ assert transport_close_count == 1
+
+
+async def test_ssl_https_compatibility_disagreement(client_ctx):
+ client, server = ssl_memory_stream_pair(
+ client_ctx,
+ server_kwargs={"https_compatible": False},
+ client_kwargs={"https_compatible": True},
+ )
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.do_handshake)
+ nursery.start_soon(server.do_handshake)
+
+ # client is in HTTPS-mode, server is not
+ # so client doing graceful_shutdown causes an error on server
+ async def receive_and_expect_error():
+ with pytest.raises(BrokenResourceError) as excinfo:
+ await server.receive_some(10)
+
+ assert _is_eof(excinfo.value.__cause__)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.aclose)
+ nursery.start_soon(receive_and_expect_error)
+
+
+async def test_https_mode_eof_before_handshake(client_ctx):
+ client, server = ssl_memory_stream_pair(
+ client_ctx,
+ server_kwargs={"https_compatible": True},
+ client_kwargs={"https_compatible": True},
+ )
+
+ async def server_expect_clean_eof():
+ assert await server.receive_some(10) == b""
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.aclose)
+ nursery.start_soon(server_expect_clean_eof)
+
+
+async def test_send_error_during_handshake(client_ctx):
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ async def bad_hook():
+ raise KeyError
+
+ client.transport_stream.send_stream.send_all_hook = bad_hook
+
+ with pytest.raises(KeyError):
+ with assert_checkpoints():
+ await client.do_handshake()
+
+ with pytest.raises(BrokenResourceError):
+ with assert_checkpoints():
+ await client.do_handshake()
+
+
+async def test_receive_error_during_handshake(client_ctx):
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ async def bad_hook():
+ raise KeyError
+
+ client.transport_stream.receive_stream.receive_some_hook = bad_hook
+
+ async def client_side(cancel_scope):
+ with pytest.raises(KeyError):
+ with assert_checkpoints():
+ await client.do_handshake()
+ cancel_scope.cancel()
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client_side, nursery.cancel_scope)
+ nursery.start_soon(server.do_handshake)
+
+ with pytest.raises(BrokenResourceError):
+ with assert_checkpoints():
+ await client.do_handshake()
+
+
+async def test_selected_alpn_protocol_before_handshake(client_ctx):
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ with pytest.raises(NeedHandshakeError):
+ client.selected_alpn_protocol()
+
+ with pytest.raises(NeedHandshakeError):
+ server.selected_alpn_protocol()
+
+
+async def test_selected_alpn_protocol_when_not_set(client_ctx):
+ # ALPN protocol still returns None when it's not set,
+ # instead of raising an exception
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.do_handshake)
+ nursery.start_soon(server.do_handshake)
+
+ assert client.selected_alpn_protocol() is None
+ assert server.selected_alpn_protocol() is None
+
+ assert client.selected_alpn_protocol() == server.selected_alpn_protocol()
+
+
+async def test_selected_npn_protocol_before_handshake(client_ctx):
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ with pytest.raises(NeedHandshakeError):
+ client.selected_npn_protocol()
+
+ with pytest.raises(NeedHandshakeError):
+ server.selected_npn_protocol()
+
+
+@pytest.mark.filterwarnings(
+ r"ignore: ssl module. NPN is deprecated, use ALPN instead:UserWarning",
+ r"ignore:ssl NPN is deprecated, use ALPN instead:DeprecationWarning",
+)
+async def test_selected_npn_protocol_when_not_set(client_ctx):
+ # NPN protocol still returns None when it's not set,
+ # instead of raising an exception
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.do_handshake)
+ nursery.start_soon(server.do_handshake)
+
+ assert client.selected_npn_protocol() is None
+ assert server.selected_npn_protocol() is None
+
+ assert client.selected_npn_protocol() == server.selected_npn_protocol()
+
+
+async def test_get_channel_binding_before_handshake(client_ctx):
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ with pytest.raises(NeedHandshakeError):
+ client.get_channel_binding()
+
+ with pytest.raises(NeedHandshakeError):
+ server.get_channel_binding()
+
+
+async def test_get_channel_binding_after_handshake(client_ctx):
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.do_handshake)
+ nursery.start_soon(server.do_handshake)
+
+ assert client.get_channel_binding() is not None
+ assert server.get_channel_binding() is not None
+
+ assert client.get_channel_binding() == server.get_channel_binding()
+
+
+async def test_getpeercert(client_ctx):
+ # Make sure we're not affected by https://bugs.python.org/issue29334
+ client, server = ssl_memory_stream_pair(client_ctx)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(client.do_handshake)
+ nursery.start_soon(server.do_handshake)
+
+ assert server.getpeercert() is None
+ print(client.getpeercert())
+ assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"]
+
+
+async def test_SSLListener(client_ctx):
+ async def setup(**kwargs):
+ listen_sock = tsocket.socket()
+ await listen_sock.bind(("127.0.0.1", 0))
+ listen_sock.listen(1)
+ socket_listener = SocketListener(listen_sock)
+ ssl_listener = SSLListener(socket_listener, SERVER_CTX, **kwargs)
+
+ transport_client = await open_tcp_stream(*listen_sock.getsockname())
+ ssl_client = SSLStream(
+ transport_client, client_ctx, server_hostname="trio-test-1.example.org"
+ )
+ return listen_sock, ssl_listener, ssl_client
+
+ listen_sock, ssl_listener, ssl_client = await setup()
+
+ async with ssl_client:
+ ssl_server = await ssl_listener.accept()
+
+ async with ssl_server:
+ assert not ssl_server._https_compatible
+
+ # Make sure the connection works
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(ssl_client.do_handshake)
+ nursery.start_soon(ssl_server.do_handshake)
+
+ # Test SSLListener.aclose
+ await ssl_listener.aclose()
+ assert listen_sock.fileno() == -1
+
+ ################
+
+ # Test https_compatible
+ _, ssl_listener, ssl_client = await setup(https_compatible=True)
+
+ ssl_server = await ssl_listener.accept()
+
+ assert ssl_server._https_compatible
+
+ await aclose_forcefully(ssl_listener)
+ await aclose_forcefully(ssl_client)
+ await aclose_forcefully(ssl_server)
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_subprocess.py b/venv/lib/python3.9/site-packages/trio/tests/test_subprocess.py
new file mode 100644
index 00000000..061a7151
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_subprocess.py
@@ -0,0 +1,602 @@
+import os
+import random
+import signal
+import subprocess
+import sys
+from functools import partial
+from pathlib import Path as SyncPath
+
+import pytest
+from async_generator import asynccontextmanager
+
+from .. import (
+ ClosedResourceError,
+ Event,
+ Process,
+ _core,
+ fail_after,
+ move_on_after,
+ run_process,
+ sleep,
+ sleep_forever,
+)
+from .._core.tests.tutil import skip_if_fbsd_pipes_broken, slow
+from ..lowlevel import open_process
+from ..testing import assert_no_checkpoints, wait_all_tasks_blocked
+
+posix = os.name == "posix"
+if posix:
+ from signal import SIGKILL, SIGTERM, SIGUSR1
+else:
+ SIGKILL, SIGTERM, SIGUSR1 = None, None, None
+
+
+# Since Windows has very few command-line utilities generally available,
+# all of our subprocesses are Python processes running short bits of
+# (mostly) cross-platform code.
+def python(code):
+ return [sys.executable, "-u", "-c", "import sys; " + code]
+
+
+EXIT_TRUE = python("sys.exit(0)")
+EXIT_FALSE = python("sys.exit(1)")
+CAT = python("sys.stdout.buffer.write(sys.stdin.buffer.read())")
+
+if posix:
+ SLEEP = lambda seconds: ["/bin/sleep", str(seconds)]
+else:
+ SLEEP = lambda seconds: python("import time; time.sleep({})".format(seconds))
+
+
+def got_signal(proc, sig):
+ if posix:
+ return proc.returncode == -sig
+ else:
+ return proc.returncode != 0
+
+
+@asynccontextmanager
+async def open_process_then_kill(*args, **kwargs):
+ proc = await open_process(*args, **kwargs)
+ try:
+ yield proc
+ finally:
+ proc.kill()
+ await proc.wait()
+
+
+@asynccontextmanager
+async def run_process_in_nursery(*args, **kwargs):
+ async with _core.open_nursery() as nursery:
+ kwargs.setdefault("check", False)
+ proc = await nursery.start(partial(run_process, *args, **kwargs))
+ yield proc
+ nursery.cancel_scope.cancel()
+
+
+background_process_param = pytest.mark.parametrize(
+ "background_process",
+ [open_process_then_kill, run_process_in_nursery],
+ ids=["open_process", "run_process in nursery"],
+)
+
+
+@background_process_param
+async def test_basic(background_process):
+ async with background_process(EXIT_TRUE) as proc:
+ await proc.wait()
+ assert isinstance(proc, Process)
+ assert proc._pidfd is None
+ assert proc.returncode == 0
+ assert repr(proc) == f"<trio.Process {EXIT_TRUE}: exited with status 0>"
+
+ async with background_process(EXIT_FALSE) as proc:
+ await proc.wait()
+ assert proc.returncode == 1
+ assert repr(proc) == "<trio.Process {!r}: {}>".format(
+ EXIT_FALSE, "exited with status 1"
+ )
+
+
+@background_process_param
+async def test_auto_update_returncode(background_process):
+ async with background_process(SLEEP(9999)) as p:
+ assert p.returncode is None
+ assert "running" in repr(p)
+ p.kill()
+ p._proc.wait()
+ assert p.returncode is not None
+ assert "exited" in repr(p)
+ assert p._pidfd is None
+ assert p.returncode is not None
+
+
+@background_process_param
+async def test_multi_wait(background_process):
+ async with background_process(SLEEP(10)) as proc:
+ # Check that wait (including multi-wait) tolerates being cancelled
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(proc.wait)
+ nursery.start_soon(proc.wait)
+ nursery.start_soon(proc.wait)
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+ # Now try waiting for real
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(proc.wait)
+ nursery.start_soon(proc.wait)
+ nursery.start_soon(proc.wait)
+ await wait_all_tasks_blocked()
+ proc.kill()
+
+
+# Test for deprecated 'async with process:' semantics
+async def test_async_with_basics_deprecated(recwarn):
+ async with await open_process(
+ CAT, stdin=subprocess.PIPE, stdout=subprocess.PIPE
+ ) as proc:
+ pass
+ assert proc.returncode is not None
+ with pytest.raises(ClosedResourceError):
+ await proc.stdin.send_all(b"x")
+ with pytest.raises(ClosedResourceError):
+ await proc.stdout.receive_some()
+
+
+# Test for deprecated 'async with process:' semantics
+async def test_kill_when_context_cancelled(recwarn):
+ with move_on_after(100) as scope:
+ async with await open_process(SLEEP(10)) as proc:
+ assert proc.poll() is None
+ scope.cancel()
+ await sleep_forever()
+ assert scope.cancelled_caught
+ assert got_signal(proc, SIGKILL)
+ assert repr(proc) == "<trio.Process {!r}: {}>".format(
+ SLEEP(10), "exited with signal 9" if posix else "exited with status 1"
+ )
+
+
+COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python(
+ "data = sys.stdin.buffer.read(); "
+ "sys.stdout.buffer.write(data); "
+ "sys.stderr.buffer.write(data[::-1])"
+)
+
+
+@background_process_param
+async def test_pipes(background_process):
+ async with background_process(
+ COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ ) as proc:
+ msg = b"the quick brown fox jumps over the lazy dog"
+
+ async def feed_input():
+ await proc.stdin.send_all(msg)
+ await proc.stdin.aclose()
+
+ async def check_output(stream, expected):
+ seen = bytearray()
+ async for chunk in stream:
+ seen += chunk
+ assert seen == expected
+
+ async with _core.open_nursery() as nursery:
+ # fail eventually if something is broken
+ nursery.cancel_scope.deadline = _core.current_time() + 30.0
+ nursery.start_soon(feed_input)
+ nursery.start_soon(check_output, proc.stdout, msg)
+ nursery.start_soon(check_output, proc.stderr, msg[::-1])
+
+ assert not nursery.cancel_scope.cancelled_caught
+ assert 0 == await proc.wait()
+
+
+@background_process_param
+async def test_interactive(background_process):
+ # Test some back-and-forth with a subprocess. This one works like so:
+ # in: 32\n
+ # out: 0000...0000\n (32 zeroes)
+ # err: 1111...1111\n (64 ones)
+ # in: 10\n
+ # out: 2222222222\n (10 twos)
+ # err: 3333....3333\n (20 threes)
+ # in: EOF
+ # out: EOF
+ # err: EOF
+
+ async with background_process(
+ python(
+ "idx = 0\n"
+ "while True:\n"
+ " line = sys.stdin.readline()\n"
+ " if line == '': break\n"
+ " request = int(line.strip())\n"
+ " print(str(idx * 2) * request)\n"
+ " print(str(idx * 2 + 1) * request * 2, file=sys.stderr)\n"
+ " idx += 1\n"
+ ),
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ ) as proc:
+
+ newline = b"\n" if posix else b"\r\n"
+
+ async def expect(idx, request):
+ async with _core.open_nursery() as nursery:
+
+ async def drain_one(stream, count, digit):
+ while count > 0:
+ result = await stream.receive_some(count)
+ assert result == (
+ "{}".format(digit).encode("utf-8") * len(result)
+ )
+ count -= len(result)
+ assert count == 0
+ assert await stream.receive_some(len(newline)) == newline
+
+ nursery.start_soon(drain_one, proc.stdout, request, idx * 2)
+ nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1)
+
+ with fail_after(5):
+ await proc.stdin.send_all(b"12")
+ await sleep(0.1)
+ await proc.stdin.send_all(b"345" + newline)
+ await expect(0, 12345)
+ await proc.stdin.send_all(b"100" + newline + b"200" + newline)
+ await expect(1, 100)
+ await expect(2, 200)
+ await proc.stdin.send_all(b"0" + newline)
+ await expect(3, 0)
+ await proc.stdin.send_all(b"999999")
+ with move_on_after(0.1) as scope:
+ await expect(4, 0)
+ assert scope.cancelled_caught
+ await proc.stdin.send_all(newline)
+ await expect(4, 999999)
+ await proc.stdin.aclose()
+ assert await proc.stdout.receive_some(1) == b""
+ assert await proc.stderr.receive_some(1) == b""
+ await proc.wait()
+
+ assert proc.returncode == 0
+
+
+async def test_run():
+ data = bytes(random.randint(0, 255) for _ in range(2**18))
+
+ result = await run_process(
+ CAT, stdin=data, capture_stdout=True, capture_stderr=True
+ )
+ assert result.args == CAT
+ assert result.returncode == 0
+ assert result.stdout == data
+ assert result.stderr == b""
+
+ result = await run_process(CAT, capture_stdout=True)
+ assert result.args == CAT
+ assert result.returncode == 0
+ assert result.stdout == b""
+ assert result.stderr is None
+
+ result = await run_process(
+ COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
+ stdin=data,
+ capture_stdout=True,
+ capture_stderr=True,
+ )
+ assert result.args == COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR
+ assert result.returncode == 0
+ assert result.stdout == data
+ assert result.stderr == data[::-1]
+
+ # invalid combinations
+ with pytest.raises(UnicodeError):
+ await run_process(CAT, stdin="oh no, it's text")
+ with pytest.raises(ValueError):
+ await run_process(CAT, stdin=subprocess.PIPE)
+ with pytest.raises(ValueError):
+ await run_process(CAT, stdout=subprocess.PIPE)
+ with pytest.raises(ValueError):
+ await run_process(CAT, stderr=subprocess.PIPE)
+ with pytest.raises(ValueError):
+ await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL)
+ with pytest.raises(ValueError):
+ await run_process(CAT, capture_stderr=True, stderr=None)
+
+
+async def test_run_check():
+ cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)")
+ with pytest.raises(subprocess.CalledProcessError) as excinfo:
+ await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True)
+ assert excinfo.value.cmd == cmd
+ assert excinfo.value.returncode == 1
+ assert excinfo.value.stderr == b"test\n"
+ assert excinfo.value.stdout is None
+
+ result = await run_process(
+ cmd, capture_stdout=True, capture_stderr=True, check=False
+ )
+ assert result.args == cmd
+ assert result.stdout == b""
+ assert result.stderr == b"test\n"
+ assert result.returncode == 1
+
+
+@skip_if_fbsd_pipes_broken
+async def test_run_with_broken_pipe():
+ result = await run_process(
+ [sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072
+ )
+ assert result.returncode == 0
+ assert result.stdout is result.stderr is None
+
+
+@background_process_param
+async def test_stderr_stdout(background_process):
+ async with background_process(
+ COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ ) as proc:
+ assert proc.stdout is not None
+ assert proc.stderr is None
+ await proc.stdio.send_all(b"1234")
+ await proc.stdio.send_eof()
+
+ output = []
+ while True:
+ chunk = await proc.stdio.receive_some(16)
+ if chunk == b"":
+ break
+ output.append(chunk)
+ assert b"".join(output) == b"12344321"
+ assert proc.returncode == 0
+
+ # equivalent test with run_process()
+ result = await run_process(
+ COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
+ stdin=b"1234",
+ capture_stdout=True,
+ stderr=subprocess.STDOUT,
+ )
+ assert result.returncode == 0
+ assert result.stdout == b"12344321"
+ assert result.stderr is None
+
+ # this one hits the branch where stderr=STDOUT but stdout
+ # is not redirected
+ async with background_process(
+ CAT, stdin=subprocess.PIPE, stderr=subprocess.STDOUT
+ ) as proc:
+ assert proc.stdout is None
+ assert proc.stderr is None
+ await proc.stdin.aclose()
+ await proc.wait()
+ assert proc.returncode == 0
+
+ if posix:
+ try:
+ r, w = os.pipe()
+
+ async with background_process(
+ COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
+ stdin=subprocess.PIPE,
+ stdout=w,
+ stderr=subprocess.STDOUT,
+ ) as proc:
+ os.close(w)
+ assert proc.stdio is None
+ assert proc.stdout is None
+ assert proc.stderr is None
+ await proc.stdin.send_all(b"1234")
+ await proc.stdin.aclose()
+ assert await proc.wait() == 0
+ assert os.read(r, 4096) == b"12344321"
+ assert os.read(r, 4096) == b""
+ finally:
+ os.close(r)
+
+
+async def test_errors():
+ with pytest.raises(TypeError) as excinfo:
+ await open_process(["ls"], encoding="utf-8")
+ assert "unbuffered byte streams" in str(excinfo.value)
+ assert "the 'encoding' option is not supported" in str(excinfo.value)
+
+ if posix:
+ with pytest.raises(TypeError) as excinfo:
+ await open_process(["ls"], shell=True)
+ with pytest.raises(TypeError) as excinfo:
+ await open_process("ls", shell=False)
+
+
+@background_process_param
+async def test_signals(background_process):
+ async def test_one_signal(send_it, signum):
+ with move_on_after(1.0) as scope:
+ async with background_process(SLEEP(3600)) as proc:
+ send_it(proc)
+ await proc.wait()
+ assert not scope.cancelled_caught
+ if posix:
+ assert proc.returncode == -signum
+ else:
+ assert proc.returncode != 0
+
+ await test_one_signal(Process.kill, SIGKILL)
+ await test_one_signal(Process.terminate, SIGTERM)
+ # Test that we can send arbitrary signals.
+ #
+ # We used to use SIGINT here, but it turns out that the Python interpreter
+ # has race conditions that can cause it to explode in weird ways if it
+ # tries to handle SIGINT during startup. SIGUSR1's default disposition is
+ # to terminate the target process, and Python doesn't try to do anything
+ # clever to handle it.
+ if posix:
+ await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1)
+
+
+@pytest.mark.skipif(not posix, reason="POSIX specific")
+@background_process_param
+async def test_wait_reapable_fails(background_process):
+ old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
+ try:
+ # With SIGCHLD disabled, the wait() syscall will wait for the
+ # process to exit but then fail with ECHILD. Make sure we
+ # support this case as the stdlib subprocess module does.
+ async with background_process(SLEEP(3600)) as proc:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(proc.wait)
+ await wait_all_tasks_blocked()
+ proc.kill()
+ nursery.cancel_scope.deadline = _core.current_time() + 1.0
+ assert not nursery.cancel_scope.cancelled_caught
+ assert proc.returncode == 0 # exit status unknowable, so...
+ finally:
+ signal.signal(signal.SIGCHLD, old_sigchld)
+
+
+@slow
+def test_waitid_eintr():
+ # This only matters on PyPy (where we're coding EINTR handling
+ # ourselves) but the test works on all waitid platforms.
+ from .._subprocess_platform import wait_child_exiting
+
+ if not wait_child_exiting.__module__.endswith("waitid"):
+ pytest.skip("waitid only")
+ from .._subprocess_platform.waitid import sync_wait_reapable
+
+ got_alarm = False
+ sleeper = subprocess.Popen(["sleep", "3600"])
+
+ def on_alarm(sig, frame):
+ nonlocal got_alarm
+ got_alarm = True
+ sleeper.kill()
+
+ old_sigalrm = signal.signal(signal.SIGALRM, on_alarm)
+ try:
+ signal.alarm(1)
+ sync_wait_reapable(sleeper.pid)
+ assert sleeper.wait(timeout=1) == -9
+ finally:
+ if sleeper.returncode is None: # pragma: no cover
+ # We only get here if something fails in the above;
+ # if the test passes, wait() will reap the process
+ sleeper.kill()
+ sleeper.wait()
+ signal.signal(signal.SIGALRM, old_sigalrm)
+
+
+async def test_custom_deliver_cancel():
+ custom_deliver_cancel_called = False
+
+ async def custom_deliver_cancel(proc):
+ nonlocal custom_deliver_cancel_called
+ custom_deliver_cancel_called = True
+ proc.terminate()
+ # Make sure this does get cancelled when the process exits, and that
+ # the process really exited.
+ try:
+ await sleep_forever()
+ finally:
+ assert proc.returncode is not None
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(
+ partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel)
+ )
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+ assert custom_deliver_cancel_called
+
+
+async def test_warn_on_failed_cancel_terminate(monkeypatch):
+ original_terminate = Process.terminate
+
+ def broken_terminate(self):
+ original_terminate(self)
+ raise OSError("whoops")
+
+ monkeypatch.setattr(Process, "terminate", broken_terminate)
+
+ with pytest.warns(RuntimeWarning, match=".*whoops.*"):
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(run_process, SLEEP(9999))
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+
+@pytest.mark.skipif(not posix, reason="posix only")
+async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch):
+ monkeypatch.setattr(Process, "terminate", lambda *args: None)
+
+ with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"):
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(run_process, SLEEP(9999))
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+
+# the background_process_param exercises a lot of run_process cases, but it uses
+# check=False, so lets have a test that uses check=True as well
+async def test_run_process_background_fail():
+ with pytest.raises(subprocess.CalledProcessError):
+ async with _core.open_nursery() as nursery:
+ proc = await nursery.start(run_process, EXIT_FALSE)
+ assert proc.returncode == 1
+
+
+@pytest.mark.skipif(
+ not SyncPath("/dev/fd").exists(),
+ reason="requires a way to iterate through open files",
+)
+async def test_for_leaking_fds():
+ starting_fds = set(SyncPath("/dev/fd").iterdir())
+ await run_process(EXIT_TRUE)
+ assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
+
+ with pytest.raises(subprocess.CalledProcessError):
+ await run_process(EXIT_FALSE)
+ assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
+
+ with pytest.raises(PermissionError):
+ await run_process(["/dev/fd/0"])
+ assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
+
+
+# regression test for #2209
+async def test_subprocess_pidfd_unnotified():
+ noticed_exit = None
+
+ async def wait_and_tell(proc) -> None:
+ nonlocal noticed_exit
+ noticed_exit = Event()
+ await proc.wait()
+ noticed_exit.set()
+
+ proc = await open_process(SLEEP(9999))
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(wait_and_tell, proc)
+ await wait_all_tasks_blocked()
+ assert isinstance(noticed_exit, Event)
+ proc.terminate()
+ # without giving trio a chance to do so,
+ with assert_no_checkpoints():
+ # wait until the process has actually exited;
+ proc._proc.wait()
+ # force a call to poll (that closes the pidfd on linux)
+ proc.poll()
+ with move_on_after(5):
+ # Some platforms use threads to wait for exit, so it might take a bit
+ # for everything to notice
+ await noticed_exit.wait()
+ assert noticed_exit.is_set(), "child task wasn't woken after poll, DEADLOCK"
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_sync.py b/venv/lib/python3.9/site-packages/trio/tests/test_sync.py
new file mode 100644
index 00000000..33f79c4d
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_sync.py
@@ -0,0 +1,567 @@
+import pytest
+
+import weakref
+
+from ..testing import wait_all_tasks_blocked, assert_checkpoints
+
+from .. import _core
+from .. import _timeouts
+from .._timeouts import sleep_forever, move_on_after
+from .._sync import *
+
+
+async def test_Event():
+ e = Event()
+ assert not e.is_set()
+ assert e.statistics().tasks_waiting == 0
+
+ e.set()
+ assert e.is_set()
+ with assert_checkpoints():
+ await e.wait()
+
+ e = Event()
+
+ record = []
+
+ async def child():
+ record.append("sleeping")
+ await e.wait()
+ record.append("woken")
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(child)
+ nursery.start_soon(child)
+ await wait_all_tasks_blocked()
+ assert record == ["sleeping", "sleeping"]
+ assert e.statistics().tasks_waiting == 2
+ e.set()
+ await wait_all_tasks_blocked()
+ assert record == ["sleeping", "sleeping", "woken", "woken"]
+
+
+async def test_CapacityLimiter():
+ with pytest.raises(TypeError):
+ CapacityLimiter(1.0)
+ with pytest.raises(ValueError):
+ CapacityLimiter(-1)
+ c = CapacityLimiter(2)
+ repr(c) # smoke test
+ assert c.total_tokens == 2
+ assert c.borrowed_tokens == 0
+ assert c.available_tokens == 2
+ with pytest.raises(RuntimeError):
+ c.release()
+ assert c.borrowed_tokens == 0
+ c.acquire_nowait()
+ assert c.borrowed_tokens == 1
+ assert c.available_tokens == 1
+
+ stats = c.statistics()
+ assert stats.borrowed_tokens == 1
+ assert stats.total_tokens == 2
+ assert stats.borrowers == [_core.current_task()]
+ assert stats.tasks_waiting == 0
+
+ # Can't re-acquire when we already have it
+ with pytest.raises(RuntimeError):
+ c.acquire_nowait()
+ assert c.borrowed_tokens == 1
+ with pytest.raises(RuntimeError):
+ await c.acquire()
+ assert c.borrowed_tokens == 1
+
+ # We can acquire on behalf of someone else though
+ with assert_checkpoints():
+ await c.acquire_on_behalf_of("someone")
+
+ # But then we've run out of capacity
+ assert c.borrowed_tokens == 2
+ with pytest.raises(_core.WouldBlock):
+ c.acquire_on_behalf_of_nowait("third party")
+
+ assert set(c.statistics().borrowers) == {_core.current_task(), "someone"}
+
+ # Until we release one
+ c.release_on_behalf_of(_core.current_task())
+ assert c.statistics().borrowers == ["someone"]
+
+ c.release_on_behalf_of("someone")
+ assert c.borrowed_tokens == 0
+ with assert_checkpoints():
+ async with c:
+ assert c.borrowed_tokens == 1
+
+ async with _core.open_nursery() as nursery:
+ await c.acquire_on_behalf_of("value 1")
+ await c.acquire_on_behalf_of("value 2")
+ nursery.start_soon(c.acquire_on_behalf_of, "value 3")
+ await wait_all_tasks_blocked()
+ assert c.borrowed_tokens == 2
+ assert c.statistics().tasks_waiting == 1
+ c.release_on_behalf_of("value 2")
+ # Fairness:
+ assert c.borrowed_tokens == 2
+ with pytest.raises(_core.WouldBlock):
+ c.acquire_nowait()
+
+ c.release_on_behalf_of("value 3")
+ c.release_on_behalf_of("value 1")
+
+
+async def test_CapacityLimiter_inf():
+ from math import inf
+
+ c = CapacityLimiter(inf)
+ repr(c) # smoke test
+ assert c.total_tokens == inf
+ assert c.borrowed_tokens == 0
+ assert c.available_tokens == inf
+ with pytest.raises(RuntimeError):
+ c.release()
+ assert c.borrowed_tokens == 0
+ c.acquire_nowait()
+ assert c.borrowed_tokens == 1
+ assert c.available_tokens == inf
+
+
+async def test_CapacityLimiter_change_total_tokens():
+ c = CapacityLimiter(2)
+
+ with pytest.raises(TypeError):
+ c.total_tokens = 1.0
+
+ with pytest.raises(ValueError):
+ c.total_tokens = 0
+
+ with pytest.raises(ValueError):
+ c.total_tokens = -10
+
+ assert c.total_tokens == 2
+
+ async with _core.open_nursery() as nursery:
+ for i in range(5):
+ nursery.start_soon(c.acquire_on_behalf_of, i)
+ await wait_all_tasks_blocked()
+ assert set(c.statistics().borrowers) == {0, 1}
+ assert c.statistics().tasks_waiting == 3
+ c.total_tokens += 2
+ assert set(c.statistics().borrowers) == {0, 1, 2, 3}
+ assert c.statistics().tasks_waiting == 1
+ c.total_tokens -= 3
+ assert c.borrowed_tokens == 4
+ assert c.total_tokens == 1
+ c.release_on_behalf_of(0)
+ c.release_on_behalf_of(1)
+ c.release_on_behalf_of(2)
+ assert set(c.statistics().borrowers) == {3}
+ assert c.statistics().tasks_waiting == 1
+ c.release_on_behalf_of(3)
+ assert set(c.statistics().borrowers) == {4}
+ assert c.statistics().tasks_waiting == 0
+
+
+# regression test for issue #548
+async def test_CapacityLimiter_memleak_548():
+ limiter = CapacityLimiter(total_tokens=1)
+ await limiter.acquire()
+
+ async with _core.open_nursery() as n:
+ n.start_soon(limiter.acquire)
+ await wait_all_tasks_blocked() # give it a chance to run the task
+ n.cancel_scope.cancel()
+
+ # if this is 1, the acquire call (despite being killed) is still there in the task, and will
+ # leak memory all the while the limiter is active
+ assert len(limiter._pending_borrowers) == 0
+
+
+async def test_Semaphore():
+ with pytest.raises(TypeError):
+ Semaphore(1.0)
+ with pytest.raises(ValueError):
+ Semaphore(-1)
+ s = Semaphore(1)
+ repr(s) # smoke test
+ assert s.value == 1
+ assert s.max_value is None
+ s.release()
+ assert s.value == 2
+ assert s.statistics().tasks_waiting == 0
+ s.acquire_nowait()
+ assert s.value == 1
+ with assert_checkpoints():
+ await s.acquire()
+ assert s.value == 0
+ with pytest.raises(_core.WouldBlock):
+ s.acquire_nowait()
+
+ s.release()
+ assert s.value == 1
+ with assert_checkpoints():
+ async with s:
+ assert s.value == 0
+ assert s.value == 1
+ s.acquire_nowait()
+
+ record = []
+
+ async def do_acquire(s):
+ record.append("started")
+ await s.acquire()
+ record.append("finished")
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(do_acquire, s)
+ await wait_all_tasks_blocked()
+ assert record == ["started"]
+ assert s.value == 0
+ s.release()
+ # Fairness:
+ assert s.value == 0
+ with pytest.raises(_core.WouldBlock):
+ s.acquire_nowait()
+ assert record == ["started", "finished"]
+
+
+async def test_Semaphore_bounded():
+ with pytest.raises(TypeError):
+ Semaphore(1, max_value=1.0)
+ with pytest.raises(ValueError):
+ Semaphore(2, max_value=1)
+ bs = Semaphore(1, max_value=1)
+ assert bs.max_value == 1
+ repr(bs) # smoke test
+ with pytest.raises(ValueError):
+ bs.release()
+ assert bs.value == 1
+ bs.acquire_nowait()
+ assert bs.value == 0
+ bs.release()
+ assert bs.value == 1
+
+
+@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
+async def test_Lock_and_StrictFIFOLock(lockcls):
+ l = lockcls() # noqa
+ assert not l.locked()
+
+ # make sure locks can be weakref'ed (gh-331)
+ r = weakref.ref(l)
+ assert r() is l
+
+ repr(l) # smoke test
+ # make sure repr uses the right name for subclasses
+ assert lockcls.__name__ in repr(l)
+ with assert_checkpoints():
+ async with l:
+ assert l.locked()
+ repr(l) # smoke test (repr branches on locked/unlocked)
+ assert not l.locked()
+ l.acquire_nowait()
+ assert l.locked()
+ l.release()
+ assert not l.locked()
+ with assert_checkpoints():
+ await l.acquire()
+ assert l.locked()
+ l.release()
+ assert not l.locked()
+
+ l.acquire_nowait()
+ with pytest.raises(RuntimeError):
+ # Error out if we already own the lock
+ l.acquire_nowait()
+ l.release()
+ with pytest.raises(RuntimeError):
+ # Error out if we don't own the lock
+ l.release()
+
+ holder_task = None
+
+ async def holder():
+ nonlocal holder_task
+ holder_task = _core.current_task()
+ async with l:
+ await sleep_forever()
+
+ async with _core.open_nursery() as nursery:
+ assert not l.locked()
+ nursery.start_soon(holder)
+ await wait_all_tasks_blocked()
+ assert l.locked()
+ # WouldBlock if someone else holds the lock
+ with pytest.raises(_core.WouldBlock):
+ l.acquire_nowait()
+ # Can't release a lock someone else holds
+ with pytest.raises(RuntimeError):
+ l.release()
+
+ statistics = l.statistics()
+ print(statistics)
+ assert statistics.locked
+ assert statistics.owner is holder_task
+ assert statistics.tasks_waiting == 0
+
+ nursery.start_soon(holder)
+ await wait_all_tasks_blocked()
+ statistics = l.statistics()
+ print(statistics)
+ assert statistics.tasks_waiting == 1
+
+ nursery.cancel_scope.cancel()
+
+ statistics = l.statistics()
+ assert not statistics.locked
+ assert statistics.owner is None
+ assert statistics.tasks_waiting == 0
+
+
+async def test_Condition():
+ with pytest.raises(TypeError):
+ Condition(Semaphore(1))
+ with pytest.raises(TypeError):
+ Condition(StrictFIFOLock)
+ l = Lock() # noqa
+ c = Condition(l)
+ assert not l.locked()
+ assert not c.locked()
+ with assert_checkpoints():
+ await c.acquire()
+ assert l.locked()
+ assert c.locked()
+
+ c = Condition()
+ assert not c.locked()
+ c.acquire_nowait()
+ assert c.locked()
+ with pytest.raises(RuntimeError):
+ c.acquire_nowait()
+ c.release()
+
+ with pytest.raises(RuntimeError):
+ # Can't wait without holding the lock
+ await c.wait()
+ with pytest.raises(RuntimeError):
+ # Can't notify without holding the lock
+ c.notify()
+ with pytest.raises(RuntimeError):
+ # Can't notify without holding the lock
+ c.notify_all()
+
+ finished_waiters = set()
+
+ async def waiter(i):
+ async with c:
+ await c.wait()
+ finished_waiters.add(i)
+
+ async with _core.open_nursery() as nursery:
+ for i in range(3):
+ nursery.start_soon(waiter, i)
+ await wait_all_tasks_blocked()
+ async with c:
+ c.notify()
+ assert c.locked()
+ await wait_all_tasks_blocked()
+ assert finished_waiters == {0}
+ async with c:
+ c.notify_all()
+ await wait_all_tasks_blocked()
+ assert finished_waiters == {0, 1, 2}
+
+ finished_waiters = set()
+ async with _core.open_nursery() as nursery:
+ for i in range(3):
+ nursery.start_soon(waiter, i)
+ await wait_all_tasks_blocked()
+ async with c:
+ c.notify(2)
+ statistics = c.statistics()
+ print(statistics)
+ assert statistics.tasks_waiting == 1
+ assert statistics.lock_statistics.tasks_waiting == 2
+ # exiting the context manager hands off the lock to the first task
+ assert c.statistics().lock_statistics.tasks_waiting == 1
+
+ await wait_all_tasks_blocked()
+ assert finished_waiters == {0, 1}
+
+ async with c:
+ c.notify_all()
+
+ # After being cancelled still hold the lock (!)
+ # (Note that c.__aexit__ checks that we hold the lock as well)
+ with _core.CancelScope() as scope:
+ async with c:
+ scope.cancel()
+ try:
+ await c.wait()
+ finally:
+ assert c.locked()
+
+
+from .._sync import AsyncContextManagerMixin
+from .._channel import open_memory_channel
+
+# Three ways of implementing a Lock in terms of a channel. Used to let us put
+# the channel through the generic lock tests.
+
+
+class ChannelLock1(AsyncContextManagerMixin):
+ def __init__(self, capacity):
+ self.s, self.r = open_memory_channel(capacity)
+ for _ in range(capacity - 1):
+ self.s.send_nowait(None)
+
+ def acquire_nowait(self):
+ self.s.send_nowait(None)
+
+ async def acquire(self):
+ await self.s.send(None)
+
+ def release(self):
+ self.r.receive_nowait()
+
+
+class ChannelLock2(AsyncContextManagerMixin):
+ def __init__(self):
+ self.s, self.r = open_memory_channel(10)
+ self.s.send_nowait(None)
+
+ def acquire_nowait(self):
+ self.r.receive_nowait()
+
+ async def acquire(self):
+ await self.r.receive()
+
+ def release(self):
+ self.s.send_nowait(None)
+
+
+class ChannelLock3(AsyncContextManagerMixin):
+ def __init__(self):
+ self.s, self.r = open_memory_channel(0)
+ # self.acquired is true when one task acquires the lock and
+ # only becomes false when it's released and no tasks are
+ # waiting to acquire.
+ self.acquired = False
+
+ def acquire_nowait(self):
+ assert not self.acquired
+ self.acquired = True
+
+ async def acquire(self):
+ if self.acquired:
+ await self.s.send(None)
+ else:
+ self.acquired = True
+ await _core.checkpoint()
+
+ def release(self):
+ try:
+ self.r.receive_nowait()
+ except _core.WouldBlock:
+ assert self.acquired
+ self.acquired = False
+
+
+lock_factories = [
+ lambda: CapacityLimiter(1),
+ lambda: Semaphore(1),
+ Lock,
+ StrictFIFOLock,
+ lambda: ChannelLock1(10),
+ lambda: ChannelLock1(1),
+ ChannelLock2,
+ ChannelLock3,
+]
+lock_factory_names = [
+ "CapacityLimiter(1)",
+ "Semaphore(1)",
+ "Lock",
+ "StrictFIFOLock",
+ "ChannelLock1(10)",
+ "ChannelLock1(1)",
+ "ChannelLock2",
+ "ChannelLock3",
+]
+
+generic_lock_test = pytest.mark.parametrize(
+ "lock_factory", lock_factories, ids=lock_factory_names
+)
+
+
+# Spawn a bunch of workers that take a lock and then yield; make sure that
+# only one worker is ever in the critical section at a time.
+@generic_lock_test
+async def test_generic_lock_exclusion(lock_factory):
+ LOOPS = 10
+ WORKERS = 5
+ in_critical_section = False
+ acquires = 0
+
+ async def worker(lock_like):
+ nonlocal in_critical_section, acquires
+ for _ in range(LOOPS):
+ async with lock_like:
+ acquires += 1
+ assert not in_critical_section
+ in_critical_section = True
+ await _core.checkpoint()
+ await _core.checkpoint()
+ assert in_critical_section
+ in_critical_section = False
+
+ async with _core.open_nursery() as nursery:
+ lock_like = lock_factory()
+ for _ in range(WORKERS):
+ nursery.start_soon(worker, lock_like)
+ assert not in_critical_section
+ assert acquires == LOOPS * WORKERS
+
+
+# Several workers queue on the same lock; make sure they each get it, in
+# order.
+@generic_lock_test
+async def test_generic_lock_fifo_fairness(lock_factory):
+ initial_order = []
+ record = []
+ LOOPS = 5
+
+ async def loopy(name, lock_like):
+ # Record the order each task was initially scheduled in
+ initial_order.append(name)
+ for _ in range(LOOPS):
+ async with lock_like:
+ record.append(name)
+
+ lock_like = lock_factory()
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(loopy, 1, lock_like)
+ nursery.start_soon(loopy, 2, lock_like)
+ nursery.start_soon(loopy, 3, lock_like)
+ # The first three could be in any order due to scheduling randomness,
+ # but after that they should repeat in the same order
+ for i in range(LOOPS):
+ assert record[3 * i : 3 * (i + 1)] == initial_order
+
+
+@generic_lock_test
+async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory):
+ lock_like = lock_factory()
+
+ record = []
+
+ async def lock_taker():
+ record.append("started")
+ async with lock_like:
+ pass
+ record.append("finished")
+
+ async with _core.open_nursery() as nursery:
+ lock_like.acquire_nowait()
+ nursery.start_soon(lock_taker)
+ await wait_all_tasks_blocked()
+ assert record == ["started"]
+ lock_like.release()
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_testing.py b/venv/lib/python3.9/site-packages/trio/tests/test_testing.py
new file mode 100644
index 00000000..0b10ae71
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_testing.py
@@ -0,0 +1,657 @@
+# XX this should get broken up, like testing.py did
+
+import tempfile
+
+import pytest
+
+from .._core.tests.tutil import can_bind_ipv6
+from .. import sleep
+from .. import _core
+from .._highlevel_generic import aclose_forcefully
+from ..testing import *
+from ..testing._check_streams import _assert_raises
+from ..testing._memory_streams import _UnboundedByteQueue
+from .. import socket as tsocket
+from .._highlevel_socket import SocketListener
+
+
+async def test_wait_all_tasks_blocked():
+ record = []
+
+ async def busy_bee():
+ for _ in range(10):
+ await _core.checkpoint()
+ record.append("busy bee exhausted")
+
+ async def waiting_for_bee_to_leave():
+ await wait_all_tasks_blocked()
+ record.append("quiet at last!")
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(busy_bee)
+ nursery.start_soon(waiting_for_bee_to_leave)
+ nursery.start_soon(waiting_for_bee_to_leave)
+
+ # check cancellation
+ record = []
+
+ async def cancelled_while_waiting():
+ try:
+ await wait_all_tasks_blocked()
+ except _core.Cancelled:
+ record.append("ok")
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(cancelled_while_waiting)
+ nursery.cancel_scope.cancel()
+ assert record == ["ok"]
+
+
+async def test_wait_all_tasks_blocked_with_timeouts(mock_clock):
+ record = []
+
+ async def timeout_task():
+ record.append("tt start")
+ await sleep(5)
+ record.append("tt finished")
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(timeout_task)
+ await wait_all_tasks_blocked()
+ assert record == ["tt start"]
+ mock_clock.jump(10)
+ await wait_all_tasks_blocked()
+ assert record == ["tt start", "tt finished"]
+
+
+async def test_wait_all_tasks_blocked_with_cushion():
+ record = []
+
+ async def blink():
+ record.append("blink start")
+ await sleep(0.01)
+ await sleep(0.01)
+ await sleep(0.01)
+ record.append("blink end")
+
+ async def wait_no_cushion():
+ await wait_all_tasks_blocked()
+ record.append("wait_no_cushion end")
+
+ async def wait_small_cushion():
+ await wait_all_tasks_blocked(0.02)
+ record.append("wait_small_cushion end")
+
+ async def wait_big_cushion():
+ await wait_all_tasks_blocked(0.03)
+ record.append("wait_big_cushion end")
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(blink)
+ nursery.start_soon(wait_no_cushion)
+ nursery.start_soon(wait_small_cushion)
+ nursery.start_soon(wait_small_cushion)
+ nursery.start_soon(wait_big_cushion)
+
+ assert record == [
+ "blink start",
+ "wait_no_cushion end",
+ "blink end",
+ "wait_small_cushion end",
+ "wait_small_cushion end",
+ "wait_big_cushion end",
+ ]
+
+
+################################################################
+
+
+async def test_assert_checkpoints(recwarn):
+ with assert_checkpoints():
+ await _core.checkpoint()
+
+ with pytest.raises(AssertionError):
+ with assert_checkpoints():
+ 1 + 1
+
+ # partial yield cases
+ # if you have a schedule point but not a cancel point, or vice-versa, then
+ # that's not a checkpoint.
+ for partial_yield in [
+ _core.checkpoint_if_cancelled,
+ _core.cancel_shielded_checkpoint,
+ ]:
+ print(partial_yield)
+ with pytest.raises(AssertionError):
+ with assert_checkpoints():
+ await partial_yield()
+
+ # But both together count as a checkpoint
+ with assert_checkpoints():
+ await _core.checkpoint_if_cancelled()
+ await _core.cancel_shielded_checkpoint()
+
+
+async def test_assert_no_checkpoints(recwarn):
+ with assert_no_checkpoints():
+ 1 + 1
+
+ with pytest.raises(AssertionError):
+ with assert_no_checkpoints():
+ await _core.checkpoint()
+
+ # partial yield cases
+ # if you have a schedule point but not a cancel point, or vice-versa, then
+ # that doesn't make *either* version of assert_{no_,}yields happy.
+ for partial_yield in [
+ _core.checkpoint_if_cancelled,
+ _core.cancel_shielded_checkpoint,
+ ]:
+ print(partial_yield)
+ with pytest.raises(AssertionError):
+ with assert_no_checkpoints():
+ await partial_yield()
+
+ # And both together also count as a checkpoint
+ with pytest.raises(AssertionError):
+ with assert_no_checkpoints():
+ await _core.checkpoint_if_cancelled()
+ await _core.cancel_shielded_checkpoint()
+
+
+################################################################
+
+
+async def test_Sequencer():
+ record = []
+
+ def t(val):
+ print(val)
+ record.append(val)
+
+ async def f1(seq):
+ async with seq(1):
+ t(("f1", 1))
+ async with seq(3):
+ t(("f1", 3))
+ async with seq(4):
+ t(("f1", 4))
+
+ async def f2(seq):
+ async with seq(0):
+ t(("f2", 0))
+ async with seq(2):
+ t(("f2", 2))
+
+ seq = Sequencer()
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(f1, seq)
+ nursery.start_soon(f2, seq)
+ async with seq(5):
+ await wait_all_tasks_blocked()
+ assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)]
+
+ seq = Sequencer()
+ # Catches us if we try to re-use a sequence point:
+ async with seq(0):
+ pass
+ with pytest.raises(RuntimeError):
+ async with seq(0):
+ pass # pragma: no cover
+
+
+async def test_Sequencer_cancel():
+ # Killing a blocked task makes everything blow up
+ record = []
+ seq = Sequencer()
+
+ async def child(i):
+ with _core.CancelScope() as scope:
+ if i == 1:
+ scope.cancel()
+ try:
+ async with seq(i):
+ pass # pragma: no cover
+ except RuntimeError:
+ record.append("seq({}) RuntimeError".format(i))
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(child, 1)
+ nursery.start_soon(child, 2)
+ async with seq(0):
+ pass # pragma: no cover
+
+ assert record == ["seq(1) RuntimeError", "seq(2) RuntimeError"]
+
+ # Late arrivals also get errors
+ with pytest.raises(RuntimeError):
+ async with seq(3):
+ pass # pragma: no cover
+
+
+################################################################
+
+
+async def test__assert_raises():
+ with pytest.raises(AssertionError):
+ with _assert_raises(RuntimeError):
+ 1 + 1
+
+ with pytest.raises(TypeError):
+ with _assert_raises(RuntimeError):
+ "foo" + 1
+
+ with _assert_raises(RuntimeError):
+ raise RuntimeError
+
+
+# This is a private implementation detail, but it's complex enough to be worth
+# testing directly
+async def test__UnboundeByteQueue():
+ ubq = _UnboundedByteQueue()
+
+ ubq.put(b"123")
+ ubq.put(b"456")
+ assert ubq.get_nowait(1) == b"1"
+ assert ubq.get_nowait(10) == b"23456"
+ ubq.put(b"789")
+ assert ubq.get_nowait() == b"789"
+
+ with pytest.raises(_core.WouldBlock):
+ ubq.get_nowait(10)
+ with pytest.raises(_core.WouldBlock):
+ ubq.get_nowait()
+
+ with pytest.raises(TypeError):
+ ubq.put("string")
+
+ ubq.put(b"abc")
+ with assert_checkpoints():
+ assert await ubq.get(10) == b"abc"
+ ubq.put(b"def")
+ ubq.put(b"ghi")
+ with assert_checkpoints():
+ assert await ubq.get(1) == b"d"
+ with assert_checkpoints():
+ assert await ubq.get() == b"efghi"
+
+ async def putter(data):
+ await wait_all_tasks_blocked()
+ ubq.put(data)
+
+ async def getter(expect):
+ with assert_checkpoints():
+ assert await ubq.get() == expect
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(getter, b"xyz")
+ nursery.start_soon(putter, b"xyz")
+
+ # Two gets at the same time -> BusyResourceError
+ with pytest.raises(_core.BusyResourceError):
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(getter, b"asdf")
+ nursery.start_soon(getter, b"asdf")
+
+ # Closing
+
+ ubq.close()
+ with pytest.raises(_core.ClosedResourceError):
+ ubq.put(b"---")
+
+ assert ubq.get_nowait(10) == b""
+ assert ubq.get_nowait() == b""
+ assert await ubq.get(10) == b""
+ assert await ubq.get() == b""
+
+ # close is idempotent
+ ubq.close()
+
+ # close wakes up blocked getters
+ ubq2 = _UnboundedByteQueue()
+
+ async def closer():
+ await wait_all_tasks_blocked()
+ ubq2.close()
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(getter, b"")
+ nursery.start_soon(closer)
+
+
+async def test_MemorySendStream():
+ mss = MemorySendStream()
+
+ async def do_send_all(data):
+ with assert_checkpoints():
+ await mss.send_all(data)
+
+ await do_send_all(b"123")
+ assert mss.get_data_nowait(1) == b"1"
+ assert mss.get_data_nowait() == b"23"
+
+ with assert_checkpoints():
+ await mss.wait_send_all_might_not_block()
+
+ with pytest.raises(_core.WouldBlock):
+ mss.get_data_nowait()
+ with pytest.raises(_core.WouldBlock):
+ mss.get_data_nowait(10)
+
+ await do_send_all(b"456")
+ with assert_checkpoints():
+ assert await mss.get_data() == b"456"
+
+ # Call send_all twice at once; one should get BusyResourceError and one
+ # should succeed. But we can't let the error propagate, because it might
+ # cause the other to be cancelled before it can finish doing its thing,
+ # and we don't know which one will get the error.
+ resource_busy_count = 0
+
+ async def do_send_all_count_resourcebusy():
+ nonlocal resource_busy_count
+ try:
+ await do_send_all(b"xxx")
+ except _core.BusyResourceError:
+ resource_busy_count += 1
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(do_send_all_count_resourcebusy)
+ nursery.start_soon(do_send_all_count_resourcebusy)
+
+ assert resource_busy_count == 1
+
+ with assert_checkpoints():
+ await mss.aclose()
+
+ assert await mss.get_data() == b"xxx"
+ assert await mss.get_data() == b""
+ with pytest.raises(_core.ClosedResourceError):
+ await do_send_all(b"---")
+
+ # hooks
+
+ assert mss.send_all_hook is None
+ assert mss.wait_send_all_might_not_block_hook is None
+ assert mss.close_hook is None
+
+ record = []
+
+ async def send_all_hook():
+ # hook runs after send_all does its work (can pull data out)
+ assert mss2.get_data_nowait() == b"abc"
+ record.append("send_all_hook")
+
+ async def wait_send_all_might_not_block_hook():
+ record.append("wait_send_all_might_not_block_hook")
+
+ def close_hook():
+ record.append("close_hook")
+
+ mss2 = MemorySendStream(
+ send_all_hook, wait_send_all_might_not_block_hook, close_hook
+ )
+
+ assert mss2.send_all_hook is send_all_hook
+ assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook
+ assert mss2.close_hook is close_hook
+
+ await mss2.send_all(b"abc")
+ await mss2.wait_send_all_might_not_block()
+ await aclose_forcefully(mss2)
+ mss2.close()
+
+ assert record == [
+ "send_all_hook",
+ "wait_send_all_might_not_block_hook",
+ "close_hook",
+ "close_hook",
+ ]
+
+
+async def test_MemoryReceiveStream():
+ mrs = MemoryReceiveStream()
+
+ async def do_receive_some(max_bytes):
+ with assert_checkpoints():
+ return await mrs.receive_some(max_bytes)
+
+ mrs.put_data(b"abc")
+ assert await do_receive_some(1) == b"a"
+ assert await do_receive_some(10) == b"bc"
+ mrs.put_data(b"abc")
+ assert await do_receive_some(None) == b"abc"
+
+ with pytest.raises(_core.BusyResourceError):
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(do_receive_some, 10)
+ nursery.start_soon(do_receive_some, 10)
+
+ assert mrs.receive_some_hook is None
+
+ mrs.put_data(b"def")
+ mrs.put_eof()
+ mrs.put_eof()
+
+ assert await do_receive_some(10) == b"def"
+ assert await do_receive_some(10) == b""
+ assert await do_receive_some(10) == b""
+
+ with pytest.raises(_core.ClosedResourceError):
+ mrs.put_data(b"---")
+
+ async def receive_some_hook():
+ mrs2.put_data(b"xxx")
+
+ record = []
+
+ def close_hook():
+ record.append("closed")
+
+ mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
+ assert mrs2.receive_some_hook is receive_some_hook
+ assert mrs2.close_hook is close_hook
+
+ mrs2.put_data(b"yyy")
+ assert await mrs2.receive_some(10) == b"yyyxxx"
+ assert await mrs2.receive_some(10) == b"xxx"
+ assert await mrs2.receive_some(10) == b"xxx"
+
+ mrs2.put_data(b"zzz")
+ mrs2.receive_some_hook = None
+ assert await mrs2.receive_some(10) == b"zzz"
+
+ mrs2.put_data(b"lost on close")
+ with assert_checkpoints():
+ await mrs2.aclose()
+ assert record == ["closed"]
+
+ with pytest.raises(_core.ClosedResourceError):
+ await mrs2.receive_some(10)
+
+
+async def test_MemoryRecvStream_closing():
+ mrs = MemoryReceiveStream()
+ # close with no pending data
+ mrs.close()
+ with pytest.raises(_core.ClosedResourceError):
+ assert await mrs.receive_some(10) == b""
+ # repeated closes ok
+ mrs.close()
+ # put_data now fails
+ with pytest.raises(_core.ClosedResourceError):
+ mrs.put_data(b"123")
+
+ mrs2 = MemoryReceiveStream()
+ # close with pending data
+ mrs2.put_data(b"xyz")
+ mrs2.close()
+ with pytest.raises(_core.ClosedResourceError):
+ await mrs2.receive_some(10)
+
+
+async def test_memory_stream_pump():
+ mss = MemorySendStream()
+ mrs = MemoryReceiveStream()
+
+ # no-op if no data present
+ memory_stream_pump(mss, mrs)
+
+ await mss.send_all(b"123")
+ memory_stream_pump(mss, mrs)
+ assert await mrs.receive_some(10) == b"123"
+
+ await mss.send_all(b"456")
+ assert memory_stream_pump(mss, mrs, max_bytes=1)
+ assert await mrs.receive_some(10) == b"4"
+ assert memory_stream_pump(mss, mrs, max_bytes=1)
+ assert memory_stream_pump(mss, mrs, max_bytes=1)
+ assert not memory_stream_pump(mss, mrs, max_bytes=1)
+ assert await mrs.receive_some(10) == b"56"
+
+ mss.close()
+ memory_stream_pump(mss, mrs)
+ assert await mrs.receive_some(10) == b""
+
+
+async def test_memory_stream_one_way_pair():
+ s, r = memory_stream_one_way_pair()
+ assert s.send_all_hook is not None
+ assert s.wait_send_all_might_not_block_hook is None
+ assert s.close_hook is not None
+ assert r.receive_some_hook is None
+ await s.send_all(b"123")
+ assert await r.receive_some(10) == b"123"
+
+ async def receiver(expected):
+ assert await r.receive_some(10) == expected
+
+ # This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(receiver, b"abc")
+ await wait_all_tasks_blocked()
+ await s.send_all(b"abc")
+
+ # And this fails if we don't pump from close_hook
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(receiver, b"")
+ await wait_all_tasks_blocked()
+ await s.aclose()
+
+ s, r = memory_stream_one_way_pair()
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(receiver, b"")
+ await wait_all_tasks_blocked()
+ s.close()
+
+ s, r = memory_stream_one_way_pair()
+
+ old = s.send_all_hook
+ s.send_all_hook = None
+ await s.send_all(b"456")
+
+ async def cancel_after_idle(nursery):
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+ async def check_for_cancel():
+ with pytest.raises(_core.Cancelled):
+ # This should block forever... or until cancelled. Even though we
+ # sent some data on the send stream.
+ await r.receive_some(10)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(cancel_after_idle, nursery)
+ nursery.start_soon(check_for_cancel)
+
+ s.send_all_hook = old
+ await s.send_all(b"789")
+ assert await r.receive_some(10) == b"456789"
+
+
+async def test_memory_stream_pair():
+ a, b = memory_stream_pair()
+ await a.send_all(b"123")
+ await b.send_all(b"abc")
+ assert await b.receive_some(10) == b"123"
+ assert await a.receive_some(10) == b"abc"
+
+ await a.send_eof()
+ assert await b.receive_some(10) == b""
+
+ async def sender():
+ await wait_all_tasks_blocked()
+ await b.send_all(b"xyz")
+
+ async def receiver():
+ assert await a.receive_some(10) == b"xyz"
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(receiver)
+ nursery.start_soon(sender)
+
+
+async def test_memory_streams_with_generic_tests():
+ async def one_way_stream_maker():
+ return memory_stream_one_way_pair()
+
+ await check_one_way_stream(one_way_stream_maker, None)
+
+ async def half_closeable_stream_maker():
+ return memory_stream_pair()
+
+ await check_half_closeable_stream(half_closeable_stream_maker, None)
+
+
+async def test_lockstep_streams_with_generic_tests():
+ async def one_way_stream_maker():
+ return lockstep_stream_one_way_pair()
+
+ await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
+
+ async def two_way_stream_maker():
+ return lockstep_stream_pair()
+
+ await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
+
+
+async def test_open_stream_to_socket_listener():
+ async def check(listener):
+ async with listener:
+ client_stream = await open_stream_to_socket_listener(listener)
+ async with client_stream:
+ server_stream = await listener.accept()
+ async with server_stream:
+ await client_stream.send_all(b"x")
+ await server_stream.receive_some(1) == b"x"
+
+ # Listener bound to localhost
+ sock = tsocket.socket()
+ await sock.bind(("127.0.0.1", 0))
+ sock.listen(10)
+ await check(SocketListener(sock))
+
+ # Listener bound to IPv4 wildcard (needs special handling)
+ sock = tsocket.socket()
+ await sock.bind(("0.0.0.0", 0))
+ sock.listen(10)
+ await check(SocketListener(sock))
+
+ if can_bind_ipv6:
+ # Listener bound to IPv6 wildcard (needs special handling)
+ sock = tsocket.socket(family=tsocket.AF_INET6)
+ await sock.bind(("::", 0))
+ sock.listen(10)
+ await check(SocketListener(sock))
+
+ if hasattr(tsocket, "AF_UNIX"):
+ # Listener bound to Unix-domain socket
+ sock = tsocket.socket(family=tsocket.AF_UNIX)
+ # can't use pytest's tmpdir; if we try then macOS says "OSError:
+ # AF_UNIX path too long"
+ with tempfile.TemporaryDirectory() as tmpdir:
+ path = "{}/sock".format(tmpdir)
+ await sock.bind(path)
+ sock.listen(10)
+ await check(SocketListener(sock))
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_threads.py b/venv/lib/python3.9/site-packages/trio/tests/test_threads.py
new file mode 100644
index 00000000..baff1827
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_threads.py
@@ -0,0 +1,752 @@
+import contextvars
+import threading
+import queue as stdlib_queue
+import time
+import weakref
+
+import pytest
+from sniffio import current_async_library_cvar
+from trio._core import TrioToken, current_trio_token
+
+from .. import _core
+from .. import Event, CapacityLimiter, sleep
+from ..testing import wait_all_tasks_blocked
+from .._core.tests.tutil import buggy_pypy_asyncgens
+from .._threads import (
+ to_thread_run_sync,
+ current_default_thread_limiter,
+ from_thread_run,
+ from_thread_run_sync,
+)
+
+from .._core.tests.test_ki import ki_self
+
+
+async def test_do_in_trio_thread():
+ trio_thread = threading.current_thread()
+
+ async def check_case(do_in_trio_thread, fn, expected, trio_token=None):
+ record = []
+
+ def threadfn():
+ try:
+ record.append(("start", threading.current_thread()))
+ x = do_in_trio_thread(fn, record, trio_token=trio_token)
+ record.append(("got", x))
+ except BaseException as exc:
+ print(exc)
+ record.append(("error", type(exc)))
+
+ child_thread = threading.Thread(target=threadfn, daemon=True)
+ child_thread.start()
+ while child_thread.is_alive():
+ print("yawn")
+ await sleep(0.01)
+ assert record == [("start", child_thread), ("f", trio_thread), expected]
+
+ token = _core.current_trio_token()
+
+ def f(record):
+ assert not _core.currently_ki_protected()
+ record.append(("f", threading.current_thread()))
+ return 2
+
+ await check_case(from_thread_run_sync, f, ("got", 2), trio_token=token)
+
+ def f(record):
+ assert not _core.currently_ki_protected()
+ record.append(("f", threading.current_thread()))
+ raise ValueError
+
+ await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token)
+
+ async def f(record):
+ assert not _core.currently_ki_protected()
+ await _core.checkpoint()
+ record.append(("f", threading.current_thread()))
+ return 3
+
+ await check_case(from_thread_run, f, ("got", 3), trio_token=token)
+
+ async def f(record):
+ assert not _core.currently_ki_protected()
+ await _core.checkpoint()
+ record.append(("f", threading.current_thread()))
+ raise KeyError
+
+ await check_case(from_thread_run, f, ("error", KeyError), trio_token=token)
+
+
+async def test_do_in_trio_thread_from_trio_thread():
+ with pytest.raises(RuntimeError):
+ from_thread_run_sync(lambda: None) # pragma: no branch
+
+ async def foo(): # pragma: no cover
+ pass
+
+ with pytest.raises(RuntimeError):
+ from_thread_run(foo)
+
+
+def test_run_in_trio_thread_ki():
+ # if we get a control-C during a run_in_trio_thread, then it propagates
+ # back to the caller (slick!)
+ record = set()
+
+ async def check_run_in_trio_thread():
+ token = _core.current_trio_token()
+
+ def trio_thread_fn():
+ print("in Trio thread")
+ assert not _core.currently_ki_protected()
+ print("ki_self")
+ try:
+ ki_self()
+ finally:
+ import sys
+
+ print("finally", sys.exc_info())
+
+ async def trio_thread_afn():
+ trio_thread_fn()
+
+ def external_thread_fn():
+ try:
+ print("running")
+ from_thread_run_sync(trio_thread_fn, trio_token=token)
+ except KeyboardInterrupt:
+ print("ok1")
+ record.add("ok1")
+ try:
+ from_thread_run(trio_thread_afn, trio_token=token)
+ except KeyboardInterrupt:
+ print("ok2")
+ record.add("ok2")
+
+ thread = threading.Thread(target=external_thread_fn)
+ thread.start()
+ print("waiting")
+ while thread.is_alive():
+ await sleep(0.01)
+ print("waited, joining")
+ thread.join()
+ print("done")
+
+ _core.run(check_run_in_trio_thread)
+ assert record == {"ok1", "ok2"}
+
+
+def test_await_in_trio_thread_while_main_exits():
+ record = []
+ ev = Event()
+
+ async def trio_fn():
+ record.append("sleeping")
+ ev.set()
+ await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
+
+ def thread_fn(token):
+ try:
+ from_thread_run(trio_fn, trio_token=token)
+ except _core.Cancelled:
+ record.append("cancelled")
+
+ async def main():
+ token = _core.current_trio_token()
+ thread = threading.Thread(target=thread_fn, args=(token,))
+ thread.start()
+ await ev.wait()
+ assert record == ["sleeping"]
+ return thread
+
+ thread = _core.run(main)
+ thread.join()
+ assert record == ["sleeping", "cancelled"]
+
+
+async def test_run_in_worker_thread():
+ trio_thread = threading.current_thread()
+
+ def f(x):
+ return (x, threading.current_thread())
+
+ x, child_thread = await to_thread_run_sync(f, 1)
+ assert x == 1
+ assert child_thread != trio_thread
+
+ def g():
+ raise ValueError(threading.current_thread())
+
+ with pytest.raises(ValueError) as excinfo:
+ await to_thread_run_sync(g)
+ print(excinfo.value.args)
+ assert excinfo.value.args[0] != trio_thread
+
+
+async def test_run_in_worker_thread_cancellation():
+ register = [None]
+
+ def f(q):
+ # Make the thread block for a controlled amount of time
+ register[0] = "blocking"
+ q.get()
+ register[0] = "finished"
+
+ async def child(q, cancellable):
+ record.append("start")
+ try:
+ return await to_thread_run_sync(f, q, cancellable=cancellable)
+ finally:
+ record.append("exit")
+
+ record = []
+ q = stdlib_queue.Queue()
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(child, q, True)
+ # Give it a chance to get started. (This is important because
+ # to_thread_run_sync does a checkpoint_if_cancelled before
+ # blocking on the thread, and we don't want to trigger this.)
+ await wait_all_tasks_blocked()
+ assert record == ["start"]
+ # Then cancel it.
+ nursery.cancel_scope.cancel()
+ # The task exited, but the thread didn't:
+ assert register[0] != "finished"
+ # Put the thread out of its misery:
+ q.put(None)
+ while register[0] != "finished":
+ time.sleep(0.01)
+
+ # This one can't be cancelled
+ record = []
+ register[0] = None
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(child, q, False)
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+ with _core.CancelScope(shield=True):
+ for _ in range(10):
+ await _core.checkpoint()
+ # It's still running
+ assert record == ["start"]
+ q.put(None)
+ # Now it exits
+
+ # But if we cancel *before* it enters, the entry is itself a cancellation
+ # point
+ with _core.CancelScope() as scope:
+ scope.cancel()
+ await child(q, False)
+ assert scope.cancelled_caught
+
+
+# Make sure that if trio.run exits, and then the thread finishes, then that's
+# handled gracefully. (Requires that the thread result machinery be prepared
+# for call_soon to raise RunFinishedError.)
+def test_run_in_worker_thread_abandoned(capfd, monkeypatch):
+ monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
+
+ q1 = stdlib_queue.Queue()
+ q2 = stdlib_queue.Queue()
+
+ def thread_fn():
+ q1.get()
+ q2.put(threading.current_thread())
+
+ async def main():
+ async def child():
+ await to_thread_run_sync(thread_fn, cancellable=True)
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(child)
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+ _core.run(main)
+
+ q1.put(None)
+ # This makes sure:
+ # - the thread actually ran
+ # - that thread has finished before we check for its output
+ thread = q2.get()
+ while thread.is_alive():
+ time.sleep(0.01) # pragma: no cover
+
+ # Make sure we don't have a "Exception in thread ..." dump to the console:
+ out, err = capfd.readouterr()
+ assert "Exception in thread" not in out
+ assert "Exception in thread" not in err
+
+
+@pytest.mark.parametrize("MAX", [3, 5, 10])
+@pytest.mark.parametrize("cancel", [False, True])
+@pytest.mark.parametrize("use_default_limiter", [False, True])
+async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter):
+ # This test is a bit tricky. The goal is to make sure that if we set
+ # limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever
+ # running at a time, even if there are more concurrent calls to
+ # to_thread_run_sync, and even if some of those are cancelled. And
+ # also to make sure that the default limiter actually limits.
+ COUNT = 2 * MAX
+ gate = threading.Event()
+ lock = threading.Lock()
+ if use_default_limiter:
+ c = current_default_thread_limiter()
+ orig_total_tokens = c.total_tokens
+ c.total_tokens = MAX
+ limiter_arg = None
+ else:
+ c = CapacityLimiter(MAX)
+ orig_total_tokens = MAX
+ limiter_arg = c
+ try:
+ # We used to use regular variables and 'nonlocal' here, but it turns
+ # out that it's not safe to assign to closed-over variables that are
+ # visible in multiple threads, at least as of CPython 3.10 and PyPy
+ # 7.3:
+ #
+ # https://bugs.python.org/issue30744
+ # https://bitbucket.org/pypy/pypy/issues/2591/
+ #
+ # Mutating them in-place is OK though (as long as you use proper
+ # locking etc.).
+ class state:
+ pass
+
+ state.ran = 0
+ state.high_water = 0
+ state.running = 0
+ state.parked = 0
+
+ token = _core.current_trio_token()
+
+ def thread_fn(cancel_scope):
+ print("thread_fn start")
+ from_thread_run_sync(cancel_scope.cancel, trio_token=token)
+ with lock:
+ state.ran += 1
+ state.running += 1
+ state.high_water = max(state.high_water, state.running)
+ # The Trio thread below watches this value and uses it as a
+ # signal that all the stats calculations have finished.
+ state.parked += 1
+ gate.wait()
+ with lock:
+ state.parked -= 1
+ state.running -= 1
+ print("thread_fn exiting")
+
+ async def run_thread(event):
+ with _core.CancelScope() as cancel_scope:
+ await to_thread_run_sync(
+ thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel
+ )
+ print("run_thread finished, cancelled:", cancel_scope.cancelled_caught)
+ event.set()
+
+ async with _core.open_nursery() as nursery:
+ print("spawning")
+ events = []
+ for i in range(COUNT):
+ events.append(Event())
+ nursery.start_soon(run_thread, events[-1])
+ await wait_all_tasks_blocked()
+ # In the cancel case, we in particular want to make sure that the
+ # cancelled tasks don't release the semaphore. So let's wait until
+ # at least one of them has exited, and that everything has had a
+ # chance to settle down from this, before we check that everyone
+ # who's supposed to be waiting is waiting:
+ if cancel:
+ print("waiting for first cancellation to clear")
+ await events[0].wait()
+ await wait_all_tasks_blocked()
+ # Then wait until the first MAX threads are parked in gate.wait(),
+ # and the next MAX threads are parked on the semaphore, to make
+ # sure no-one is sneaking past, and to make sure the high_water
+ # check below won't fail due to scheduling issues. (It could still
+ # fail if too many threads are let through here.)
+ while state.parked != MAX or c.statistics().tasks_waiting != MAX:
+ await sleep(0.01) # pragma: no cover
+ # Then release the threads
+ gate.set()
+
+ assert state.high_water == MAX
+
+ if cancel:
+ # Some threads might still be running; need to wait to them to
+ # finish before checking that all threads ran. We can do this
+ # using the CapacityLimiter.
+ while c.borrowed_tokens > 0:
+ await sleep(0.01) # pragma: no cover
+
+ assert state.ran == COUNT
+ assert state.running == 0
+ finally:
+ c.total_tokens = orig_total_tokens
+
+
+async def test_run_in_worker_thread_custom_limiter():
+ # Basically just checking that we only call acquire_on_behalf_of and
+ # release_on_behalf_of, since that's part of our documented API.
+ record = []
+
+ class CustomLimiter:
+ async def acquire_on_behalf_of(self, borrower):
+ record.append("acquire")
+ self._borrower = borrower
+
+ def release_on_behalf_of(self, borrower):
+ record.append("release")
+ assert borrower == self._borrower
+
+ await to_thread_run_sync(lambda: None, limiter=CustomLimiter())
+ assert record == ["acquire", "release"]
+
+
+async def test_run_in_worker_thread_limiter_error():
+ record = []
+
+ class BadCapacityLimiter:
+ async def acquire_on_behalf_of(self, borrower):
+ record.append("acquire")
+
+ def release_on_behalf_of(self, borrower):
+ record.append("release")
+ raise ValueError
+
+ bs = BadCapacityLimiter()
+
+ with pytest.raises(ValueError) as excinfo:
+ await to_thread_run_sync(lambda: None, limiter=bs)
+ assert excinfo.value.__context__ is None
+ assert record == ["acquire", "release"]
+ record = []
+
+ # If the original function raised an error, then the semaphore error
+ # chains with it
+ d = {}
+ with pytest.raises(ValueError) as excinfo:
+ await to_thread_run_sync(lambda: d["x"], limiter=bs)
+ assert isinstance(excinfo.value.__context__, KeyError)
+ assert record == ["acquire", "release"]
+
+
+async def test_run_in_worker_thread_fail_to_spawn(monkeypatch):
+ # Test the unlikely but possible case where trying to spawn a thread fails
+ def bad_start(self, *args):
+ raise RuntimeError("the engines canna take it captain")
+
+ monkeypatch.setattr(_core._thread_cache.ThreadCache, "start_thread_soon", bad_start)
+
+ limiter = current_default_thread_limiter()
+ assert limiter.borrowed_tokens == 0
+
+ # We get an appropriate error, and the limiter is cleanly released
+ with pytest.raises(RuntimeError) as excinfo:
+ await to_thread_run_sync(lambda: None) # pragma: no cover
+ assert "engines" in str(excinfo.value)
+
+ assert limiter.borrowed_tokens == 0
+
+
+async def test_trio_to_thread_run_sync_token():
+ # Test that to_thread_run_sync automatically injects the current trio token
+ # into a spawned thread
+ def thread_fn():
+ callee_token = from_thread_run_sync(_core.current_trio_token)
+ return callee_token
+
+ caller_token = _core.current_trio_token()
+ callee_token = await to_thread_run_sync(thread_fn)
+ assert callee_token == caller_token
+
+
+async def test_trio_to_thread_run_sync_expected_error():
+ # Test correct error when passed async function
+ async def async_fn(): # pragma: no cover
+ pass
+
+ with pytest.raises(TypeError, match="expected a sync function"):
+ await to_thread_run_sync(async_fn)
+
+
+trio_test_contextvar = contextvars.ContextVar("trio_test_contextvar")
+
+
+async def test_trio_to_thread_run_sync_contextvars():
+ trio_thread = threading.current_thread()
+ trio_test_contextvar.set("main")
+
+ def f():
+ value = trio_test_contextvar.get()
+ sniffio_cvar_value = current_async_library_cvar.get()
+ return (value, sniffio_cvar_value, threading.current_thread())
+
+ value, sniffio_cvar_value, child_thread = await to_thread_run_sync(f)
+ assert value == "main"
+ assert sniffio_cvar_value == None
+ assert child_thread != trio_thread
+
+ def g():
+ parent_value = trio_test_contextvar.get()
+ trio_test_contextvar.set("worker")
+ inner_value = trio_test_contextvar.get()
+ sniffio_cvar_value = current_async_library_cvar.get()
+ return (
+ parent_value,
+ inner_value,
+ sniffio_cvar_value,
+ threading.current_thread(),
+ )
+
+ (
+ parent_value,
+ inner_value,
+ sniffio_cvar_value,
+ child_thread,
+ ) = await to_thread_run_sync(g)
+ current_value = trio_test_contextvar.get()
+ sniffio_outer_value = current_async_library_cvar.get()
+ assert parent_value == "main"
+ assert inner_value == "worker"
+ assert (
+ current_value == "main"
+ ), "The contextvar value set on the worker would not propagate back to the main thread"
+ assert sniffio_cvar_value is None
+ assert sniffio_outer_value == "trio"
+
+
+async def test_trio_from_thread_run_sync():
+ # Test that to_thread_run_sync correctly "hands off" the trio token to
+ # trio.from_thread.run_sync()
+ def thread_fn():
+ trio_time = from_thread_run_sync(_core.current_time)
+ return trio_time
+
+ trio_time = await to_thread_run_sync(thread_fn)
+ assert isinstance(trio_time, float)
+
+ # Test correct error when passed async function
+ async def async_fn(): # pragma: no cover
+ pass
+
+ def thread_fn():
+ from_thread_run_sync(async_fn)
+
+ with pytest.raises(TypeError, match="expected a sync function"):
+ await to_thread_run_sync(thread_fn)
+
+
+async def test_trio_from_thread_run():
+ # Test that to_thread_run_sync correctly "hands off" the trio token to
+ # trio.from_thread.run()
+ record = []
+
+ async def back_in_trio_fn():
+ _core.current_time() # implicitly checks that we're in trio
+ record.append("back in trio")
+
+ def thread_fn():
+ record.append("in thread")
+ from_thread_run(back_in_trio_fn)
+
+ await to_thread_run_sync(thread_fn)
+ assert record == ["in thread", "back in trio"]
+
+ # Test correct error when passed sync function
+ def sync_fn(): # pragma: no cover
+ pass
+
+ with pytest.raises(TypeError, match="appears to be synchronous"):
+ await to_thread_run_sync(from_thread_run, sync_fn)
+
+
+async def test_trio_from_thread_token():
+ # Test that to_thread_run_sync and spawned trio.from_thread.run_sync()
+ # share the same Trio token
+ def thread_fn():
+ callee_token = from_thread_run_sync(_core.current_trio_token)
+ return callee_token
+
+ caller_token = _core.current_trio_token()
+ callee_token = await to_thread_run_sync(thread_fn)
+ assert callee_token == caller_token
+
+
+async def test_trio_from_thread_token_kwarg():
+ # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can
+ # use an explicitly defined token
+ def thread_fn(token):
+ callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token)
+ return callee_token
+
+ caller_token = _core.current_trio_token()
+ callee_token = await to_thread_run_sync(thread_fn, caller_token)
+ assert callee_token == caller_token
+
+
+async def test_from_thread_no_token():
+ # Test that a "raw call" to trio.from_thread.run() fails because no token
+ # has been provided
+
+ with pytest.raises(RuntimeError):
+ from_thread_run_sync(_core.current_time)
+
+
+async def test_trio_from_thread_run_sync_contextvars():
+ trio_test_contextvar.set("main")
+
+ def thread_fn():
+ thread_parent_value = trio_test_contextvar.get()
+ trio_test_contextvar.set("worker")
+ thread_current_value = trio_test_contextvar.get()
+ sniffio_cvar_thread_pre_value = current_async_library_cvar.get()
+
+ def back_in_main():
+ back_parent_value = trio_test_contextvar.get()
+ trio_test_contextvar.set("back_in_main")
+ back_current_value = trio_test_contextvar.get()
+ sniffio_cvar_back_value = current_async_library_cvar.get()
+ return back_parent_value, back_current_value, sniffio_cvar_back_value
+
+ (
+ back_parent_value,
+ back_current_value,
+ sniffio_cvar_back_value,
+ ) = from_thread_run_sync(back_in_main)
+ thread_after_value = trio_test_contextvar.get()
+ sniffio_cvar_thread_after_value = current_async_library_cvar.get()
+ return (
+ thread_parent_value,
+ thread_current_value,
+ thread_after_value,
+ sniffio_cvar_thread_pre_value,
+ sniffio_cvar_thread_after_value,
+ back_parent_value,
+ back_current_value,
+ sniffio_cvar_back_value,
+ )
+
+ (
+ thread_parent_value,
+ thread_current_value,
+ thread_after_value,
+ sniffio_cvar_thread_pre_value,
+ sniffio_cvar_thread_after_value,
+ back_parent_value,
+ back_current_value,
+ sniffio_cvar_back_value,
+ ) = await to_thread_run_sync(thread_fn)
+ current_value = trio_test_contextvar.get()
+ sniffio_cvar_out_value = current_async_library_cvar.get()
+ assert current_value == thread_parent_value == "main"
+ assert thread_current_value == back_parent_value == thread_after_value == "worker"
+ assert back_current_value == "back_in_main"
+ assert sniffio_cvar_out_value == sniffio_cvar_back_value == "trio"
+ assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None
+
+
+async def test_trio_from_thread_run_contextvars():
+ trio_test_contextvar.set("main")
+
+ def thread_fn():
+ thread_parent_value = trio_test_contextvar.get()
+ trio_test_contextvar.set("worker")
+ thread_current_value = trio_test_contextvar.get()
+ sniffio_cvar_thread_pre_value = current_async_library_cvar.get()
+
+ async def async_back_in_main():
+ back_parent_value = trio_test_contextvar.get()
+ trio_test_contextvar.set("back_in_main")
+ back_current_value = trio_test_contextvar.get()
+ sniffio_cvar_back_value = current_async_library_cvar.get()
+ return back_parent_value, back_current_value, sniffio_cvar_back_value
+
+ (
+ back_parent_value,
+ back_current_value,
+ sniffio_cvar_back_value,
+ ) = from_thread_run(async_back_in_main)
+ thread_after_value = trio_test_contextvar.get()
+ sniffio_cvar_thread_after_value = current_async_library_cvar.get()
+ return (
+ thread_parent_value,
+ thread_current_value,
+ thread_after_value,
+ sniffio_cvar_thread_pre_value,
+ sniffio_cvar_thread_after_value,
+ back_parent_value,
+ back_current_value,
+ sniffio_cvar_back_value,
+ )
+
+ (
+ thread_parent_value,
+ thread_current_value,
+ thread_after_value,
+ sniffio_cvar_thread_pre_value,
+ sniffio_cvar_thread_after_value,
+ back_parent_value,
+ back_current_value,
+ sniffio_cvar_back_value,
+ ) = await to_thread_run_sync(thread_fn)
+ current_value = trio_test_contextvar.get()
+ assert current_value == thread_parent_value == "main"
+ assert thread_current_value == back_parent_value == thread_after_value == "worker"
+ assert back_current_value == "back_in_main"
+ assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None
+ assert sniffio_cvar_back_value == "trio"
+
+
+def test_run_fn_as_system_task_catched_badly_typed_token():
+ with pytest.raises(RuntimeError):
+ from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype")
+
+
+async def test_from_thread_inside_trio_thread():
+ def not_called(): # pragma: no cover
+ assert False
+
+ trio_token = _core.current_trio_token()
+ with pytest.raises(RuntimeError):
+ from_thread_run_sync(not_called, trio_token=trio_token)
+
+
+@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
+def test_from_thread_run_during_shutdown():
+ save = []
+ record = []
+
+ async def agen():
+ try:
+ yield
+ finally:
+ with pytest.raises(_core.RunFinishedError), _core.CancelScope(shield=True):
+ await to_thread_run_sync(from_thread_run, sleep, 0)
+ record.append("ok")
+
+ async def main():
+ save.append(agen())
+ await save[-1].asend(None)
+
+ _core.run(main)
+ assert record == ["ok"]
+
+
+async def test_trio_token_weak_referenceable():
+ token = current_trio_token()
+ assert isinstance(token, TrioToken)
+ weak_reference = weakref.ref(token)
+ assert token is weak_reference()
+
+
+async def test_unsafe_cancellable_kwarg():
+
+ # This is a stand in for a numpy ndarray or other objects
+ # that (maybe surprisingly) lack a notion of truthiness
+ class BadBool:
+ def __bool__(self):
+ raise NotImplementedError
+
+ with pytest.raises(NotImplementedError):
+ await to_thread_run_sync(int, cancellable=BadBool())
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_timeouts.py b/venv/lib/python3.9/site-packages/trio/tests/test_timeouts.py
new file mode 100644
index 00000000..382c015b
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_timeouts.py
@@ -0,0 +1,104 @@
+import outcome
+import pytest
+import time
+
+from .._core.tests.tutil import slow
+from .. import _core
+from ..testing import assert_checkpoints
+from .._timeouts import *
+
+
+async def check_takes_about(f, expected_dur):
+ start = time.perf_counter()
+ result = await outcome.acapture(f)
+ dur = time.perf_counter() - start
+ print(dur / expected_dur)
+ # 1.5 is an arbitrary fudge factor because there's always some delay
+ # between when we become eligible to wake up and when we actually do. We
+ # used to sleep for 0.05, and regularly observed overruns of 1.6x on
+ # Appveyor, and then started seeing overruns of 2.3x on Travis's macOS, so
+ # now we bumped up the sleep to 1 second, marked the tests as slow, and
+ # hopefully now the proportional error will be less huge.
+ #
+ # We also also for durations that are a hair shorter than expected. For
+ # example, here's a run on Windows where a 1.0 second sleep was measured
+ # to take 0.9999999999999858 seconds:
+ # https://ci.appveyor.com/project/njsmith/trio/build/1.0.768/job/3lbdyxl63q3h9s21
+ # I believe that what happened here is that Windows's low clock resolution
+ # meant that our calls to time.monotonic() returned exactly the same
+ # values as the calls inside the actual run loop, but the two subtractions
+ # returned slightly different values because the run loop's clock adds a
+ # random floating point offset to both times, which should cancel out, but
+ # lol floating point we got slightly different rounding errors. (That
+ # value above is exactly 128 ULPs below 1.0, which would make sense if it
+ # started as a 1 ULP error at a different dynamic range.)
+ assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
+ return result.unwrap()
+
+
+# How long to (attempt to) sleep for when testing. Smaller numbers make the
+# test suite go faster.
+TARGET = 1.0
+
+
+@slow
+async def test_sleep():
+ async def sleep_1():
+ await sleep_until(_core.current_time() + TARGET)
+
+ await check_takes_about(sleep_1, TARGET)
+
+ async def sleep_2():
+ await sleep(TARGET)
+
+ await check_takes_about(sleep_2, TARGET)
+
+ with pytest.raises(ValueError):
+ await sleep(-1)
+
+ with assert_checkpoints():
+ await sleep(0)
+ # This also serves as a test of the trivial move_on_at
+ with move_on_at(_core.current_time()):
+ with pytest.raises(_core.Cancelled):
+ await sleep(0)
+
+
+@slow
+async def test_move_on_after():
+ with pytest.raises(ValueError):
+ with move_on_after(-1):
+ pass # pragma: no cover
+
+ async def sleep_3():
+ with move_on_after(TARGET):
+ await sleep(100)
+
+ await check_takes_about(sleep_3, TARGET)
+
+
+@slow
+async def test_fail():
+ async def sleep_4():
+ with fail_at(_core.current_time() + TARGET):
+ await sleep(100)
+
+ with pytest.raises(TooSlowError):
+ await check_takes_about(sleep_4, TARGET)
+
+ with fail_at(_core.current_time() + 100):
+ await sleep(0)
+
+ async def sleep_5():
+ with fail_after(TARGET):
+ await sleep(100)
+
+ with pytest.raises(TooSlowError):
+ await check_takes_about(sleep_5, TARGET)
+
+ with fail_after(100):
+ await sleep(0)
+
+ with pytest.raises(ValueError):
+ with fail_after(-1):
+ pass # pragma: no cover
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_unix_pipes.py b/venv/lib/python3.9/site-packages/trio/tests/test_unix_pipes.py
new file mode 100644
index 00000000..cf98942e
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_unix_pipes.py
@@ -0,0 +1,276 @@
+import errno
+import select
+import os
+import tempfile
+import sys
+
+import pytest
+
+from .._core.tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken
+from .. import _core, move_on_after
+from ..testing import wait_all_tasks_blocked, check_one_way_stream
+
+posix = os.name == "posix"
+pytestmark = pytest.mark.skipif(not posix, reason="posix only")
+if posix:
+ from .._unix_pipes import FdStream
+else:
+ with pytest.raises(ImportError):
+ from .._unix_pipes import FdStream
+
+
+# Have to use quoted types so import doesn't crash on windows
+async def make_pipe() -> "Tuple[FdStream, FdStream]":
+ """Makes a new pair of pipes."""
+ (r, w) = os.pipe()
+ return FdStream(w), FdStream(r)
+
+
+async def make_clogged_pipe():
+ s, r = await make_pipe()
+ try:
+ while True:
+ # We want to totally fill up the pipe buffer.
+ # This requires working around a weird feature that POSIX pipes
+ # have.
+ # If you do a write of <= PIPE_BUF bytes, then it's guaranteed
+ # to either complete entirely, or not at all. So if we tried to
+ # write PIPE_BUF bytes, and the buffer's free space is only
+ # PIPE_BUF/2, then the write will raise BlockingIOError... even
+ # though a smaller write could still succeed! To avoid this,
+ # make sure to write >PIPE_BUF bytes each time, which disables
+ # the special behavior.
+ # For details, search for PIPE_BUF here:
+ # http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html
+
+ # for the getattr:
+ # https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3
+ buf_size = getattr(select, "PIPE_BUF", 8192)
+ os.write(s.fileno(), b"x" * buf_size * 2)
+ except BlockingIOError:
+ pass
+ return s, r
+
+
+async def test_send_pipe():
+ r, w = os.pipe()
+ async with FdStream(w) as send:
+ assert send.fileno() == w
+ await send.send_all(b"123")
+ assert (os.read(r, 8)) == b"123"
+
+ os.close(r)
+
+
+async def test_receive_pipe():
+ r, w = os.pipe()
+ async with FdStream(r) as recv:
+ assert (recv.fileno()) == r
+ os.write(w, b"123")
+ assert (await recv.receive_some(8)) == b"123"
+
+ os.close(w)
+
+
+async def test_pipes_combined():
+ write, read = await make_pipe()
+ count = 2**20
+
+ async def sender():
+ big = bytearray(count)
+ await write.send_all(big)
+
+ async def reader():
+ await wait_all_tasks_blocked()
+ received = 0
+ while received < count:
+ received += len(await read.receive_some(4096))
+
+ assert received == count
+
+ async with _core.open_nursery() as n:
+ n.start_soon(sender)
+ n.start_soon(reader)
+
+ await read.aclose()
+ await write.aclose()
+
+
+async def test_pipe_errors():
+ with pytest.raises(TypeError):
+ FdStream(None)
+
+ r, w = os.pipe()
+ os.close(w)
+ async with FdStream(r) as s:
+ with pytest.raises(ValueError):
+ await s.receive_some(0)
+
+
+async def test_del():
+ w, r = await make_pipe()
+ f1, f2 = w.fileno(), r.fileno()
+ del w, r
+ gc_collect_harder()
+
+ with pytest.raises(OSError) as excinfo:
+ os.close(f1)
+ assert excinfo.value.errno == errno.EBADF
+
+ with pytest.raises(OSError) as excinfo:
+ os.close(f2)
+ assert excinfo.value.errno == errno.EBADF
+
+
+async def test_async_with():
+ w, r = await make_pipe()
+ async with w, r:
+ pass
+
+ assert w.fileno() == -1
+ assert r.fileno() == -1
+
+ with pytest.raises(OSError) as excinfo:
+ os.close(w.fileno())
+ assert excinfo.value.errno == errno.EBADF
+
+ with pytest.raises(OSError) as excinfo:
+ os.close(r.fileno())
+ assert excinfo.value.errno == errno.EBADF
+
+
+async def test_misdirected_aclose_regression():
+ # https://github.com/python-trio/trio/issues/661#issuecomment-456582356
+ w, r = await make_pipe()
+ old_r_fd = r.fileno()
+
+ # Close the original objects
+ await w.aclose()
+ await r.aclose()
+
+ # Do a little dance to get a new pipe whose receive handle matches the old
+ # receive handle.
+ r2_fd, w2_fd = os.pipe()
+ if r2_fd != old_r_fd: # pragma: no cover
+ os.dup2(r2_fd, old_r_fd)
+ os.close(r2_fd)
+ async with FdStream(old_r_fd) as r2:
+ assert r2.fileno() == old_r_fd
+
+ # And now set up a background task that's working on the new receive
+ # handle
+ async def expect_eof():
+ assert await r2.receive_some(10) == b""
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(expect_eof)
+ await wait_all_tasks_blocked()
+
+ # Here's the key test: does calling aclose() again on the *old*
+ # handle, cause the task blocked on the *new* handle to raise
+ # ClosedResourceError?
+ await r.aclose()
+ await wait_all_tasks_blocked()
+
+ # Guess we survived! Close the new write handle so that the task
+ # gets an EOF and can exit cleanly.
+ os.close(w2_fd)
+
+
+async def test_close_at_bad_time_for_receive_some(monkeypatch):
+ # We used to have race conditions where if one task was using the pipe,
+ # and another closed it at *just* the wrong moment, it would give an
+ # unexpected error instead of ClosedResourceError:
+ # https://github.com/python-trio/trio/issues/661
+ #
+ # This tests what happens if the pipe gets closed in the moment *between*
+ # when receive_some wakes up, and when it tries to call os.read
+ async def expect_closedresourceerror():
+ with pytest.raises(_core.ClosedResourceError):
+ await r.receive_some(10)
+
+ orig_wait_readable = _core._run.TheIOManager.wait_readable
+
+ async def patched_wait_readable(*args, **kwargs):
+ await orig_wait_readable(*args, **kwargs)
+ await r.aclose()
+
+ monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable)
+ s, r = await make_pipe()
+ async with s, r:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(expect_closedresourceerror)
+ await wait_all_tasks_blocked()
+ # Trigger everything by waking up the receiver
+ await s.send_all(b"x")
+
+
+async def test_close_at_bad_time_for_send_all(monkeypatch):
+ # We used to have race conditions where if one task was using the pipe,
+ # and another closed it at *just* the wrong moment, it would give an
+ # unexpected error instead of ClosedResourceError:
+ # https://github.com/python-trio/trio/issues/661
+ #
+ # This tests what happens if the pipe gets closed in the moment *between*
+ # when send_all wakes up, and when it tries to call os.write
+ async def expect_closedresourceerror():
+ with pytest.raises(_core.ClosedResourceError):
+ await s.send_all(b"x" * 100)
+
+ orig_wait_writable = _core._run.TheIOManager.wait_writable
+
+ async def patched_wait_writable(*args, **kwargs):
+ await orig_wait_writable(*args, **kwargs)
+ await s.aclose()
+
+ monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable)
+ s, r = await make_clogged_pipe()
+ async with s, r:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(expect_closedresourceerror)
+ await wait_all_tasks_blocked()
+ # Trigger everything by waking up the sender. On ppc64el, PIPE_BUF
+ # is 8192 but make_clogged_pipe() ends up writing a total of
+ # 1048576 bytes before the pipe is full, and then a subsequent
+ # receive_some(10000) isn't sufficient for orig_wait_writable() to
+ # return for our subsequent aclose() call. It's necessary to empty
+ # the pipe further before this happens. So we loop here until the
+ # pipe is empty to make sure that the sender wakes up even in this
+ # case. Otherwise patched_wait_writable() never gets to the
+ # aclose(), so expect_closedresourceerror() never returns, the
+ # nursery never finishes all tasks and this test hangs.
+ received_data = await r.receive_some(10000)
+ while received_data:
+ received_data = await r.receive_some(10000)
+
+
+# On FreeBSD, directories are readable, and we haven't found any other trick
+# for making an unreadable fd, so there's no way to run this test. Fortunately
+# the logic this is testing doesn't depend on the platform, so testing on
+# other platforms is probably good enough.
+@pytest.mark.skipif(
+ sys.platform.startswith("freebsd"),
+ reason="no way to make read() return a bizarro error on FreeBSD",
+)
+async def test_bizarro_OSError_from_receive():
+ # Make sure that if the read syscall returns some bizarro error, then we
+ # get a BrokenResourceError. This is incredibly unlikely; there's almost
+ # no way to trigger a failure here intentionally (except for EBADF, but we
+ # exploit that to detect file closure, so it takes a different path). So
+ # we set up a strange scenario where the pipe fd somehow transmutes into a
+ # directory fd, causing os.read to raise IsADirectoryError (yes, that's a
+ # real built-in exception type).
+ s, r = await make_pipe()
+ async with s, r:
+ dir_fd = os.open("/", os.O_DIRECTORY, 0)
+ try:
+ os.dup2(dir_fd, r.fileno())
+ with pytest.raises(_core.BrokenResourceError):
+ await r.receive_some(10)
+ finally:
+ os.close(dir_fd)
+
+
+@skip_if_fbsd_pipes_broken
+async def test_pipe_fully():
+ await check_one_way_stream(make_pipe, make_clogged_pipe)
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_util.py b/venv/lib/python3.9/site-packages/trio/tests/test_util.py
new file mode 100644
index 00000000..15ab09a8
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_util.py
@@ -0,0 +1,193 @@
+import signal
+import sys
+
+import pytest
+
+import trio
+from .. import _core
+from .._core.tests.tutil import (
+ ignore_coroutine_never_awaited_warnings,
+ create_asyncio_future_in_new_loop,
+)
+from .._util import (
+ signal_raise,
+ ConflictDetector,
+ is_main_thread,
+ coroutine_or_error,
+ generic_function,
+ Final,
+ NoPublicConstructor,
+)
+from ..testing import wait_all_tasks_blocked
+
+
+def test_signal_raise():
+ record = []
+
+ def handler(signum, _):
+ record.append(signum)
+
+ old = signal.signal(signal.SIGFPE, handler)
+ try:
+ signal_raise(signal.SIGFPE)
+ finally:
+ signal.signal(signal.SIGFPE, old)
+ assert record == [signal.SIGFPE]
+
+
+async def test_ConflictDetector():
+ ul1 = ConflictDetector("ul1")
+ ul2 = ConflictDetector("ul2")
+
+ with ul1:
+ with ul2:
+ print("ok")
+
+ with pytest.raises(_core.BusyResourceError) as excinfo:
+ with ul1:
+ with ul1:
+ pass # pragma: no cover
+ assert "ul1" in str(excinfo.value)
+
+ async def wait_with_ul1():
+ with ul1:
+ await wait_all_tasks_blocked()
+
+ with pytest.raises(_core.BusyResourceError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(wait_with_ul1)
+ nursery.start_soon(wait_with_ul1)
+ assert "ul1" in str(excinfo.value)
+
+
+def test_module_metadata_is_fixed_up():
+ import trio
+ import trio.testing
+
+ assert trio.Cancelled.__module__ == "trio"
+ assert trio.open_nursery.__module__ == "trio"
+ assert trio.abc.Stream.__module__ == "trio.abc"
+ assert trio.lowlevel.wait_task_rescheduled.__module__ == "trio.lowlevel"
+ assert trio.testing.trio_test.__module__ == "trio.testing"
+
+ # Also check methods
+ assert trio.lowlevel.ParkingLot.__init__.__module__ == "trio.lowlevel"
+ assert trio.abc.Stream.send_all.__module__ == "trio.abc"
+
+ # And names
+ assert trio.Cancelled.__name__ == "Cancelled"
+ assert trio.Cancelled.__qualname__ == "Cancelled"
+ assert trio.abc.SendStream.send_all.__name__ == "send_all"
+ assert trio.abc.SendStream.send_all.__qualname__ == "SendStream.send_all"
+ assert trio.to_thread.__name__ == "trio.to_thread"
+ assert trio.to_thread.run_sync.__name__ == "run_sync"
+ assert trio.to_thread.run_sync.__qualname__ == "run_sync"
+
+
+async def test_is_main_thread():
+ assert is_main_thread()
+
+ def not_main_thread():
+ assert not is_main_thread()
+
+ await trio.to_thread.run_sync(not_main_thread)
+
+
+# @coroutine is deprecated since python 3.8, which is fine with us.
+@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning")
+def test_coroutine_or_error():
+ class Deferred:
+ "Just kidding"
+
+ with ignore_coroutine_never_awaited_warnings():
+
+ async def f(): # pragma: no cover
+ pass
+
+ with pytest.raises(TypeError) as excinfo:
+ coroutine_or_error(f())
+ assert "expecting an async function" in str(excinfo.value)
+
+ import asyncio
+
+ if sys.version_info < (3, 11):
+
+ @asyncio.coroutine
+ def generator_based_coro(): # pragma: no cover
+ yield from asyncio.sleep(1)
+
+ with pytest.raises(TypeError) as excinfo:
+ coroutine_or_error(generator_based_coro())
+ assert "asyncio" in str(excinfo.value)
+
+ with pytest.raises(TypeError) as excinfo:
+ coroutine_or_error(create_asyncio_future_in_new_loop())
+ assert "asyncio" in str(excinfo.value)
+
+ with pytest.raises(TypeError) as excinfo:
+ coroutine_or_error(create_asyncio_future_in_new_loop)
+ assert "asyncio" in str(excinfo.value)
+
+ with pytest.raises(TypeError) as excinfo:
+ coroutine_or_error(Deferred())
+ assert "twisted" in str(excinfo.value)
+
+ with pytest.raises(TypeError) as excinfo:
+ coroutine_or_error(lambda: Deferred())
+ assert "twisted" in str(excinfo.value)
+
+ with pytest.raises(TypeError) as excinfo:
+ coroutine_or_error(len, [[1, 2, 3]])
+
+ assert "appears to be synchronous" in str(excinfo.value)
+
+ async def async_gen(arg): # pragma: no cover
+ yield
+
+ with pytest.raises(TypeError) as excinfo:
+ coroutine_or_error(async_gen, [0])
+ msg = "expected an async function but got an async generator"
+ assert msg in str(excinfo.value)
+
+ # Make sure no references are kept around to keep anything alive
+ del excinfo
+
+
+def test_generic_function():
+ @generic_function
+ def test_func(arg):
+ """Look, a docstring!"""
+ return arg
+
+ assert test_func is test_func[int] is test_func[int, str]
+ assert test_func(42) == test_func[int](42) == 42
+ assert test_func.__doc__ == "Look, a docstring!"
+ assert test_func.__qualname__ == "test_generic_function.<locals>.test_func"
+ assert test_func.__name__ == "test_func"
+ assert test_func.__module__ == __name__
+
+
+def test_final_metaclass():
+ class FinalClass(metaclass=Final):
+ pass
+
+ with pytest.raises(TypeError):
+
+ class SubClass(FinalClass):
+ pass
+
+
+def test_no_public_constructor_metaclass():
+ class SpecialClass(metaclass=NoPublicConstructor):
+ pass
+
+ with pytest.raises(TypeError):
+ SpecialClass()
+
+ with pytest.raises(TypeError):
+
+ class SubClass(SpecialClass):
+ pass
+
+ # Private constructor should not raise
+ assert isinstance(SpecialClass._create(), SpecialClass)
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_wait_for_object.py b/venv/lib/python3.9/site-packages/trio/tests/test_wait_for_object.py
new file mode 100644
index 00000000..38acfa80
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_wait_for_object.py
@@ -0,0 +1,220 @@
+import os
+
+import pytest
+
+on_windows = os.name == "nt"
+# Mark all the tests in this file as being windows-only
+pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
+
+from .._core.tests.tutil import slow
+import trio
+from .. import _core
+from .. import _timeouts
+
+if on_windows:
+ from .._core._windows_cffi import ffi, kernel32
+ from .._wait_for_object import (
+ WaitForSingleObject,
+ WaitForMultipleObjects_sync,
+ )
+
+
+async def test_WaitForMultipleObjects_sync():
+ # This does a series of tests where we set/close the handle before
+ # initiating the waiting for it.
+ #
+ # Note that closing the handle (not signaling) will cause the
+ # *initiation* of a wait to return immediately. But closing a handle
+ # that is already being waited on will not stop whatever is waiting
+ # for it.
+
+ # One handle
+ handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ kernel32.SetEvent(handle1)
+ WaitForMultipleObjects_sync(handle1)
+ kernel32.CloseHandle(handle1)
+ print("test_WaitForMultipleObjects_sync one OK")
+
+ # Two handles, signal first
+ handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ kernel32.SetEvent(handle1)
+ WaitForMultipleObjects_sync(handle1, handle2)
+ kernel32.CloseHandle(handle1)
+ kernel32.CloseHandle(handle2)
+ print("test_WaitForMultipleObjects_sync set first OK")
+
+ # Two handles, signal second
+ handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ kernel32.SetEvent(handle2)
+ WaitForMultipleObjects_sync(handle1, handle2)
+ kernel32.CloseHandle(handle1)
+ kernel32.CloseHandle(handle2)
+ print("test_WaitForMultipleObjects_sync set second OK")
+
+ # Two handles, close first
+ handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ kernel32.CloseHandle(handle1)
+ with pytest.raises(OSError):
+ WaitForMultipleObjects_sync(handle1, handle2)
+ kernel32.CloseHandle(handle2)
+ print("test_WaitForMultipleObjects_sync close first OK")
+
+ # Two handles, close second
+ handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ kernel32.CloseHandle(handle2)
+ with pytest.raises(OSError):
+ WaitForMultipleObjects_sync(handle1, handle2)
+ kernel32.CloseHandle(handle1)
+ print("test_WaitForMultipleObjects_sync close second OK")
+
+
+@slow
+async def test_WaitForMultipleObjects_sync_slow():
+ # This does a series of test in which the main thread sync-waits for
+ # handles, while we spawn a thread to set the handles after a short while.
+
+ TIMEOUT = 0.3
+
+ # One handle
+ handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ t0 = _core.current_time()
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(
+ trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1
+ )
+ await _timeouts.sleep(TIMEOUT)
+ # If we would comment the line below, the above thread will be stuck,
+ # and Trio won't exit this scope
+ kernel32.SetEvent(handle1)
+ t1 = _core.current_time()
+ assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
+ kernel32.CloseHandle(handle1)
+ print("test_WaitForMultipleObjects_sync_slow one OK")
+
+ # Two handles, signal first
+ handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ t0 = _core.current_time()
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(
+ trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2
+ )
+ await _timeouts.sleep(TIMEOUT)
+ kernel32.SetEvent(handle1)
+ t1 = _core.current_time()
+ assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
+ kernel32.CloseHandle(handle1)
+ kernel32.CloseHandle(handle2)
+ print("test_WaitForMultipleObjects_sync_slow thread-set first OK")
+
+ # Two handles, signal second
+ handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ t0 = _core.current_time()
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(
+ trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2
+ )
+ await _timeouts.sleep(TIMEOUT)
+ kernel32.SetEvent(handle2)
+ t1 = _core.current_time()
+ assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
+ kernel32.CloseHandle(handle1)
+ kernel32.CloseHandle(handle2)
+ print("test_WaitForMultipleObjects_sync_slow thread-set second OK")
+
+
+async def test_WaitForSingleObject():
+ # This does a series of test for setting/closing the handle before
+ # initiating the wait.
+
+ # Test already set
+ handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ kernel32.SetEvent(handle)
+ await WaitForSingleObject(handle) # should return at once
+ kernel32.CloseHandle(handle)
+ print("test_WaitForSingleObject already set OK")
+
+ # Test already set, as int
+ handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ handle_int = int(ffi.cast("intptr_t", handle))
+ kernel32.SetEvent(handle)
+ await WaitForSingleObject(handle_int) # should return at once
+ kernel32.CloseHandle(handle)
+ print("test_WaitForSingleObject already set OK")
+
+ # Test already closed
+ handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ kernel32.CloseHandle(handle)
+ with pytest.raises(OSError):
+ await WaitForSingleObject(handle) # should return at once
+ print("test_WaitForSingleObject already closed OK")
+
+ # Not a handle
+ with pytest.raises(TypeError):
+ await WaitForSingleObject("not a handle") # Wrong type
+ # with pytest.raises(OSError):
+ # await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :(
+ print("test_WaitForSingleObject not a handle OK")
+
+
+@slow
+async def test_WaitForSingleObject_slow():
+ # This does a series of test for setting the handle in another task,
+ # and cancelling the wait task.
+
+ # Set the timeout used in the tests. We test the waiting time against
+ # the timeout with a certain margin.
+ TIMEOUT = 0.3
+
+ async def signal_soon_async(handle):
+ await _timeouts.sleep(TIMEOUT)
+ kernel32.SetEvent(handle)
+
+ # Test handle is SET after TIMEOUT in separate coroutine
+
+ handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ t0 = _core.current_time()
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(WaitForSingleObject, handle)
+ nursery.start_soon(signal_soon_async, handle)
+
+ kernel32.CloseHandle(handle)
+ t1 = _core.current_time()
+ assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
+ print("test_WaitForSingleObject_slow set from task OK")
+
+ # Test handle is SET after TIMEOUT in separate coroutine, as int
+
+ handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ handle_int = int(ffi.cast("intptr_t", handle))
+ t0 = _core.current_time()
+
+ async with _core.open_nursery() as nursery:
+ nursery.start_soon(WaitForSingleObject, handle_int)
+ nursery.start_soon(signal_soon_async, handle)
+
+ kernel32.CloseHandle(handle)
+ t1 = _core.current_time()
+ assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
+ print("test_WaitForSingleObject_slow set from task as int OK")
+
+ # Test handle is CLOSED after 1 sec - NOPE see comment above
+
+ # Test cancellation
+
+ handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
+ t0 = _core.current_time()
+
+ with _timeouts.move_on_after(TIMEOUT):
+ await WaitForSingleObject(handle)
+
+ kernel32.CloseHandle(handle)
+ t1 = _core.current_time()
+ assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
+ print("test_WaitForSingleObject_slow cancellation OK")
diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_windows_pipes.py b/venv/lib/python3.9/site-packages/trio/tests/test_windows_pipes.py
new file mode 100644
index 00000000..0a6b3516
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/test_windows_pipes.py
@@ -0,0 +1,110 @@
+import errno
+import select
+
+import os
+import sys
+import pytest
+
+from .._core.tests.tutil import gc_collect_harder
+from .. import _core, move_on_after
+from ..testing import wait_all_tasks_blocked, check_one_way_stream
+
+if sys.platform == "win32":
+ from .._windows_pipes import PipeSendStream, PipeReceiveStream
+ from .._core._windows_cffi import _handle, kernel32
+ from asyncio.windows_utils import pipe
+else:
+ pytestmark = pytest.mark.skip(reason="windows only")
+ pipe = None # type: Any
+ PipeSendStream = None # type: Any
+ PipeReceiveStream = None # type: Any
+
+
+async def make_pipe() -> "Tuple[PipeSendStream, PipeReceiveStream]":
+ """Makes a new pair of pipes."""
+ (r, w) = pipe()
+ return PipeSendStream(w), PipeReceiveStream(r)
+
+
+async def test_pipe_typecheck():
+ with pytest.raises(TypeError):
+ PipeSendStream(1.0)
+ with pytest.raises(TypeError):
+ PipeReceiveStream(None)
+
+
+async def test_pipe_error_on_close():
+ # Make sure we correctly handle a failure from kernel32.CloseHandle
+ r, w = pipe()
+
+ send_stream = PipeSendStream(w)
+ receive_stream = PipeReceiveStream(r)
+
+ assert kernel32.CloseHandle(_handle(r))
+ assert kernel32.CloseHandle(_handle(w))
+
+ with pytest.raises(OSError):
+ await send_stream.aclose()
+ with pytest.raises(OSError):
+ await receive_stream.aclose()
+
+
+async def test_pipes_combined():
+ write, read = await make_pipe()
+ count = 2**20
+ replicas = 3
+
+ async def sender():
+ async with write:
+ big = bytearray(count)
+ for _ in range(replicas):
+ await write.send_all(big)
+
+ async def reader():
+ async with read:
+ await wait_all_tasks_blocked()
+ total_received = 0
+ while True:
+ # 5000 is chosen because it doesn't evenly divide 2**20
+ received = len(await read.receive_some(5000))
+ if not received:
+ break
+ total_received += received
+
+ assert total_received == count * replicas
+
+ async with _core.open_nursery() as n:
+ n.start_soon(sender)
+ n.start_soon(reader)
+
+
+async def test_async_with():
+ w, r = await make_pipe()
+ async with w, r:
+ pass
+
+ with pytest.raises(_core.ClosedResourceError):
+ await w.send_all(b"")
+ with pytest.raises(_core.ClosedResourceError):
+ await r.receive_some(10)
+
+
+async def test_close_during_write():
+ w, r = await make_pipe()
+ async with _core.open_nursery() as nursery:
+
+ async def write_forever():
+ with pytest.raises(_core.ClosedResourceError) as excinfo:
+ while True:
+ await w.send_all(b"x" * 4096)
+ assert "another task" in str(excinfo.value)
+
+ nursery.start_soon(write_forever)
+ await wait_all_tasks_blocked(0.1)
+ await w.aclose()
+
+
+async def test_pipe_fully():
+ # passing make_clogged_pipe tests wait_send_all_might_not_block, and we
+ # can't implement that on Windows
+ await check_one_way_stream(make_pipe, None)
diff --git a/venv/lib/python3.9/site-packages/trio/tests/tools/__init__.py b/venv/lib/python3.9/site-packages/trio/tests/tools/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/tools/__init__.py
diff --git a/venv/lib/python3.9/site-packages/trio/tests/tools/test_gen_exports.py b/venv/lib/python3.9/site-packages/trio/tests/tools/test_gen_exports.py
new file mode 100644
index 00000000..e4e388c2
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/trio/tests/tools/test_gen_exports.py
@@ -0,0 +1,72 @@
+import ast
+import astor
+import pytest
+import os
+import sys
+
+from shutil import copyfile
+from trio._tools.gen_exports import (
+ get_public_methods,
+ create_passthrough_args,
+ process,
+)
+
+SOURCE = '''from _run import _public
+
+class Test:
+ @_public
+ def public_func(self):
+ """With doc string"""
+
+ @ignore_this
+ @_public
+ @another_decorator
+ async def public_async_func(self):
+ pass # no doc string
+
+ def not_public(self):
+ pass
+
+ async def not_public_async(self):
+ pass
+'''
+
+
+def test_get_public_methods():
+ methods = list(get_public_methods(ast.parse(SOURCE)))
+ assert {m.name for m in methods} == {"public_func", "public_async_func"}
+
+
+def test_create_pass_through_args():
+ testcases = [
+ ("def f()", "()"),
+ ("def f(one)", "(one)"),
+ ("def f(one, two)", "(one, two)"),
+ ("def f(one, *args)", "(one, *args)"),
+ (
+ "def f(one, *args, kw1, kw2=None, **kwargs)",
+ "(one, *args, kw1=kw1, kw2=kw2, **kwargs)",
+ ),
+ ]
+
+ for (funcdef, expected) in testcases:
+ func_node = ast.parse(funcdef + ":\n pass").body[0]
+ assert isinstance(func_node, ast.FunctionDef)
+ assert create_passthrough_args(func_node) == expected
+
+
+def test_process(tmp_path):
+ modpath = tmp_path / "_module.py"
+ genpath = tmp_path / "_generated_module.py"
+ modpath.write_text(SOURCE, encoding="utf-8")
+ assert not genpath.exists()
+ with pytest.raises(SystemExit) as excinfo:
+ process([(str(modpath), "runner")], do_test=True)
+ assert excinfo.value.code == 1
+ process([(str(modpath), "runner")], do_test=False)
+ assert genpath.exists()
+ process([(str(modpath), "runner")], do_test=True)
+ # But if we change the lookup path it notices
+ with pytest.raises(SystemExit) as excinfo:
+ process([(str(modpath), "runner.io_manager")], do_test=True)
+ assert excinfo.value.code == 1