From 355dee533bb34a571b9367820a63cccb668cf866 Mon Sep 17 00:00:00 2001 From: noptuno Date: Thu, 27 Apr 2023 20:29:30 -0400 Subject: Merging PR_218 openai_rev package with new streamlit chat app --- .../python3.9/site-packages/trio/tests/__init__.py | 0 .../python3.9/site-packages/trio/tests/conftest.py | 41 + .../trio/tests/module_with_deprecations.py | 21 + .../python3.9/site-packages/trio/tests/test_abc.py | 49 + .../site-packages/trio/tests/test_channel.py | 407 ++++++ .../site-packages/trio/tests/test_contextvars.py | 52 + .../site-packages/trio/tests/test_deprecate.py | 243 ++++ .../site-packages/trio/tests/test_dtls.py | 867 +++++++++++++ .../site-packages/trio/tests/test_exports.py | 145 +++ .../site-packages/trio/tests/test_fakenet.py | 44 + .../site-packages/trio/tests/test_file_io.py | 198 +++ .../trio/tests/test_highlevel_generic.py | 94 ++ .../tests/test_highlevel_open_tcp_listeners.py | 300 +++++ .../trio/tests/test_highlevel_open_tcp_stream.py | 574 +++++++++ .../trio/tests/test_highlevel_open_unix_stream.py | 67 + .../trio/tests/test_highlevel_serve_listeners.py | 145 +++ .../trio/tests/test_highlevel_socket.py | 267 ++++ .../trio/tests/test_highlevel_ssl_helpers.py | 113 ++ .../site-packages/trio/tests/test_path.py | 262 ++++ .../trio/tests/test_scheduler_determinism.py | 40 + .../site-packages/trio/tests/test_signals.py | 177 +++ .../site-packages/trio/tests/test_socket.py | 1017 +++++++++++++++ .../python3.9/site-packages/trio/tests/test_ssl.py | 1303 ++++++++++++++++++++ .../site-packages/trio/tests/test_subprocess.py | 602 +++++++++ .../site-packages/trio/tests/test_sync.py | 567 +++++++++ .../site-packages/trio/tests/test_testing.py | 657 ++++++++++ .../site-packages/trio/tests/test_threads.py | 752 +++++++++++ .../site-packages/trio/tests/test_timeouts.py | 104 ++ .../site-packages/trio/tests/test_unix_pipes.py | 276 +++++ .../site-packages/trio/tests/test_util.py | 193 +++ .../trio/tests/test_wait_for_object.py | 220 ++++ .../site-packages/trio/tests/test_windows_pipes.py | 110 ++ .../site-packages/trio/tests/tools/__init__.py | 0 .../trio/tests/tools/test_gen_exports.py | 72 ++ 34 files changed, 9979 insertions(+) create mode 100644 venv/lib/python3.9/site-packages/trio/tests/__init__.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/conftest.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/module_with_deprecations.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_abc.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_channel.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_contextvars.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_deprecate.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_dtls.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_exports.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_fakenet.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_file_io.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_highlevel_generic.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_listeners.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_stream.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_unix_stream.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_highlevel_serve_listeners.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_highlevel_socket.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_highlevel_ssl_helpers.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_path.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_scheduler_determinism.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_signals.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_socket.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_ssl.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_subprocess.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_sync.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_testing.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_threads.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_timeouts.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_unix_pipes.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_util.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_wait_for_object.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/test_windows_pipes.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/tools/__init__.py create mode 100644 venv/lib/python3.9/site-packages/trio/tests/tools/test_gen_exports.py (limited to 'venv/lib/python3.9/site-packages/trio/tests') diff --git a/venv/lib/python3.9/site-packages/trio/tests/__init__.py b/venv/lib/python3.9/site-packages/trio/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/venv/lib/python3.9/site-packages/trio/tests/conftest.py b/venv/lib/python3.9/site-packages/trio/tests/conftest.py new file mode 100644 index 00000000..772486e1 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/conftest.py @@ -0,0 +1,41 @@ +# XX this does not belong here -- b/c it's here, these things only apply to +# the tests in trio/_core/tests, not in trio/tests. For now there's some +# copy-paste... +# +# this stuff should become a proper pytest plugin + +import pytest +import inspect + +from ..testing import trio_test, MockClock + +RUN_SLOW = True + + +def pytest_addoption(parser): + parser.addoption("--run-slow", action="store_true", help="run slow tests") + + +def pytest_configure(config): + global RUN_SLOW + RUN_SLOW = config.getoption("--run-slow", True) + + +@pytest.fixture +def mock_clock(): + return MockClock() + + +@pytest.fixture +def autojump_clock(): + return MockClock(autojump_threshold=0) + + +# FIXME: split off into a package (or just make part of Trio's public +# interface?), with config file to enable? and I guess a mark option too; I +# guess it's useful with the class- and file-level marking machinery (where +# the raw @trio_test decorator isn't enough). +@pytest.hookimpl(tryfirst=True) +def pytest_pyfunc_call(pyfuncitem): + if inspect.iscoroutinefunction(pyfuncitem.obj): + pyfuncitem.obj = trio_test(pyfuncitem.obj) diff --git a/venv/lib/python3.9/site-packages/trio/tests/module_with_deprecations.py b/venv/lib/python3.9/site-packages/trio/tests/module_with_deprecations.py new file mode 100644 index 00000000..73184d11 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/module_with_deprecations.py @@ -0,0 +1,21 @@ +regular = "hi" + +from .. import _deprecate + +_deprecate.enable_attribute_deprecations(__name__) + +# Make sure that we don't trigger infinite recursion when accessing module +# attributes in between calling enable_attribute_deprecations and defining +# __deprecated_attributes__: +import sys + +this_mod = sys.modules[__name__] +assert this_mod.regular == "hi" +assert not hasattr(this_mod, "dep1") + +__deprecated_attributes__ = { + "dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1), + "dep2": _deprecate.DeprecatedAttribute( + "value2", "1.2", issue=1, instead="instead-string" + ), +} diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_abc.py b/venv/lib/python3.9/site-packages/trio/tests/test_abc.py new file mode 100644 index 00000000..c445c971 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_abc.py @@ -0,0 +1,49 @@ +import pytest + +import attr + +from ..testing import assert_checkpoints +from .. import abc as tabc + + +async def test_AsyncResource_defaults(): + @attr.s + class MyAR(tabc.AsyncResource): + record = attr.ib(factory=list) + + async def aclose(self): + self.record.append("ac") + + async with MyAR() as myar: + assert isinstance(myar, MyAR) + assert myar.record == [] + + assert myar.record == ["ac"] + + +def test_abc_generics(): + # Pythons below 3.5.2 had a typing.Generic that would throw + # errors when instantiating or subclassing a parameterized + # version of a class with any __slots__. This is why RunVar + # (which has slots) is not generic. This tests that + # the generic ABCs are fine, because while they are slotted + # they don't actually define any slots. + + class SlottedChannel(tabc.SendChannel[tabc.Stream]): + __slots__ = ("x",) + + def send_nowait(self, value): + raise RuntimeError + + async def send(self, value): + raise RuntimeError # pragma: no cover + + def clone(self): + raise RuntimeError # pragma: no cover + + async def aclose(self): + pass # pragma: no cover + + channel = SlottedChannel() + with pytest.raises(RuntimeError): + channel.send_nowait(None) diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_channel.py b/venv/lib/python3.9/site-packages/trio/tests/test_channel.py new file mode 100644 index 00000000..fd990fb3 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_channel.py @@ -0,0 +1,407 @@ +import pytest + +from ..testing import wait_all_tasks_blocked, assert_checkpoints +import trio +from trio import open_memory_channel, EndOfChannel + + +async def test_channel(): + with pytest.raises(TypeError): + open_memory_channel(1.0) + with pytest.raises(ValueError): + open_memory_channel(-1) + + s, r = open_memory_channel(2) + repr(s) # smoke test + repr(r) # smoke test + + s.send_nowait(1) + with assert_checkpoints(): + await s.send(2) + with pytest.raises(trio.WouldBlock): + s.send_nowait(None) + + with assert_checkpoints(): + assert await r.receive() == 1 + assert r.receive_nowait() == 2 + with pytest.raises(trio.WouldBlock): + r.receive_nowait() + + s.send_nowait("last") + await s.aclose() + with pytest.raises(trio.ClosedResourceError): + await s.send("too late") + with pytest.raises(trio.ClosedResourceError): + s.send_nowait("too late") + with pytest.raises(trio.ClosedResourceError): + s.clone() + await s.aclose() + + assert r.receive_nowait() == "last" + with pytest.raises(EndOfChannel): + await r.receive() + await r.aclose() + with pytest.raises(trio.ClosedResourceError): + await r.receive() + with pytest.raises(trio.ClosedResourceError): + await r.receive_nowait() + await r.aclose() + + +async def test_553(autojump_clock): + s, r = open_memory_channel(1) + with trio.move_on_after(10) as timeout_scope: + await r.receive() + assert timeout_scope.cancelled_caught + await s.send("Test for PR #553") + + +async def test_channel_multiple_producers(): + async def producer(send_channel, i): + # We close our handle when we're done with it + async with send_channel: + for j in range(3 * i, 3 * (i + 1)): + await send_channel.send(j) + + send_channel, receive_channel = open_memory_channel(0) + async with trio.open_nursery() as nursery: + # We hand out clones to all the new producers, and then close the + # original. + async with send_channel: + for i in range(10): + nursery.start_soon(producer, send_channel.clone(), i) + + got = [] + async for value in receive_channel: + got.append(value) + + got.sort() + assert got == list(range(30)) + + +async def test_channel_multiple_consumers(): + successful_receivers = set() + received = [] + + async def consumer(receive_channel, i): + async for value in receive_channel: + successful_receivers.add(i) + received.append(value) + + async with trio.open_nursery() as nursery: + send_channel, receive_channel = trio.open_memory_channel(1) + async with send_channel: + for i in range(5): + nursery.start_soon(consumer, receive_channel, i) + await wait_all_tasks_blocked() + for i in range(10): + await send_channel.send(i) + + assert successful_receivers == set(range(5)) + assert len(received) == 10 + assert set(received) == set(range(10)) + + +async def test_close_basics(): + async def send_block(s, expect): + with pytest.raises(expect): + await s.send(None) + + # closing send -> other send gets ClosedResourceError + s, r = open_memory_channel(0) + async with trio.open_nursery() as nursery: + nursery.start_soon(send_block, s, trio.ClosedResourceError) + await wait_all_tasks_blocked() + await s.aclose() + + # and it's persistent + with pytest.raises(trio.ClosedResourceError): + s.send_nowait(None) + with pytest.raises(trio.ClosedResourceError): + await s.send(None) + + # and receive gets EndOfChannel + with pytest.raises(EndOfChannel): + r.receive_nowait() + with pytest.raises(EndOfChannel): + await r.receive() + + # closing receive -> send gets BrokenResourceError + s, r = open_memory_channel(0) + async with trio.open_nursery() as nursery: + nursery.start_soon(send_block, s, trio.BrokenResourceError) + await wait_all_tasks_blocked() + await r.aclose() + + # and it's persistent + with pytest.raises(trio.BrokenResourceError): + s.send_nowait(None) + with pytest.raises(trio.BrokenResourceError): + await s.send(None) + + # closing receive -> other receive gets ClosedResourceError + async def receive_block(r): + with pytest.raises(trio.ClosedResourceError): + await r.receive() + + s, r = open_memory_channel(0) + async with trio.open_nursery() as nursery: + nursery.start_soon(receive_block, r) + await wait_all_tasks_blocked() + await r.aclose() + + # and it's persistent + with pytest.raises(trio.ClosedResourceError): + r.receive_nowait() + with pytest.raises(trio.ClosedResourceError): + await r.receive() + + +async def test_close_sync(): + async def send_block(s, expect): + with pytest.raises(expect): + await s.send(None) + + # closing send -> other send gets ClosedResourceError + s, r = open_memory_channel(0) + async with trio.open_nursery() as nursery: + nursery.start_soon(send_block, s, trio.ClosedResourceError) + await wait_all_tasks_blocked() + s.close() + + # and it's persistent + with pytest.raises(trio.ClosedResourceError): + s.send_nowait(None) + with pytest.raises(trio.ClosedResourceError): + await s.send(None) + + # and receive gets EndOfChannel + with pytest.raises(EndOfChannel): + r.receive_nowait() + with pytest.raises(EndOfChannel): + await r.receive() + + # closing receive -> send gets BrokenResourceError + s, r = open_memory_channel(0) + async with trio.open_nursery() as nursery: + nursery.start_soon(send_block, s, trio.BrokenResourceError) + await wait_all_tasks_blocked() + r.close() + + # and it's persistent + with pytest.raises(trio.BrokenResourceError): + s.send_nowait(None) + with pytest.raises(trio.BrokenResourceError): + await s.send(None) + + # closing receive -> other receive gets ClosedResourceError + async def receive_block(r): + with pytest.raises(trio.ClosedResourceError): + await r.receive() + + s, r = open_memory_channel(0) + async with trio.open_nursery() as nursery: + nursery.start_soon(receive_block, r) + await wait_all_tasks_blocked() + r.close() + + # and it's persistent + with pytest.raises(trio.ClosedResourceError): + r.receive_nowait() + with pytest.raises(trio.ClosedResourceError): + await r.receive() + + +async def test_receive_channel_clone_and_close(): + s, r = open_memory_channel(10) + + r2 = r.clone() + r3 = r.clone() + + s.send_nowait(None) + await r.aclose() + with r2: + pass + + with pytest.raises(trio.ClosedResourceError): + r.clone() + + with pytest.raises(trio.ClosedResourceError): + r2.clone() + + # Can still send, r3 is still open + s.send_nowait(None) + + await r3.aclose() + + # But now the receiver is really closed + with pytest.raises(trio.BrokenResourceError): + s.send_nowait(None) + + +async def test_close_multiple_send_handles(): + # With multiple send handles, closing one handle only wakes senders on + # that handle, but others can continue just fine + s1, r = open_memory_channel(0) + s2 = s1.clone() + + async def send_will_close(): + with pytest.raises(trio.ClosedResourceError): + await s1.send("nope") + + async def send_will_succeed(): + await s2.send("ok") + + async with trio.open_nursery() as nursery: + nursery.start_soon(send_will_close) + nursery.start_soon(send_will_succeed) + await wait_all_tasks_blocked() + await s1.aclose() + assert await r.receive() == "ok" + + +async def test_close_multiple_receive_handles(): + # With multiple receive handles, closing one handle only wakes receivers on + # that handle, but others can continue just fine + s, r1 = open_memory_channel(0) + r2 = r1.clone() + + async def receive_will_close(): + with pytest.raises(trio.ClosedResourceError): + await r1.receive() + + async def receive_will_succeed(): + assert await r2.receive() == "ok" + + async with trio.open_nursery() as nursery: + nursery.start_soon(receive_will_close) + nursery.start_soon(receive_will_succeed) + await wait_all_tasks_blocked() + await r1.aclose() + await s.send("ok") + + +async def test_inf_capacity(): + s, r = open_memory_channel(float("inf")) + + # It's accepted, and we can send all day without blocking + with s: + for i in range(10): + s.send_nowait(i) + + got = [] + async for i in r: + got.append(i) + assert got == list(range(10)) + + +async def test_statistics(): + s, r = open_memory_channel(2) + + assert s.statistics() == r.statistics() + stats = s.statistics() + assert stats.current_buffer_used == 0 + assert stats.max_buffer_size == 2 + assert stats.open_send_channels == 1 + assert stats.open_receive_channels == 1 + assert stats.tasks_waiting_send == 0 + assert stats.tasks_waiting_receive == 0 + + s.send_nowait(None) + assert s.statistics().current_buffer_used == 1 + + s2 = s.clone() + assert s.statistics().open_send_channels == 2 + await s.aclose() + assert s2.statistics().open_send_channels == 1 + + r2 = r.clone() + assert s2.statistics().open_receive_channels == 2 + await r2.aclose() + assert s2.statistics().open_receive_channels == 1 + + async with trio.open_nursery() as nursery: + s2.send_nowait(None) # fill up the buffer + assert s.statistics().current_buffer_used == 2 + nursery.start_soon(s2.send, None) + nursery.start_soon(s2.send, None) + await wait_all_tasks_blocked() + assert s.statistics().tasks_waiting_send == 2 + nursery.cancel_scope.cancel() + assert s.statistics().tasks_waiting_send == 0 + + # empty out the buffer again + try: + while True: + r.receive_nowait() + except trio.WouldBlock: + pass + + async with trio.open_nursery() as nursery: + nursery.start_soon(r.receive) + await wait_all_tasks_blocked() + assert s.statistics().tasks_waiting_receive == 1 + nursery.cancel_scope.cancel() + assert s.statistics().tasks_waiting_receive == 0 + + +async def test_channel_fairness(): + + # We can remove an item we just sent, and send an item back in after, if + # no-one else is waiting. + s, r = open_memory_channel(1) + s.send_nowait(1) + assert r.receive_nowait() == 1 + s.send_nowait(2) + assert r.receive_nowait() == 2 + + # But if someone else is waiting to receive, then they "own" the item we + # send, so we can't receive it (even though we run first): + + result = None + + async def do_receive(r): + nonlocal result + result = await r.receive() + + async with trio.open_nursery() as nursery: + nursery.start_soon(do_receive, r) + await wait_all_tasks_blocked() + s.send_nowait(2) + with pytest.raises(trio.WouldBlock): + r.receive_nowait() + assert result == 2 + + # And the analogous situation for send: if we free up a space, we can't + # immediately send something in it if someone is already waiting to do + # that + s, r = open_memory_channel(1) + s.send_nowait(1) + with pytest.raises(trio.WouldBlock): + s.send_nowait(None) + async with trio.open_nursery() as nursery: + nursery.start_soon(s.send, 2) + await wait_all_tasks_blocked() + assert r.receive_nowait() == 1 + with pytest.raises(trio.WouldBlock): + s.send_nowait(3) + assert (await r.receive()) == 2 + + +async def test_unbuffered(): + s, r = open_memory_channel(0) + with pytest.raises(trio.WouldBlock): + r.receive_nowait() + with pytest.raises(trio.WouldBlock): + s.send_nowait(1) + + async def do_send(s, v): + with assert_checkpoints(): + await s.send(v) + + async with trio.open_nursery() as nursery: + nursery.start_soon(do_send, s, 1) + with assert_checkpoints(): + assert await r.receive() == 1 + with pytest.raises(trio.WouldBlock): + r.receive_nowait() diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_contextvars.py b/venv/lib/python3.9/site-packages/trio/tests/test_contextvars.py new file mode 100644 index 00000000..63853f51 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_contextvars.py @@ -0,0 +1,52 @@ +import contextvars + +from .. import _core + +trio_testing_contextvar = contextvars.ContextVar("trio_testing_contextvar") + + +async def test_contextvars_default(): + trio_testing_contextvar.set("main") + record = [] + + async def child(): + value = trio_testing_contextvar.get() + record.append(value) + + async with _core.open_nursery() as nursery: + nursery.start_soon(child) + assert record == ["main"] + + +async def test_contextvars_set(): + trio_testing_contextvar.set("main") + record = [] + + async def child(): + trio_testing_contextvar.set("child") + value = trio_testing_contextvar.get() + record.append(value) + + async with _core.open_nursery() as nursery: + nursery.start_soon(child) + value = trio_testing_contextvar.get() + assert record == ["child"] + assert value == "main" + + +async def test_contextvars_copy(): + trio_testing_contextvar.set("main") + context = contextvars.copy_context() + trio_testing_contextvar.set("second_main") + record = [] + + async def child(): + value = trio_testing_contextvar.get() + record.append(value) + + async with _core.open_nursery() as nursery: + context.run(nursery.start_soon, child) + nursery.start_soon(child) + value = trio_testing_contextvar.get() + assert set(record) == {"main", "second_main"} + assert value == "second_main" diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_deprecate.py b/venv/lib/python3.9/site-packages/trio/tests/test_deprecate.py new file mode 100644 index 00000000..e5e1da8c --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_deprecate.py @@ -0,0 +1,243 @@ +import pytest + +import inspect +import warnings + +from .._deprecate import ( + TrioDeprecationWarning, + warn_deprecated, + deprecated, + deprecated_alias, +) + +from . import module_with_deprecations + + +@pytest.fixture +def recwarn_always(recwarn): + warnings.simplefilter("always") + # ResourceWarnings about unclosed sockets can occur nondeterministically + # (during GC) which throws off the tests in this file + warnings.simplefilter("ignore", ResourceWarning) + return recwarn + + +def _here(): + info = inspect.getframeinfo(inspect.currentframe().f_back) + return (info.filename, info.lineno) + + +def test_warn_deprecated(recwarn_always): + def deprecated_thing(): + warn_deprecated("ice", "1.2", issue=1, instead="water") + + deprecated_thing() + filename, lineno = _here() + assert len(recwarn_always) == 1 + got = recwarn_always.pop(TrioDeprecationWarning) + assert "ice is deprecated" in got.message.args[0] + assert "Trio 1.2" in got.message.args[0] + assert "water instead" in got.message.args[0] + assert "/issues/1" in got.message.args[0] + assert got.filename == filename + assert got.lineno == lineno - 1 + + +def test_warn_deprecated_no_instead_or_issue(recwarn_always): + # Explicitly no instead or issue + warn_deprecated("water", "1.3", issue=None, instead=None) + assert len(recwarn_always) == 1 + got = recwarn_always.pop(TrioDeprecationWarning) + assert "water is deprecated" in got.message.args[0] + assert "no replacement" in got.message.args[0] + assert "Trio 1.3" in got.message.args[0] + + +def test_warn_deprecated_stacklevel(recwarn_always): + def nested1(): + nested2() + + def nested2(): + warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3) + + filename, lineno = _here() + nested1() + got = recwarn_always.pop(TrioDeprecationWarning) + assert got.filename == filename + assert got.lineno == lineno + 1 + + +def old(): # pragma: no cover + pass + + +def new(): # pragma: no cover + pass + + +def test_warn_deprecated_formatting(recwarn_always): + warn_deprecated(old, "1.0", issue=1, instead=new) + got = recwarn_always.pop(TrioDeprecationWarning) + assert "test_deprecate.old is deprecated" in got.message.args[0] + assert "test_deprecate.new instead" in got.message.args[0] + + +@deprecated("1.5", issue=123, instead=new) +def deprecated_old(): + return 3 + + +def test_deprecated_decorator(recwarn_always): + assert deprecated_old() == 3 + got = recwarn_always.pop(TrioDeprecationWarning) + assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0] + assert "1.5" in got.message.args[0] + assert "test_deprecate.new" in got.message.args[0] + assert "issues/123" in got.message.args[0] + + +class Foo: + @deprecated("1.0", issue=123, instead="crying") + def method(self): + return 7 + + +def test_deprecated_decorator_method(recwarn_always): + f = Foo() + assert f.method() == 7 + got = recwarn_always.pop(TrioDeprecationWarning) + assert "test_deprecate.Foo.method is deprecated" in got.message.args[0] + + +@deprecated("1.2", thing="the thing", issue=None, instead=None) +def deprecated_with_thing(): + return 72 + + +def test_deprecated_decorator_with_explicit_thing(recwarn_always): + assert deprecated_with_thing() == 72 + got = recwarn_always.pop(TrioDeprecationWarning) + assert "the thing is deprecated" in got.message.args[0] + + +def new_hotness(): + return "new hotness" + + +old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1) + + +def test_deprecated_alias(recwarn_always): + assert old_hotness() == "new hotness" + got = recwarn_always.pop(TrioDeprecationWarning) + assert "test_deprecate.old_hotness is deprecated" in got.message.args[0] + assert "1.23" in got.message.args[0] + assert "test_deprecate.new_hotness instead" in got.message.args[0] + assert "issues/1" in got.message.args[0] + + assert ".. deprecated:: 1.23" in old_hotness.__doc__ + assert "test_deprecate.new_hotness instead" in old_hotness.__doc__ + assert "issues/1>`__" in old_hotness.__doc__ + + +class Alias: + def new_hotness_method(self): + return "new hotness method" + + old_hotness_method = deprecated_alias( + "Alias.old_hotness_method", new_hotness_method, "3.21", issue=1 + ) + + +def test_deprecated_alias_method(recwarn_always): + obj = Alias() + assert obj.old_hotness_method() == "new hotness method" + got = recwarn_always.pop(TrioDeprecationWarning) + msg = got.message.args[0] + assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg + assert "test_deprecate.Alias.new_hotness_method instead" in msg + + +@deprecated("2.1", issue=1, instead="hi") +def docstring_test1(): # pragma: no cover + """Hello!""" + + +@deprecated("2.1", issue=None, instead="hi") +def docstring_test2(): # pragma: no cover + """Hello!""" + + +@deprecated("2.1", issue=1, instead=None) +def docstring_test3(): # pragma: no cover + """Hello!""" + + +@deprecated("2.1", issue=None, instead=None) +def docstring_test4(): # pragma: no cover + """Hello!""" + + +def test_deprecated_docstring_munging(): + assert ( + docstring_test1.__doc__ + == """Hello! + +.. deprecated:: 2.1 + Use hi instead. + For details, see `issue #1 `__. + +""" + ) + + assert ( + docstring_test2.__doc__ + == """Hello! + +.. deprecated:: 2.1 + Use hi instead. + +""" + ) + + assert ( + docstring_test3.__doc__ + == """Hello! + +.. deprecated:: 2.1 + For details, see `issue #1 `__. + +""" + ) + + assert ( + docstring_test4.__doc__ + == """Hello! + +.. deprecated:: 2.1 + +""" + ) + + +def test_module_with_deprecations(recwarn_always): + assert module_with_deprecations.regular == "hi" + assert len(recwarn_always) == 0 + + filename, lineno = _here() + assert module_with_deprecations.dep1 == "value1" + got = recwarn_always.pop(TrioDeprecationWarning) + assert got.filename == filename + assert got.lineno == lineno + 1 + + assert "module_with_deprecations.dep1" in got.message.args[0] + assert "Trio 1.1" in got.message.args[0] + assert "/issues/1" in got.message.args[0] + assert "value1 instead" in got.message.args[0] + + assert module_with_deprecations.dep2 == "value2" + got = recwarn_always.pop(TrioDeprecationWarning) + assert "instead-string instead" in got.message.args[0] + + with pytest.raises(AttributeError): + module_with_deprecations.asdf diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_dtls.py b/venv/lib/python3.9/site-packages/trio/tests/test_dtls.py new file mode 100644 index 00000000..8968d9a6 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_dtls.py @@ -0,0 +1,867 @@ +import pytest +import trio +import trio.testing +from trio import DTLSEndpoint +import random +import attr +from async_generator import asynccontextmanager +from itertools import count + +import trustme +from OpenSSL import SSL + +from trio.testing._fake_net import FakeNet +from .._core.tests.tutil import slow, binds_ipv6, gc_collect_harder + +ca = trustme.CA() +server_cert = ca.issue_cert("example.com") + +server_ctx = SSL.Context(SSL.DTLS_METHOD) +server_cert.configure_cert(server_ctx) + +client_ctx = SSL.Context(SSL.DTLS_METHOD) +ca.configure_trust(client_ctx) + + +parametrize_ipv6 = pytest.mark.parametrize( + "ipv6", [False, pytest.param(True, marks=binds_ipv6)], ids=["ipv4", "ipv6"] +) + + +def endpoint(**kwargs): + ipv6 = kwargs.pop("ipv6", False) + if ipv6: + family = trio.socket.AF_INET6 + else: + family = trio.socket.AF_INET + sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family) + return DTLSEndpoint(sock, **kwargs) + + +@asynccontextmanager +async def dtls_echo_server(*, autocancel=True, mtu=None, ipv6=False): + with endpoint(ipv6=ipv6) as server: + if ipv6: + localhost = "::1" + else: + localhost = "127.0.0.1" + await server.socket.bind((localhost, 0)) + async with trio.open_nursery() as nursery: + + async def echo_handler(dtls_channel): + print( + f"echo handler started: " + f"server {dtls_channel.endpoint.socket.getsockname()} " + f"client {dtls_channel.peer_address}" + ) + if mtu is not None: + dtls_channel.set_ciphertext_mtu(mtu) + try: + print("server starting do_handshake") + await dtls_channel.do_handshake() + print("server finished do_handshake") + async for packet in dtls_channel: + print(f"echoing {packet} -> {dtls_channel.peer_address}") + await dtls_channel.send(packet) + except trio.BrokenResourceError: # pragma: no cover + print("echo handler channel broken") + + await nursery.start(server.serve, server_ctx, echo_handler) + + yield server, server.socket.getsockname() + + if autocancel: + nursery.cancel_scope.cancel() + + +@parametrize_ipv6 +async def test_smoke(ipv6): + async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address): + with endpoint(ipv6=ipv6) as client_endpoint: + client_channel = client_endpoint.connect(address, client_ctx) + with pytest.raises(trio.NeedHandshakeError): + client_channel.get_cleartext_mtu() + + await client_channel.do_handshake() + await client_channel.send(b"hello") + assert await client_channel.receive() == b"hello" + await client_channel.send(b"goodbye") + assert await client_channel.receive() == b"goodbye" + + with pytest.raises(ValueError): + await client_channel.send(b"") + + client_channel.set_ciphertext_mtu(1234) + cleartext_mtu_1234 = client_channel.get_cleartext_mtu() + client_channel.set_ciphertext_mtu(4321) + assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234 + client_channel.set_ciphertext_mtu(1234) + assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234 + + +@slow +async def test_handshake_over_terrible_network(autojump_clock): + HANDSHAKES = 1000 + r = random.Random(0) + fn = FakeNet() + fn.enable() + + async with dtls_echo_server() as (_, address): + async with trio.open_nursery() as nursery: + + async def route_packet(packet): + while True: + op = r.choices( + ["deliver", "drop", "dupe", "delay"], + weights=[0.7, 0.1, 0.1, 0.1], + )[0] + print(f"{packet.source} -> {packet.destination}: {op}") + if op == "drop": + return + elif op == "dupe": + fn.send_packet(packet) + elif op == "delay": + await trio.sleep(r.random() * 3) + # I wanted to test random packet corruption too, but it turns out + # openssl has a bug in the following scenario: + # + # - client sends ClientHello + # - server sends HelloVerifyRequest with cookie -- but cookie is + # invalid b/c either the ClientHello or HelloVerifyRequest was + # corrupted + # - client re-sends ClientHello with invalid cookie + # - server replies with new HelloVerifyRequest and correct cookie + # + # At this point, the client *should* switch to the new, valid + # cookie. But OpenSSL doesn't; it stubbornly insists on re-sending + # the original, invalid cookie over and over. In theory we could + # work around this by detecting cookie changes and starting over + # with a whole new SSL object, but (a) it doesn't seem worth it, (b) + # when I tried then I ran into another issue where OpenSSL got stuck + # in an infinite loop sending alerts over and over, which I didn't + # dig into because see (a). + # + # elif op == "distort": + # payload = bytearray(packet.payload) + # payload[r.randrange(len(payload))] ^= 1 << r.randrange(8) + # packet = attr.evolve(packet, payload=payload) + else: + assert op == "deliver" + print( + f"{packet.source} -> {packet.destination}: delivered {packet.payload.hex()}" + ) + fn.deliver_packet(packet) + break + + def route_packet_wrapper(packet): + try: + nursery.start_soon(route_packet, packet) + except RuntimeError: # pragma: no cover + # We're exiting the nursery, so any remaining packets can just get + # dropped + pass + + fn.route_packet = route_packet_wrapper + + for i in range(HANDSHAKES): + print("#" * 80) + print("#" * 80) + print("#" * 80) + with endpoint() as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + print("client starting do_handshake") + await client.do_handshake() + print("client finished do_handshake") + msg = str(i).encode() + # Make multiple attempts to send data, because the network might + # drop it + while True: + with trio.move_on_after(10) as cscope: + await client.send(msg) + assert await client.receive() == msg + if not cscope.cancelled_caught: + break + + +async def test_implicit_handshake(): + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + + # Implicit handshake + await client.send(b"xyz") + assert await client.receive() == b"xyz" + + +async def test_full_duplex(): + # Tests simultaneous send/receive, and also multiple methods implicitly invoking + # do_handshake simultaneously. + with endpoint() as server_endpoint, endpoint() as client_endpoint: + await server_endpoint.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as server_nursery: + + async def handler(channel): + async with trio.open_nursery() as nursery: + nursery.start_soon(channel.send, b"from server") + nursery.start_soon(channel.receive) + + await server_nursery.start(server_endpoint.serve, server_ctx, handler) + + client = client_endpoint.connect( + server_endpoint.socket.getsockname(), client_ctx + ) + async with trio.open_nursery() as nursery: + nursery.start_soon(client.send, b"from client") + nursery.start_soon(client.receive) + + server_nursery.cancel_scope.cancel() + + +async def test_channel_closing(): + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake() + client.close() + + with pytest.raises(trio.ClosedResourceError): + await client.send(b"abc") + with pytest.raises(trio.ClosedResourceError): + await client.receive() + + # close is idempotent + client.close() + # can also aclose + await client.aclose() + + +async def test_serve_exits_cleanly_on_close(): + async with dtls_echo_server(autocancel=False) as (server_endpoint, address): + server_endpoint.close() + # Testing that the nursery exits even without being cancelled + # close is idempotent + server_endpoint.close() + + +async def test_client_multiplex(): + async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2): + with endpoint() as client_endpoint: + client1 = client_endpoint.connect(address1, client_ctx) + client2 = client_endpoint.connect(address2, client_ctx) + + await client1.send(b"abc") + await client2.send(b"xyz") + assert await client2.receive() == b"xyz" + assert await client1.receive() == b"abc" + + client_endpoint.close() + + with pytest.raises(trio.ClosedResourceError): + await client1.send("xxx") + with pytest.raises(trio.ClosedResourceError): + await client2.receive() + with pytest.raises(trio.ClosedResourceError): + client_endpoint.connect(address1, client_ctx) + + async with trio.open_nursery() as nursery: + with pytest.raises(trio.ClosedResourceError): + + async def null_handler(_): # pragma: no cover + pass + + await nursery.start(client_endpoint.serve, server_ctx, null_handler) + + +async def test_dtls_over_dgram_only(): + with trio.socket.socket() as s: + with pytest.raises(ValueError): + DTLSEndpoint(s) + + +async def test_double_serve(): + async def null_handler(_): # pragma: no cover + pass + + with endpoint() as server_endpoint: + await server_endpoint.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as nursery: + await nursery.start(server_endpoint.serve, server_ctx, null_handler) + with pytest.raises(trio.BusyResourceError): + await nursery.start(server_endpoint.serve, server_ctx, null_handler) + + nursery.cancel_scope.cancel() + + async with trio.open_nursery() as nursery: + await nursery.start(server_endpoint.serve, server_ctx, null_handler) + nursery.cancel_scope.cancel() + + +async def test_connect_to_non_server(autojump_clock): + fn = FakeNet() + fn.enable() + with endpoint() as client1, endpoint() as client2: + await client1.socket.bind(("127.0.0.1", 0)) + # This should just time out + with trio.move_on_after(100) as cscope: + channel = client2.connect(client1.socket.getsockname(), client_ctx) + await channel.do_handshake() + assert cscope.cancelled_caught + + +async def test_incoming_buffer_overflow(autojump_clock): + fn = FakeNet() + fn.enable() + for buffer_size in [10, 20]: + async with dtls_echo_server() as (_, address): + with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint: + assert client_endpoint.incoming_packets_buffer == buffer_size + client = client_endpoint.connect(address, client_ctx) + for i in range(buffer_size + 15): + await client.send(str(i).encode()) + await trio.sleep(1) + stats = client.statistics() + assert stats.incoming_packets_dropped_in_trio == 15 + for i in range(buffer_size): + assert await client.receive() == str(i).encode() + await client.send(b"buffer clear now") + assert await client.receive() == b"buffer clear now" + + +async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): + fn = FakeNet() + fn.enable() + + from trio._dtls import ( + Record, + encode_record, + HandshakeFragment, + encode_handshake_fragment, + ContentType, + HandshakeType, + ProtocolVersion, + ) + + client_hello = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=10, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ), + ) + ) + + client_hello_extended = client_hello + b"\x00" + client_hello_short = client_hello[:-1] + # cuts off in middle of handshake message header + client_hello_really_short = client_hello[:14] + client_hello_corrupt_record_len = bytearray(client_hello) + client_hello_corrupt_record_len[11] = 0xFF + + client_hello_fragmented = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=20, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ), + ) + ) + + client_hello_trailing_data_in_record = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=20, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ) + + b"\x00", + ) + ) + + handshake_empty = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=b"", + ) + ) + + client_hello_truncated_in_cookie = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=bytes(2 + 32 + 1) + b"\xff", + ) + ) + + async with dtls_echo_server() as (_, address): + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock: + for bad_packet in [ + b"", + b"xyz", + client_hello_extended, + client_hello_short, + client_hello_really_short, + client_hello_corrupt_record_len, + client_hello_fragmented, + client_hello_trailing_data_in_record, + handshake_empty, + client_hello_truncated_in_cookie, + ]: + await sock.sendto(bad_packet, address) + await trio.sleep(1) + + +async def test_invalid_cookie_rejected(autojump_clock): + fn = FakeNet() + fn.enable() + + from trio._dtls import decode_client_hello_untrusted, BadPacket + + with trio.CancelScope() as cscope: + + # the first 11 bytes of ClientHello aren't protected by the cookie, so only test + # corrupting bytes after that. + offset_to_corrupt = count(11) + + def route_packet(packet): + try: + _, cookie, _ = decode_client_hello_untrusted(packet.payload) + except BadPacket: + pass + else: + if len(cookie) != 0: + # this is a challenge response packet + # let's corrupt the next offset so the handshake should fail + payload = bytearray(packet.payload) + offset = next(offset_to_corrupt) + if offset >= len(payload): + # We've tried all offsets. Clamp offset to the end of the + # payload, and terminate the test. + offset = len(payload) - 1 + cscope.cancel() + payload[offset] ^= 0x01 + packet = attr.evolve(packet, payload=payload) + + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + while True: + with endpoint() as client: + channel = client.connect(address, client_ctx) + await channel.do_handshake() + assert cscope.cancelled_caught + + +async def test_client_cancels_handshake_and_starts_new_one(autojump_clock): + # if a client disappears during the handshake, and then starts a new handshake from + # scratch, then the first handler's channel should fail, and a new handler get + # started + fn = FakeNet() + fn.enable() + + with endpoint() as server, endpoint() as client: + await server.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as nursery: + first_time = True + + async def handler(channel): + nonlocal first_time + if first_time: + first_time = False + print("handler: first time, cancelling connect") + connect_cscope.cancel() + await trio.sleep(0.5) + print("handler: handshake should fail now") + with pytest.raises(trio.BrokenResourceError): + await channel.do_handshake() + else: + print("handler: not first time, sending hello") + await channel.send(b"hello") + + await nursery.start(server.serve, server_ctx, handler) + + print("client: starting first connect") + with trio.CancelScope() as connect_cscope: + channel = client.connect(server.socket.getsockname(), client_ctx) + await channel.do_handshake() + assert connect_cscope.cancelled_caught + + print("client: starting second connect") + channel = client.connect(server.socket.getsockname(), client_ctx) + assert await channel.receive() == b"hello" + + # Give handlers a chance to finish + await trio.sleep(10) + nursery.cancel_scope.cancel() + + +async def test_swap_client_server(): + with endpoint() as a, endpoint() as b: + await a.socket.bind(("127.0.0.1", 0)) + await b.socket.bind(("127.0.0.1", 0)) + + async def echo_handler(channel): + async for packet in channel: + await channel.send(packet) + + async def crashing_echo_handler(channel): + with pytest.raises(trio.BrokenResourceError): + await echo_handler(channel) + + async with trio.open_nursery() as nursery: + await nursery.start(a.serve, server_ctx, crashing_echo_handler) + await nursery.start(b.serve, server_ctx, echo_handler) + + b_to_a = b.connect(a.socket.getsockname(), client_ctx) + await b_to_a.send(b"b as client") + assert await b_to_a.receive() == b"b as client" + + a_to_b = a.connect(b.socket.getsockname(), client_ctx) + await a_to_b.do_handshake() + with pytest.raises(trio.BrokenResourceError): + await b_to_a.send(b"association broken") + await a_to_b.send(b"a as client") + assert await a_to_b.receive() == b"a as client" + + nursery.cancel_scope.cancel() + + +@slow +async def test_openssl_retransmit_doesnt_break_stuff(): + # can't use autojump_clock here, because the point of the test is to wait for + # openssl's built-in retransmit timer to expire, which is hard-coded to use + # wall-clock time. + fn = FakeNet() + fn.enable() + + blackholed = True + + def route_packet(packet): + if blackholed: + print("dropped packet", packet) + return + print("delivered packet", packet) + # packets.append( + # scapy.all.IP( + # src=packet.source.ip.compressed, dst=packet.destination.ip.compressed + # ) + # / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port) + # / packet.payload + # ) + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (server_endpoint, address): + with endpoint() as client_endpoint: + async with trio.open_nursery() as nursery: + + async def connecter(): + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake(initial_retransmit_timeout=1.5) + await client.send(b"hi") + assert await client.receive() == b"hi" + + nursery.start_soon(connecter) + + # openssl's default timeout is 1 second, so this ensures that it thinks + # the timeout has expired + await trio.sleep(1.1) + # disable blackholing and send a garbage packet to wake up openssl so it + # notices the timeout has expired + blackholed = False + await server_endpoint.socket.sendto( + b"xxx", client_endpoint.socket.getsockname() + ) + # now the client task should finish connecting and exit cleanly + + # scapy.all.wrpcap("/tmp/trace.pcap", packets) + + +async def test_initial_retransmit_timeout_configuration(autojump_clock): + fn = FakeNet() + fn.enable() + + blackholed = True + + def route_packet(packet): + nonlocal blackholed + if blackholed: + blackholed = False + else: + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + for t in [1, 2, 4]: + with endpoint() as client: + before = trio.current_time() + blackholed = True + channel = client.connect(address, client_ctx) + await channel.do_handshake(initial_retransmit_timeout=t) + after = trio.current_time() + assert after - before == t + + +async def test_explicit_tiny_mtu_is_respected(): + # ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to + # be larger than that. (300 is still smaller than any real network though.) + MTU = 300 + + fn = FakeNet() + fn.enable() + + def route_packet(packet): + print(f"delivering {packet}") + print(f"payload size: {len(packet.payload)}") + assert len(packet.payload) <= MTU + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server(mtu=MTU) as (server, address): + with endpoint() as client: + channel = client.connect(address, client_ctx) + channel.set_ciphertext_mtu(MTU) + await channel.do_handshake() + await channel.send(b"hi") + assert await channel.receive() == b"hi" + + +@parametrize_ipv6 +async def test_handshake_handles_minimum_network_mtu(ipv6, autojump_clock): + # Fake network that has the minimum allowable MTU for whatever protocol we're using. + fn = FakeNet() + fn.enable() + + if ipv6: + mtu = 1280 - 48 + else: + mtu = 576 - 28 + + def route_packet(packet): + if len(packet.payload) > mtu: + print(f"dropping {packet}") + else: + print(f"delivering {packet}") + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + # See if we can successfully do a handshake -- some of the volleys will get dropped, + # and the retransmit logic should detect this and back off the MTU to something + # smaller until it succeeds. + async with dtls_echo_server(ipv6=ipv6) as (_, address): + with endpoint(ipv6=ipv6) as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + # the handshake mtu backoff shouldn't affect the return value from + # get_cleartext_mtu, b/c that's under the user's control via + # set_ciphertext_mtu + client.set_ciphertext_mtu(9999) + await client.send(b"xyz") + assert await client.receive() == b"xyz" + assert client.get_cleartext_mtu() > 9000 # as vegeta said + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_system_task_cleaned_up_on_gc(): + before_tasks = trio.lowlevel.current_statistics().tasks_living + + # We put this into a sub-function so that everything automatically becomes garbage + # when the frame exits. For some reason just doing 'del e' wasn't enough on pypy + # with coverage enabled -- I think we were hitting this bug: + # https://foss.heptapod.net/pypy/pypy/-/issues/3656 + async def start_and_forget_endpoint(): + e = endpoint() + + # This connection/handshake attempt can't succeed. The only purpose is to force + # the endpoint to set up a receive loop. + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s: + await s.bind(("127.0.0.1", 0)) + c = e.connect(s.getsockname(), client_ctx) + async with trio.open_nursery() as nursery: + nursery.start_soon(c.do_handshake) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + during_tasks = trio.lowlevel.current_statistics().tasks_living + return during_tasks + + with pytest.warns(ResourceWarning): + during_tasks = await start_and_forget_endpoint() + await trio.testing.wait_all_tasks_blocked() + gc_collect_harder() + + await trio.testing.wait_all_tasks_blocked() + + after_tasks = trio.lowlevel.current_statistics().tasks_living + assert before_tasks < during_tasks + assert before_tasks == after_tasks + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_gc_before_system_task_starts(): + e = endpoint() + + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + await trio.testing.wait_all_tasks_blocked() + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_gc_as_packet_received(): + fn = FakeNet() + fn.enable() + + e = endpoint() + await e.socket.bind(("127.0.0.1", 0)) + e._ensure_receive_loop() + + await trio.testing.wait_all_tasks_blocked() + + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s: + await s.sendto(b"xxx", e.socket.getsockname()) + # At this point, the endpoint's receive loop has been marked runnable because it + # just received a packet; closing the endpoint socket won't interrupt that. But by + # the time it wakes up to process the packet, the endpoint will be gone. + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +def test_gc_after_trio_exits(): + async def main(): + # We use fakenet just to make sure no real sockets can leak out of the test + # case - on pypy somehow the socket was outliving the gc_collect_harder call + # below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode + # when called after trio exits, it doesn't need a real socket. + fn = FakeNet() + fn.enable() + return endpoint() + + e = trio.run(main) + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + +async def test_already_closed_socket_doesnt_crash(): + with endpoint() as e: + # We close the socket before checkpointing, so the socket will already be closed + # when the system task starts up + e.socket.close() + # Now give it a chance to start up, and hopefully not crash + await trio.testing.wait_all_tasks_blocked() + + +async def test_socket_closed_while_processing_clienthello(autojump_clock): + fn = FakeNet() + fn.enable() + + # Check what happens if the socket is discovered to be closed when sending a + # HelloVerifyRequest, since that has its own sending logic + async with dtls_echo_server() as (server, address): + + def route_packet(packet): + fn.deliver_packet(packet) + server.socket.close() + + fn.route_packet = route_packet + + with endpoint() as client_endpoint: + with trio.move_on_after(10): + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake() + + +async def test_association_replaced_while_handshake_running(autojump_clock): + fn = FakeNet() + fn.enable() + + def route_packet(packet): + pass + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + c1 = client_endpoint.connect(address, client_ctx) + async with trio.open_nursery() as nursery: + + async def doomed_handshake(): + with pytest.raises(trio.BrokenResourceError): + await c1.do_handshake() + + nursery.start_soon(doomed_handshake) + + await trio.sleep(10) + + client_endpoint.connect(address, client_ctx) + + +async def test_association_replaced_before_handshake_starts(): + fn = FakeNet() + fn.enable() + + # This test shouldn't send any packets + def route_packet(packet): # pragma: no cover + assert False + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + c1 = client_endpoint.connect(address, client_ctx) + client_endpoint.connect(address, client_ctx) + with pytest.raises(trio.BrokenResourceError): + await c1.do_handshake() + + +async def test_send_to_closed_local_port(): + # On Windows, sending a UDP packet to a closed local port can cause a weird + # ECONNRESET error later, inside the receive task. Make sure we're handling it + # properly. + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + async with trio.open_nursery() as nursery: + for i in range(1, 10): + channel = client_endpoint.connect(("127.0.0.1", i), client_ctx) + nursery.start_soon(channel.do_handshake) + channel = client_endpoint.connect(address, client_ctx) + await channel.send(b"xxx") + assert await channel.receive() == b"xxx" + nursery.cancel_scope.cancel() diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_exports.py b/venv/lib/python3.9/site-packages/trio/tests/test_exports.py new file mode 100644 index 00000000..8d6b2d61 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_exports.py @@ -0,0 +1,145 @@ +import re +import sys +import importlib +import types +import inspect +import enum + +import pytest + +import trio +import trio.testing + +from .. import _core +from .. import _util + + +def test_core_is_properly_reexported(): + # Each export from _core should be re-exported by exactly one of these + # three modules: + sources = [trio, trio.lowlevel, trio.testing] + for symbol in dir(_core): + if symbol.startswith("_") or symbol == "tests": + continue + found = 0 + for source in sources: + if symbol in dir(source) and getattr(source, symbol) is getattr( + _core, symbol + ): + found += 1 + print(symbol, found) + assert found == 1 + + +def public_modules(module): + yield module + for name, class_ in module.__dict__.items(): + if name.startswith("_"): # pragma: no cover + continue + if not isinstance(class_, types.ModuleType): + continue + if not class_.__name__.startswith(module.__name__): # pragma: no cover + continue + if class_ is module: + continue + # We should rename the trio.tests module (#274), but until then we use + # a special-case hack: + if class_.__name__ == "trio.tests": + continue + yield from public_modules(class_) + + +PUBLIC_MODULES = list(public_modules(trio)) +PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES] + + +# It doesn't make sense for downstream redistributors to run this test, since +# they might be using a newer version of Python with additional symbols which +# won't be reflected in trio.socket, and this shouldn't cause downstream test +# runs to start failing. +@pytest.mark.redistributors_should_skip +# pylint/jedi often have trouble with alpha releases, where Python's internals +# are in flux, grammar may not have settled down, etc. +@pytest.mark.skipif( + sys.version_info.releaselevel == "alpha", + reason="skip static introspection tools on Python dev/alpha releases", +) +@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES) +@pytest.mark.parametrize("tool", ["pylint", "jedi"]) +@pytest.mark.filterwarnings( + # https://github.com/pypa/setuptools/issues/3274 + "ignore:module 'sre_constants' is deprecated:DeprecationWarning", +) +def test_static_tool_sees_all_symbols(tool, modname): + module = importlib.import_module(modname) + + def no_underscores(symbols): + return {symbol for symbol in symbols if not symbol.startswith("_")} + + runtime_names = no_underscores(dir(module)) + + # We should rename the trio.tests module (#274), but until then we use a + # special-case hack: + if modname == "trio": + runtime_names.remove("tests") + + if tool == "pylint": + from pylint.lint import PyLinter + + linter = PyLinter() + ast = linter.get_ast(module.__file__, modname) + static_names = no_underscores(ast) + elif tool == "jedi": + import jedi + + # Simulate typing "import trio; trio." + script = jedi.Script("import {}; {}.".format(modname, modname)) + completions = script.complete() + static_names = no_underscores(c.name for c in completions) + else: # pragma: no cover + assert False + + # It's expected that the static set will contain more names than the + # runtime set: + # - static tools are sometimes sloppy and include deleted names + # - some symbols are platform-specific at runtime, but always show up in + # static analysis (e.g. in trio.socket or trio.lowlevel) + # So we check that the runtime names are a subset of the static names. + missing_names = runtime_names - static_names + if missing_names: # pragma: no cover + print("{} can't see the following names in {}:".format(tool, modname)) + print() + for name in sorted(missing_names): + print(" {}".format(name)) + assert False + + +def test_classes_are_final(): + for module in PUBLIC_MODULES: + for name, class_ in module.__dict__.items(): + if not isinstance(class_, type): + continue + # Deprecated classes are exported with a leading underscore + if name.startswith("_"): # pragma: no cover + continue + + # Abstract classes can be subclassed, because that's the whole + # point of ABCs + if inspect.isabstract(class_): + continue + # Exceptions are allowed to be subclassed, because exception + # subclassing isn't used to inherit behavior. + if issubclass(class_, BaseException): + continue + # These are classes that are conceptually abstract, but + # inspect.isabstract returns False for boring reasons. + if class_ in {trio.abc.Instrument, trio.socket.SocketType}: + continue + # Enums have their own metaclass, so we can't use our metaclasses. + # And I don't think there's a lot of risk from people subclassing + # enums... + if issubclass(class_, enum.Enum): + continue + # ... insert other special cases here ... + + assert isinstance(class_, _util.Final) diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_fakenet.py b/venv/lib/python3.9/site-packages/trio/tests/test_fakenet.py new file mode 100644 index 00000000..bc691c9d --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_fakenet.py @@ -0,0 +1,44 @@ +import pytest + +import trio +from trio.testing._fake_net import FakeNet + + +def fn(): + fn = FakeNet() + fn.enable() + return fn + + +async def test_basic_udp(): + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + + await s1.bind(("127.0.0.1", 0)) + ip, port = s1.getsockname() + assert ip == "127.0.0.1" + assert port != 0 + await s2.sendto(b"xyz", s1.getsockname()) + data, addr = await s1.recvfrom(10) + assert data == b"xyz" + assert addr == s2.getsockname() + await s1.sendto(b"abc", s2.getsockname()) + data, addr = await s2.recvfrom(10) + assert data == b"abc" + assert addr == s1.getsockname() + + +async def test_msg_trunc(): + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + await s1.bind(("127.0.0.1", 0)) + await s2.sendto(b"xyz", s1.getsockname()) + data, addr = await s1.recvfrom(10) + + +async def test_basic_tcp(): + fn() + with pytest.raises(NotImplementedError): + trio.socket.socket() diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_file_io.py b/venv/lib/python3.9/site-packages/trio/tests/test_file_io.py new file mode 100644 index 00000000..b40f7518 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_file_io.py @@ -0,0 +1,198 @@ +import io +import os + +import pytest +from unittest import mock +from unittest.mock import sentinel + +import trio +from trio import _core +from trio._file_io import AsyncIOWrapper, _FILE_SYNC_ATTRS, _FILE_ASYNC_METHODS + + +@pytest.fixture +def path(tmpdir): + return os.fspath(tmpdir.join("test")) + + +@pytest.fixture +def wrapped(): + return mock.Mock(spec_set=io.StringIO) + + +@pytest.fixture +def async_file(wrapped): + return trio.wrap_file(wrapped) + + +def test_wrap_invalid(): + with pytest.raises(TypeError): + trio.wrap_file(str()) + + +def test_wrap_non_iobase(): + class FakeFile: + def close(self): # pragma: no cover + pass + + def write(self): # pragma: no cover + pass + + wrapped = FakeFile() + assert not isinstance(wrapped, io.IOBase) + + async_file = trio.wrap_file(wrapped) + assert isinstance(async_file, AsyncIOWrapper) + + del FakeFile.write + + with pytest.raises(TypeError): + trio.wrap_file(FakeFile()) + + +def test_wrapped_property(async_file, wrapped): + assert async_file.wrapped is wrapped + + +def test_dir_matches_wrapped(async_file, wrapped): + attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS) + + # all supported attrs in wrapped should be available in async_file + assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped)) + # all supported attrs not in wrapped should not be available in async_file + assert not any( + attr in dir(async_file) for attr in attrs if attr not in dir(wrapped) + ) + + +def test_unsupported_not_forwarded(): + class FakeFile(io.RawIOBase): + def unsupported_attr(self): # pragma: no cover + pass + + async_file = trio.wrap_file(FakeFile()) + + assert hasattr(async_file.wrapped, "unsupported_attr") + + with pytest.raises(AttributeError): + getattr(async_file, "unsupported_attr") + + +def test_sync_attrs_forwarded(async_file, wrapped): + for attr_name in _FILE_SYNC_ATTRS: + if attr_name not in dir(async_file): + continue + + assert getattr(async_file, attr_name) is getattr(wrapped, attr_name) + + +def test_sync_attrs_match_wrapper(async_file, wrapped): + for attr_name in _FILE_SYNC_ATTRS: + if attr_name in dir(async_file): + continue + + with pytest.raises(AttributeError): + getattr(async_file, attr_name) + + with pytest.raises(AttributeError): + getattr(wrapped, attr_name) + + +def test_async_methods_generated_once(async_file): + for meth_name in _FILE_ASYNC_METHODS: + if meth_name not in dir(async_file): + continue + + assert getattr(async_file, meth_name) is getattr(async_file, meth_name) + + +def test_async_methods_signature(async_file): + # use read as a representative of all async methods + assert async_file.read.__name__ == "read" + assert async_file.read.__qualname__ == "AsyncIOWrapper.read" + + assert "io.StringIO.read" in async_file.read.__doc__ + + +async def test_async_methods_wrap(async_file, wrapped): + for meth_name in _FILE_ASYNC_METHODS: + if meth_name not in dir(async_file): + continue + + meth = getattr(async_file, meth_name) + wrapped_meth = getattr(wrapped, meth_name) + + value = await meth(sentinel.argument, keyword=sentinel.keyword) + + wrapped_meth.assert_called_once_with( + sentinel.argument, keyword=sentinel.keyword + ) + assert value == wrapped_meth() + + wrapped.reset_mock() + + +async def test_async_methods_match_wrapper(async_file, wrapped): + for meth_name in _FILE_ASYNC_METHODS: + if meth_name in dir(async_file): + continue + + with pytest.raises(AttributeError): + getattr(async_file, meth_name) + + with pytest.raises(AttributeError): + getattr(wrapped, meth_name) + + +async def test_open(path): + f = await trio.open_file(path, "w") + + assert isinstance(f, AsyncIOWrapper) + + await f.aclose() + + +async def test_open_context_manager(path): + async with await trio.open_file(path, "w") as f: + assert isinstance(f, AsyncIOWrapper) + assert not f.closed + + assert f.closed + + +async def test_async_iter(): + async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar")) + expected = list(async_file.wrapped) + result = [] + async_file.wrapped.seek(0) + + async for line in async_file: + result.append(line) + + assert result == expected + + +async def test_aclose_cancelled(path): + with _core.CancelScope() as cscope: + f = await trio.open_file(path, "w") + cscope.cancel() + + with pytest.raises(_core.Cancelled): + await f.write("a") + + with pytest.raises(_core.Cancelled): + await f.aclose() + + assert f.closed + + +async def test_detach_rewraps_asynciobase(): + raw = io.BytesIO() + buffered = io.BufferedReader(raw) + + async_file = trio.wrap_file(buffered) + + detached = await async_file.detach() + + assert isinstance(detached, AsyncIOWrapper) + assert detached.wrapped is raw diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_generic.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_generic.py new file mode 100644 index 00000000..df2b2cec --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_generic.py @@ -0,0 +1,94 @@ +import pytest + +import attr + +from ..abc import SendStream, ReceiveStream +from .._highlevel_generic import StapledStream + + +@attr.s +class RecordSendStream(SendStream): + record = attr.ib(factory=list) + + async def send_all(self, data): + self.record.append(("send_all", data)) + + async def wait_send_all_might_not_block(self): + self.record.append("wait_send_all_might_not_block") + + async def aclose(self): + self.record.append("aclose") + + +@attr.s +class RecordReceiveStream(ReceiveStream): + record = attr.ib(factory=list) + + async def receive_some(self, max_bytes=None): + self.record.append(("receive_some", max_bytes)) + + async def aclose(self): + self.record.append("aclose") + + +async def test_StapledStream(): + send_stream = RecordSendStream() + receive_stream = RecordReceiveStream() + stapled = StapledStream(send_stream, receive_stream) + + assert stapled.send_stream is send_stream + assert stapled.receive_stream is receive_stream + + await stapled.send_all(b"foo") + await stapled.wait_send_all_might_not_block() + assert send_stream.record == [ + ("send_all", b"foo"), + "wait_send_all_might_not_block", + ] + send_stream.record.clear() + + await stapled.send_eof() + assert send_stream.record == ["aclose"] + send_stream.record.clear() + + async def fake_send_eof(): + send_stream.record.append("send_eof") + + send_stream.send_eof = fake_send_eof + await stapled.send_eof() + assert send_stream.record == ["send_eof"] + + send_stream.record.clear() + assert receive_stream.record == [] + + await stapled.receive_some(1234) + assert receive_stream.record == [("receive_some", 1234)] + assert send_stream.record == [] + receive_stream.record.clear() + + await stapled.aclose() + assert receive_stream.record == ["aclose"] + assert send_stream.record == ["aclose"] + + +async def test_StapledStream_with_erroring_close(): + # Make sure that if one of the aclose methods errors out, then the other + # one still gets called. + class BrokenSendStream(RecordSendStream): + async def aclose(self): + await super().aclose() + raise ValueError + + class BrokenReceiveStream(RecordReceiveStream): + async def aclose(self): + await super().aclose() + raise ValueError + + stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream()) + + with pytest.raises(ValueError) as excinfo: + await stapled.aclose() + assert isinstance(excinfo.value.__context__, ValueError) + + assert stapled.send_stream.record == ["aclose"] + assert stapled.receive_stream.record == ["aclose"] diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_listeners.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_listeners.py new file mode 100644 index 00000000..10398473 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_listeners.py @@ -0,0 +1,300 @@ +import sys + +import pytest + +import socket as stdlib_socket +import errno + +import attr + +import trio +from trio import open_tcp_listeners, serve_tcp, SocketListener, open_tcp_stream +from trio.testing import open_stream_to_socket_listener +from .. import socket as tsocket +from .._core.tests.tutil import slow, creates_ipv6, binds_ipv6 + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + + +async def test_open_tcp_listeners_basic(): + listeners = await open_tcp_listeners(0) + assert isinstance(listeners, list) + for obj in listeners: + assert isinstance(obj, SocketListener) + # Binds to wildcard address by default + assert obj.socket.family in [tsocket.AF_INET, tsocket.AF_INET6] + assert obj.socket.getsockname()[0] in ["0.0.0.0", "::"] + + listener = listeners[0] + # Make sure the backlog is at least 2 + c1 = await open_stream_to_socket_listener(listener) + c2 = await open_stream_to_socket_listener(listener) + + s1 = await listener.accept() + s2 = await listener.accept() + + # Note that we don't know which client stream is connected to which server + # stream + await s1.send_all(b"x") + await s2.send_all(b"x") + assert await c1.receive_some(1) == b"x" + assert await c2.receive_some(1) == b"x" + + for resource in [c1, c2, s1, s2] + listeners: + await resource.aclose() + + +async def test_open_tcp_listeners_specific_port_specific_host(): + # Pick a port + sock = tsocket.socket() + await sock.bind(("127.0.0.1", 0)) + host, port = sock.getsockname() + sock.close() + + (listener,) = await open_tcp_listeners(port, host=host) + async with listener: + assert listener.socket.getsockname() == (host, port) + + +@binds_ipv6 +async def test_open_tcp_listeners_ipv6_v6only(): + # Check IPV6_V6ONLY is working properly + (ipv6_listener,) = await open_tcp_listeners(0, host="::1") + async with ipv6_listener: + _, port, *_ = ipv6_listener.socket.getsockname() + + with pytest.raises(OSError): + await open_tcp_stream("127.0.0.1", port) + + +async def test_open_tcp_listeners_rebind(): + (l1,) = await open_tcp_listeners(0, host="127.0.0.1") + sockaddr1 = l1.socket.getsockname() + + # Plain old rebinding while it's still there should fail, even if we have + # SO_REUSEADDR set + with stdlib_socket.socket() as probe: + probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1) + with pytest.raises(OSError): + probe.bind(sockaddr1) + + # Now use the first listener to set up some connections in various states, + # and make sure that they don't create any obstacle to rebinding a second + # listener after the first one is closed. + c_established = await open_stream_to_socket_listener(l1) + s_established = await l1.accept() + + c_time_wait = await open_stream_to_socket_listener(l1) + s_time_wait = await l1.accept() + # Server-initiated close leaves socket in TIME_WAIT + await s_time_wait.aclose() + + await l1.aclose() + (l2,) = await open_tcp_listeners(sockaddr1[1], host="127.0.0.1") + sockaddr2 = l2.socket.getsockname() + + assert sockaddr1 == sockaddr2 + assert s_established.socket.getsockname() == sockaddr2 + assert c_time_wait.socket.getpeername() == sockaddr2 + + for resource in [ + l1, + l2, + c_established, + s_established, + c_time_wait, + s_time_wait, + ]: + await resource.aclose() + + +class FakeOSError(OSError): + pass + + +@attr.s +class FakeSocket(tsocket.SocketType): + family = attr.ib() + type = attr.ib() + proto = attr.ib() + + closed = attr.ib(default=False) + poison_listen = attr.ib(default=False) + backlog = attr.ib(default=None) + + def getsockopt(self, level, option): + if (level, option) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN): + return True + assert False # pragma: no cover + + def setsockopt(self, level, option, value): + pass + + async def bind(self, sockaddr): + pass + + def listen(self, backlog): + assert self.backlog is None + assert backlog is not None + self.backlog = backlog + if self.poison_listen: + raise FakeOSError("whoops") + + def close(self): + self.closed = True + + +@attr.s +class FakeSocketFactory: + poison_after = attr.ib() + sockets = attr.ib(factory=list) + raise_on_family = attr.ib(factory=dict) # family => errno + + def socket(self, family, type, proto): + if family in self.raise_on_family: + raise OSError(self.raise_on_family[family], "nope") + sock = FakeSocket(family, type, proto) + self.poison_after -= 1 + if self.poison_after == 0: + sock.poison_listen = True + self.sockets.append(sock) + return sock + + +@attr.s +class FakeHostnameResolver: + family_addr_pairs = attr.ib() + + async def getaddrinfo(self, host, port, family, type, proto, flags): + return [ + (family, tsocket.SOCK_STREAM, 0, "", (addr, port)) + for family, addr in self.family_addr_pairs + ] + + +async def test_open_tcp_listeners_multiple_host_cleanup_on_error(): + # If we were trying to bind to multiple hosts and one of them failed, they + # call get cleaned up before returning + fsf = FakeSocketFactory(3) + tsocket.set_custom_socket_factory(fsf) + tsocket.set_custom_hostname_resolver( + FakeHostnameResolver( + [ + (tsocket.AF_INET, "1.1.1.1"), + (tsocket.AF_INET, "2.2.2.2"), + (tsocket.AF_INET, "3.3.3.3"), + ] + ) + ) + + with pytest.raises(FakeOSError): + await open_tcp_listeners(80, host="example.org") + + assert len(fsf.sockets) == 3 + for sock in fsf.sockets: + assert sock.closed + + +async def test_open_tcp_listeners_port_checking(): + for host in ["127.0.0.1", None]: + with pytest.raises(TypeError): + await open_tcp_listeners(None, host=host) + with pytest.raises(TypeError): + await open_tcp_listeners(b"80", host=host) + with pytest.raises(TypeError): + await open_tcp_listeners("http", host=host) + + +async def test_serve_tcp(): + async def handler(stream): + await stream.send_all(b"x") + + async with trio.open_nursery() as nursery: + listeners = await nursery.start(serve_tcp, handler, 0) + stream = await open_stream_to_socket_listener(listeners[0]) + async with stream: + await stream.receive_some(1) == b"x" + nursery.cancel_scope.cancel() + + +@pytest.mark.parametrize( + "try_families", + [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}], +) +@pytest.mark.parametrize( + "fail_families", + [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}], +) +async def test_open_tcp_listeners_some_address_families_unavailable( + try_families, fail_families +): + fsf = FakeSocketFactory( + 10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families} + ) + tsocket.set_custom_socket_factory(fsf) + tsocket.set_custom_hostname_resolver( + FakeHostnameResolver([(family, "foo") for family in try_families]) + ) + + should_succeed = try_families - fail_families + + if not should_succeed: + with pytest.raises(OSError) as exc_info: + await open_tcp_listeners(80, host="example.org") + + assert "This system doesn't support" in str(exc_info.value) + if isinstance(exc_info.value.__cause__, BaseExceptionGroup): + for subexc in exc_info.value.__cause__.exceptions: + assert "nope" in str(subexc) + else: + assert isinstance(exc_info.value.__cause__, OSError) + assert "nope" in str(exc_info.value.__cause__) + else: + listeners = await open_tcp_listeners(80) + for listener in listeners: + should_succeed.remove(listener.socket.family) + assert not should_succeed + + +async def test_open_tcp_listeners_socket_fails_not_afnosupport(): + fsf = FakeSocketFactory( + 10, + raise_on_family={ + tsocket.AF_INET: errno.EAFNOSUPPORT, + tsocket.AF_INET6: errno.EINVAL, + }, + ) + tsocket.set_custom_socket_factory(fsf) + tsocket.set_custom_hostname_resolver( + FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]) + ) + + with pytest.raises(OSError) as exc_info: + await open_tcp_listeners(80, host="example.org") + assert exc_info.value.errno == errno.EINVAL + assert exc_info.value.__cause__ is None + assert "nope" in str(exc_info.value) + + +# We used to have an elaborate test that opened a real TCP listening socket +# and then tried to measure its backlog by making connections to it. And most +# of the time, it worked. But no matter what we tried, it was always fragile, +# because it had to do things like use timeouts to guess when the listening +# queue was full, sometimes the CI hosts go into SYN-cookie mode (where there +# effectively is no backlog), sometimes the host might not be enough resources +# to give us the full requested backlog... it was a mess. So now we just check +# that the backlog argument is passed through correctly. +async def test_open_tcp_listeners_backlog(): + fsf = FakeSocketFactory(99) + tsocket.set_custom_socket_factory(fsf) + for (given, expected) in [ + (None, 0xFFFF), + (99999999, 0xFFFF), + (10, 10), + (1, 1), + ]: + listeners = await open_tcp_listeners(0, backlog=given) + assert listeners + for listener in listeners: + assert listener.socket.backlog == expected diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_stream.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_stream.py new file mode 100644 index 00000000..0f3b6a0b --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_tcp_stream.py @@ -0,0 +1,574 @@ +import pytest +import sys +import socket + +import attr + +import trio +from trio.socket import AF_INET, AF_INET6, SOCK_STREAM, IPPROTO_TCP +from trio._highlevel_open_tcp_stream import ( + reorder_for_rfc_6555_section_5_4, + close_all, + open_tcp_stream, + format_host_port, +) + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + + +def test_close_all(): + class CloseMe: + closed = False + + def close(self): + self.closed = True + + class CloseKiller: + def close(self): + raise OSError + + c = CloseMe() + with close_all() as to_close: + to_close.add(c) + assert c.closed + + c = CloseMe() + with pytest.raises(RuntimeError): + with close_all() as to_close: + to_close.add(c) + raise RuntimeError + assert c.closed + + c = CloseMe() + with pytest.raises(OSError): + with close_all() as to_close: + to_close.add(CloseKiller()) + to_close.add(c) + assert c.closed + + +def test_reorder_for_rfc_6555_section_5_4(): + def fake4(i): + return ( + AF_INET, + SOCK_STREAM, + IPPROTO_TCP, + "", + ("10.0.0.{}".format(i), 80), + ) + + def fake6(i): + return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", ("::{}".format(i), 80)) + + for fake in fake4, fake6: + # No effect on homogeneous lists + targets = [fake(0), fake(1), fake(2)] + reorder_for_rfc_6555_section_5_4(targets) + assert targets == [fake(0), fake(1), fake(2)] + + # Single item lists also OK + targets = [fake(0)] + reorder_for_rfc_6555_section_5_4(targets) + assert targets == [fake(0)] + + # If the list starts out with different families in positions 0 and 1, + # then it's left alone + orig = [fake4(0), fake6(0), fake4(1), fake6(1)] + targets = list(orig) + reorder_for_rfc_6555_section_5_4(targets) + assert targets == orig + + # If not, it's reordered + targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)] + reorder_for_rfc_6555_section_5_4(targets) + assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)] + + +def test_format_host_port(): + assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80" + assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80" + assert format_host_port("example.com", 443) == "example.com:443" + assert format_host_port(b"example.com", 443) == "example.com:443" + assert format_host_port("::1", "http") == "[::1]:http" + assert format_host_port(b"::1", "http") == "[::1]:http" + + +# Make sure we can connect to localhost using real kernel sockets +async def test_open_tcp_stream_real_socket_smoketest(): + listen_sock = trio.socket.socket() + await listen_sock.bind(("127.0.0.1", 0)) + _, listen_port = listen_sock.getsockname() + listen_sock.listen(1) + client_stream = await open_tcp_stream("127.0.0.1", listen_port) + server_sock, _ = await listen_sock.accept() + await client_stream.send_all(b"x") + assert await server_sock.recv(1) == b"x" + await client_stream.aclose() + server_sock.close() + + listen_sock.close() + + +async def test_open_tcp_stream_input_validation(): + with pytest.raises(ValueError): + await open_tcp_stream(None, 80) + with pytest.raises(TypeError): + await open_tcp_stream("127.0.0.1", b"80") + + +def can_bind_127_0_0_2(): + with socket.socket() as s: + try: + s.bind(("127.0.0.2", 0)) + except OSError: + return False + return s.getsockname()[0] == "127.0.0.2" + + +async def test_local_address_real(): + with trio.socket.socket() as listener: + await listener.bind(("127.0.0.1", 0)) + listener.listen() + + # It's hard to test local_address properly, because you need multiple + # local addresses that you can bind to. Fortunately, on most Linux + # systems, you can bind to any 127.*.*.* address, and they all go + # through the loopback interface. So we can use a non-standard + # loopback address. On other systems, the only address we know for + # certain we have is 127.0.0.1, so we can't really test local_address= + # properly -- passing local_address=127.0.0.1 is indistinguishable + # from not passing local_address= at all. But, we can still do a smoke + # test to make sure the local_address= code doesn't crash. + if can_bind_127_0_0_2(): + local_address = "127.0.0.2" + else: + local_address = "127.0.0.1" + + async with await open_tcp_stream( + *listener.getsockname(), local_address=local_address + ) as client_stream: + assert client_stream.socket.getsockname()[0] == local_address + if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"): + assert client_stream.socket.getsockopt( + trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT + ) + + server_sock, remote_addr = await listener.accept() + await client_stream.aclose() + server_sock.close() + assert remote_addr[0] == local_address + + # Trying to connect to an ipv4 address with the ipv6 wildcard + # local_address should fail + with pytest.raises(OSError): + await open_tcp_stream(*listener.getsockname(), local_address="::") + + # But the ipv4 wildcard address should work + async with await open_tcp_stream( + *listener.getsockname(), local_address="0.0.0.0" + ) as client_stream: + server_sock, remote_addr = await listener.accept() + server_sock.close() + assert remote_addr == client_stream.socket.getsockname() + + +# Now, thorough tests using fake sockets + + +@attr.s(eq=False) +class FakeSocket(trio.socket.SocketType): + scenario = attr.ib() + family = attr.ib() + type = attr.ib() + proto = attr.ib() + + ip = attr.ib(default=None) + port = attr.ib(default=None) + succeeded = attr.ib(default=False) + closed = attr.ib(default=False) + failing = attr.ib(default=False) + + async def connect(self, sockaddr): + self.ip = sockaddr[0] + self.port = sockaddr[1] + assert self.ip not in self.scenario.sockets + self.scenario.sockets[self.ip] = self + self.scenario.connect_times[self.ip] = trio.current_time() + delay, result = self.scenario.ip_dict[self.ip] + await trio.sleep(delay) + if result == "error": + raise OSError("sorry") + if result == "postconnect_fail": + self.failing = True + self.succeeded = True + + def close(self): + self.closed = True + + # called when SocketStream is constructed + def setsockopt(self, *args, **kwargs): + if self.failing: + # raise something that isn't OSError as SocketStream + # ignores those + raise KeyboardInterrupt + + +class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver): + def __init__(self, port, ip_list, supported_families): + # ip_list have to be unique + ip_order = [ip for (ip, _, _) in ip_list] + assert len(set(ip_order)) == len(ip_list) + ip_dict = {} + for ip, delay, result in ip_list: + assert 0 <= delay + assert result in ["error", "success", "postconnect_fail"] + ip_dict[ip] = (delay, result) + + self.port = port + self.ip_order = ip_order + self.ip_dict = ip_dict + self.supported_families = supported_families + self.socket_count = 0 + self.sockets = {} + self.connect_times = {} + + def socket(self, family, type, proto): + if family not in self.supported_families: + raise OSError("pretending not to support this family") + self.socket_count += 1 + return FakeSocket(self, family, type, proto) + + def _ip_to_gai_entry(self, ip): + if ":" in ip: + family = trio.socket.AF_INET6 + sockaddr = (ip, self.port, 0, 0) + else: + family = trio.socket.AF_INET + sockaddr = (ip, self.port) + return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr) + + async def getaddrinfo(self, host, port, family, type, proto, flags): + assert host == b"test.example.com" + assert port == self.port + assert family == trio.socket.AF_UNSPEC + assert type == trio.socket.SOCK_STREAM + assert proto == 0 + assert flags == 0 + return [self._ip_to_gai_entry(ip) for ip in self.ip_order] + + async def getnameinfo(self, sockaddr, flags): # pragma: no cover + raise NotImplementedError + + def check(self, succeeded): + # sockets only go into self.sockets when connect is called; make sure + # all the sockets that were created did in fact go in there. + assert self.socket_count == len(self.sockets) + + for ip, socket in self.sockets.items(): + assert ip in self.ip_dict + if socket is not succeeded: + assert socket.closed + assert socket.port == self.port + + +async def run_scenario( + # The port to connect to + port, + # A list of + # (ip, delay, result) + # tuples, where delay is in seconds and result is "success" or "error" + # The ip's will be returned from getaddrinfo in this order, and then + # connect() calls to them will have the given result. + ip_list, + *, + # If False, AF_INET4/6 sockets error out on creation, before connect is + # even called. + ipv4_supported=True, + ipv6_supported=True, + # Normally, we return (winning_sock, scenario object) + # If this is True, we require there to be an exception, and return + # (exception, scenario object) + expect_error=(), + **kwargs, +): + supported_families = set() + if ipv4_supported: + supported_families.add(trio.socket.AF_INET) + if ipv6_supported: + supported_families.add(trio.socket.AF_INET6) + scenario = Scenario(port, ip_list, supported_families) + trio.socket.set_custom_hostname_resolver(scenario) + trio.socket.set_custom_socket_factory(scenario) + + try: + stream = await open_tcp_stream("test.example.com", port, **kwargs) + assert expect_error == () + scenario.check(stream.socket) + return (stream.socket, scenario) + except AssertionError: # pragma: no cover + raise + except expect_error as exc: + scenario.check(None) + return (exc, scenario) + + +async def test_one_host_quick_success(autojump_clock): + sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")]) + assert sock.ip == "1.2.3.4" + assert trio.current_time() == 0.123 + + +async def test_one_host_slow_success(autojump_clock): + sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")]) + assert sock.ip == "1.2.3.4" + assert trio.current_time() == 100 + + +async def test_one_host_quick_fail(autojump_clock): + exc, scenario = await run_scenario( + 82, [("1.2.3.4", 0.123, "error")], expect_error=OSError + ) + assert isinstance(exc, OSError) + assert trio.current_time() == 0.123 + + +async def test_one_host_slow_fail(autojump_clock): + exc, scenario = await run_scenario( + 83, [("1.2.3.4", 100, "error")], expect_error=OSError + ) + assert isinstance(exc, OSError) + assert trio.current_time() == 100 + + +async def test_one_host_failed_after_connect(autojump_clock): + exc, scenario = await run_scenario( + 83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt + ) + assert isinstance(exc, KeyboardInterrupt) + + +# With the default 0.250 second delay, the third attempt will win +async def test_basic_fallthrough(autojump_clock): + sock, scenario = await run_scenario( + 80, + [ + ("1.1.1.1", 1, "success"), + ("2.2.2.2", 1, "success"), + ("3.3.3.3", 0.2, "success"), + ], + ) + assert sock.ip == "3.3.3.3" + # current time is default time + default time + connection time + assert trio.current_time() == (0.250 + 0.250 + 0.2) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.250, + "3.3.3.3": 0.500, + } + + +async def test_early_success(autojump_clock): + sock, scenario = await run_scenario( + 80, + [ + ("1.1.1.1", 1, "success"), + ("2.2.2.2", 0.1, "success"), + ("3.3.3.3", 0.2, "success"), + ], + ) + assert sock.ip == "2.2.2.2" + assert trio.current_time() == (0.250 + 0.1) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.250, + # 3.3.3.3 was never even started + } + + +# With a 0.450 second delay, the first attempt will win +async def test_custom_delay(autojump_clock): + sock, scenario = await run_scenario( + 80, + [ + ("1.1.1.1", 1, "success"), + ("2.2.2.2", 1, "success"), + ("3.3.3.3", 0.2, "success"), + ], + happy_eyeballs_delay=0.450, + ) + assert sock.ip == "1.1.1.1" + assert trio.current_time() == 1 + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.450, + "3.3.3.3": 0.900, + } + + +async def test_custom_errors_expedite(autojump_clock): + sock, scenario = await run_scenario( + 80, + [ + ("1.1.1.1", 0.1, "error"), + ("2.2.2.2", 0.2, "error"), + ("3.3.3.3", 10, "success"), + # .25 is the default timeout + ("4.4.4.4", 0.25, "success"), + ], + ) + assert sock.ip == "4.4.4.4" + assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.1, + "3.3.3.3": 0.1 + 0.2, + "4.4.4.4": 0.1 + 0.2 + 0.25, + } + + +async def test_all_fail(autojump_clock): + exc, scenario = await run_scenario( + 80, + [ + ("1.1.1.1", 0.1, "error"), + ("2.2.2.2", 0.2, "error"), + ("3.3.3.3", 10, "error"), + ("4.4.4.4", 0.250, "error"), + ], + expect_error=OSError, + ) + assert isinstance(exc, OSError) + assert isinstance(exc.__cause__, BaseExceptionGroup) + assert len(exc.__cause__.exceptions) == 4 + assert trio.current_time() == (0.1 + 0.2 + 10) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.1, + "3.3.3.3": 0.1 + 0.2, + "4.4.4.4": 0.1 + 0.2 + 0.25, + } + + +async def test_multi_success(autojump_clock): + sock, scenario = await run_scenario( + 80, + [ + ("1.1.1.1", 0.5, "error"), + ("2.2.2.2", 10, "success"), + ("3.3.3.3", 10 - 1, "success"), + ("4.4.4.4", 10 - 2, "success"), + ("5.5.5.5", 0.5, "error"), + ], + happy_eyeballs_delay=1, + ) + assert not scenario.sockets["1.1.1.1"].succeeded + assert ( + scenario.sockets["2.2.2.2"].succeeded + or scenario.sockets["3.3.3.3"].succeeded + or scenario.sockets["4.4.4.4"].succeeded + ) + assert not scenario.sockets["5.5.5.5"].succeeded + assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"] + assert trio.current_time() == (0.5 + 10) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.5, + "3.3.3.3": 1.5, + "4.4.4.4": 2.5, + "5.5.5.5": 3.5, + } + + +async def test_does_reorder(autojump_clock): + sock, scenario = await run_scenario( + 80, + [ + ("1.1.1.1", 10, "error"), + # This would win if we tried it first... + ("2.2.2.2", 1, "success"), + # But in fact we try this first, because of section 5.4 + ("::3", 0.5, "success"), + ], + happy_eyeballs_delay=1, + ) + assert sock.ip == "::3" + assert trio.current_time() == 1 + 0.5 + assert scenario.connect_times == { + "1.1.1.1": 0, + "::3": 1, + } + + +async def test_handles_no_ipv4(autojump_clock): + sock, scenario = await run_scenario( + 80, + # Here the ipv6 addresses fail at socket creation time, so the connect + # configuration doesn't matter + [ + ("::1", 10, "success"), + ("2.2.2.2", 0, "success"), + ("::3", 0.1, "success"), + ("4.4.4.4", 0, "success"), + ], + happy_eyeballs_delay=1, + ipv4_supported=False, + ) + assert sock.ip == "::3" + assert trio.current_time() == 1 + 0.1 + assert scenario.connect_times == { + "::1": 0, + "::3": 1.0, + } + + +async def test_handles_no_ipv6(autojump_clock): + sock, scenario = await run_scenario( + 80, + # Here the ipv6 addresses fail at socket creation time, so the connect + # configuration doesn't matter + [ + ("::1", 0, "success"), + ("2.2.2.2", 10, "success"), + ("::3", 0, "success"), + ("4.4.4.4", 0.1, "success"), + ], + happy_eyeballs_delay=1, + ipv6_supported=False, + ) + assert sock.ip == "4.4.4.4" + assert trio.current_time() == 1 + 0.1 + assert scenario.connect_times == { + "2.2.2.2": 0, + "4.4.4.4": 1.0, + } + + +async def test_no_hosts(autojump_clock): + exc, scenario = await run_scenario(80, [], expect_error=OSError) + assert "no results found" in str(exc) + + +async def test_cancel(autojump_clock): + with trio.move_on_after(5) as cancel_scope: + exc, scenario = await run_scenario( + 80, + [ + ("1.1.1.1", 10, "success"), + ("2.2.2.2", 10, "success"), + ("3.3.3.3", 10, "success"), + ("4.4.4.4", 10, "success"), + ], + expect_error=BaseExceptionGroup, + ) + # What comes out should be 1 or more Cancelled errors that all belong + # to this cancel_scope; this is the easiest way to check that + raise exc + assert cancel_scope.cancelled_caught + + assert trio.current_time() == 5 + + # This should have been called already, but just to make sure, since the + # exception-handling logic in run_scenario is a bit complicated and the + # main thing we care about here is that all the sockets were cleaned up. + scenario.check(succeeded=False) diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_unix_stream.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_unix_stream.py new file mode 100644 index 00000000..211aff3e --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_open_unix_stream.py @@ -0,0 +1,67 @@ +import os +import socket +import tempfile + +import pytest + +from trio import open_unix_socket, Path +from trio._highlevel_open_unix_stream import close_on_error + +if not hasattr(socket, "AF_UNIX"): + pytestmark = pytest.mark.skip("Needs unix socket support") + + +def test_close_on_error(): + class CloseMe: + closed = False + + def close(self): + self.closed = True + + with close_on_error(CloseMe()) as c: + pass + assert not c.closed + + with pytest.raises(RuntimeError): + with close_on_error(CloseMe()) as c: + raise RuntimeError + assert c.closed + + +@pytest.mark.parametrize("filename", [4, 4.5]) +async def test_open_with_bad_filename_type(filename): + with pytest.raises(TypeError): + await open_unix_socket(filename) + + +async def test_open_bad_socket(): + # mktemp is marked as insecure, but that's okay, we don't want the file to + # exist + name = tempfile.mktemp() + with pytest.raises(FileNotFoundError): + await open_unix_socket(name) + + +async def test_open_unix_socket(): + for name_type in [Path, str]: + name = tempfile.mktemp() + serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + with serv_sock: + serv_sock.bind(name) + try: + serv_sock.listen(1) + + # The actual function we're testing + unix_socket = await open_unix_socket(name_type(name)) + + async with unix_socket: + client, _ = serv_sock.accept() + with client: + await unix_socket.send_all(b"test") + assert client.recv(2048) == b"test" + + client.sendall(b"response") + received = await unix_socket.receive_some(2048) + assert received == b"response" + finally: + os.unlink(name) diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_serve_listeners.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_serve_listeners.py new file mode 100644 index 00000000..b028092e --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_serve_listeners.py @@ -0,0 +1,145 @@ +import pytest + +from functools import partial +import errno + +import attr + +import trio +from trio.testing import memory_stream_pair, wait_all_tasks_blocked + + +@attr.s(hash=False, eq=False) +class MemoryListener(trio.abc.Listener): + closed = attr.ib(default=False) + accepted_streams = attr.ib(factory=list) + queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1))) + accept_hook = attr.ib(default=None) + + async def connect(self): + assert not self.closed + client, server = memory_stream_pair() + await self.queued_streams[0].send(server) + return client + + async def accept(self): + await trio.lowlevel.checkpoint() + assert not self.closed + if self.accept_hook is not None: + await self.accept_hook() + stream = await self.queued_streams[1].receive() + self.accepted_streams.append(stream) + return stream + + async def aclose(self): + self.closed = True + await trio.lowlevel.checkpoint() + + +async def test_serve_listeners_basic(): + listeners = [MemoryListener(), MemoryListener()] + + record = [] + + def close_hook(): + # Make sure this is a forceful close + assert trio.current_effective_deadline() == float("-inf") + record.append("closed") + + async def handler(stream): + await stream.send_all(b"123") + assert await stream.receive_some(10) == b"456" + stream.send_stream.close_hook = close_hook + stream.receive_stream.close_hook = close_hook + + async def client(listener): + s = await listener.connect() + assert await s.receive_some(10) == b"123" + await s.send_all(b"456") + + async def do_tests(parent_nursery): + async with trio.open_nursery() as nursery: + for listener in listeners: + for _ in range(3): + nursery.start_soon(client, listener) + + await wait_all_tasks_blocked() + + # verifies that all 6 streams x 2 directions each were closed ok + assert len(record) == 12 + + parent_nursery.cancel_scope.cancel() + + async with trio.open_nursery() as nursery: + l2 = await nursery.start(trio.serve_listeners, handler, listeners) + assert l2 == listeners + # This is just split into another function because gh-136 isn't + # implemented yet + nursery.start_soon(do_tests, nursery) + + for listener in listeners: + assert listener.closed + + +async def test_serve_listeners_accept_unrecognized_error(): + for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]: + listener = MemoryListener() + + async def raise_error(): + raise error + + listener.accept_hook = raise_error + + with pytest.raises(type(error)) as excinfo: + await trio.serve_listeners(None, [listener]) + assert excinfo.value is error + + +async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog): + listener = MemoryListener() + + async def raise_EMFILE(): + raise OSError(errno.EMFILE, "out of file descriptors") + + listener.accept_hook = raise_EMFILE + + # It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900 + # = 10 times total + with trio.move_on_after(0.950): + await trio.serve_listeners(None, [listener]) + + assert len(caplog.records) == 10 + for record in caplog.records: + assert "retrying" in record.msg + assert record.exc_info[1].errno == errno.EMFILE + + +async def test_serve_listeners_connection_nursery(autojump_clock): + listener = MemoryListener() + + async def handler(stream): + await trio.sleep(1) + + class Done(Exception): + pass + + async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED): + async with trio.open_nursery() as nursery: + task_status.started(nursery) + await wait_all_tasks_blocked() + assert len(nursery.child_tasks) == 10 + raise Done + + with pytest.raises(Done): + async with trio.open_nursery() as nursery: + handler_nursery = await nursery.start(connection_watcher) + await nursery.start( + partial( + trio.serve_listeners, + handler, + [listener], + handler_nursery=handler_nursery, + ) + ) + for _ in range(10): + nursery.start_soon(listener.connect) diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_socket.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_socket.py new file mode 100644 index 00000000..9dcb834d --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_socket.py @@ -0,0 +1,267 @@ +import pytest + +import sys +import socket as stdlib_socket +import errno + +from .. import _core +from ..testing import ( + check_half_closeable_stream, + wait_all_tasks_blocked, + assert_checkpoints, +) +from .._highlevel_socket import * +from .. import socket as tsocket + + +async def test_SocketStream_basics(): + # stdlib socket bad (even if connected) + a, b = stdlib_socket.socketpair() + with a, b: + with pytest.raises(TypeError): + SocketStream(a) + + # DGRAM socket bad + with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock: + with pytest.raises(ValueError): + SocketStream(sock) + + a, b = tsocket.socketpair() + with a, b: + s = SocketStream(a) + assert s.socket is a + + # Use a real, connected socket to test socket options, because + # socketpair() might give us a unix socket that doesn't support any of + # these options + with tsocket.socket() as listen_sock: + await listen_sock.bind(("127.0.0.1", 0)) + listen_sock.listen(1) + with tsocket.socket() as client_sock: + await client_sock.connect(listen_sock.getsockname()) + + s = SocketStream(client_sock) + + # TCP_NODELAY enabled by default + assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY) + # We can disable it though + s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False) + assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY) + + b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1) + assert isinstance(b, bytes) + + +async def test_SocketStream_send_all(): + BIG = 10000000 + + a_sock, b_sock = tsocket.socketpair() + with a_sock, b_sock: + a = SocketStream(a_sock) + b = SocketStream(b_sock) + + # Check a send_all that has to be split into multiple parts (on most + # platforms... on Windows every send() either succeeds or fails as a + # whole) + async def sender(): + data = bytearray(BIG) + await a.send_all(data) + # send_all uses memoryviews internally, which temporarily "lock" + # the object they view. If it doesn't clean them up properly, then + # some bytearray operations might raise an error afterwards, which + # would be a pretty weird and annoying side-effect to spring on + # users. So test that this doesn't happen, by forcing the + # bytearray's underlying buffer to be realloc'ed: + data += bytes(BIG) + # (Note: the above line of code doesn't do a very good job at + # testing anything, because: + # - on CPython, the refcount GC generally cleans up memoryviews + # for us even if we're sloppy. + # - on PyPy3, at least as of 5.7.0, the memoryview code and the + # bytearray code conspire so that resizing never fails – if + # resizing forces the bytearray's internal buffer to move, then + # all memoryview references are automagically updated (!!). + # See: + # https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227 + # But I'm leaving the test here in hopes that if this ever changes + # and we break our implementation of send_all, then we'll get some + # early warning...) + + async def receiver(): + # Make sure the sender fills up the kernel buffers and blocks + await wait_all_tasks_blocked() + nbytes = 0 + while nbytes < BIG: + nbytes += len(await b.receive_some(BIG)) + assert nbytes == BIG + + async with _core.open_nursery() as nursery: + nursery.start_soon(sender) + nursery.start_soon(receiver) + + # We know that we received BIG bytes of NULs so far. Make sure that + # was all the data in there. + await a.send_all(b"e") + assert await b.receive_some(10) == b"e" + await a.send_eof() + assert await b.receive_some(10) == b"" + + +async def fill_stream(s): + async def sender(): + while True: + await s.send_all(b"x" * 10000) + + async def waiter(nursery): + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + async with _core.open_nursery() as nursery: + nursery.start_soon(sender) + nursery.start_soon(waiter, nursery) + + +async def test_SocketStream_generic(): + async def stream_maker(): + left, right = tsocket.socketpair() + return SocketStream(left), SocketStream(right) + + async def clogged_stream_maker(): + left, right = await stream_maker() + await fill_stream(left) + await fill_stream(right) + return left, right + + await check_half_closeable_stream(stream_maker, clogged_stream_maker) + + +async def test_SocketListener(): + # Not a Trio socket + with stdlib_socket.socket() as s: + s.bind(("127.0.0.1", 0)) + s.listen(10) + with pytest.raises(TypeError): + SocketListener(s) + + # Not a SOCK_STREAM + with tsocket.socket(type=tsocket.SOCK_DGRAM) as s: + await s.bind(("127.0.0.1", 0)) + with pytest.raises(ValueError) as excinfo: + SocketListener(s) + excinfo.match(r".*SOCK_STREAM") + + # Didn't call .listen() + # macOS has no way to check for this, so skip testing it there. + if sys.platform != "darwin": + with tsocket.socket() as s: + await s.bind(("127.0.0.1", 0)) + with pytest.raises(ValueError) as excinfo: + SocketListener(s) + excinfo.match(r".*listen") + + listen_sock = tsocket.socket() + await listen_sock.bind(("127.0.0.1", 0)) + listen_sock.listen(10) + listener = SocketListener(listen_sock) + + assert listener.socket is listen_sock + + client_sock = tsocket.socket() + await client_sock.connect(listen_sock.getsockname()) + with assert_checkpoints(): + server_stream = await listener.accept() + assert isinstance(server_stream, SocketStream) + assert server_stream.socket.getsockname() == listen_sock.getsockname() + assert server_stream.socket.getpeername() == client_sock.getsockname() + + with assert_checkpoints(): + await listener.aclose() + + with assert_checkpoints(): + await listener.aclose() + + with assert_checkpoints(): + with pytest.raises(_core.ClosedResourceError): + await listener.accept() + + client_sock.close() + await server_stream.aclose() + + +async def test_SocketListener_socket_closed_underfoot(): + listen_sock = tsocket.socket() + await listen_sock.bind(("127.0.0.1", 0)) + listen_sock.listen(10) + listener = SocketListener(listen_sock) + + # Close the socket, not the listener + listen_sock.close() + + # SocketListener gives correct error + with assert_checkpoints(): + with pytest.raises(_core.ClosedResourceError): + await listener.accept() + + +async def test_SocketListener_accept_errors(): + class FakeSocket(tsocket.SocketType): + def __init__(self, events): + self._events = iter(events) + + type = tsocket.SOCK_STREAM + + # Fool the check for SO_ACCEPTCONN in SocketListener.__init__ + def getsockopt(self, level, opt): + return True + + def setsockopt(self, level, opt, value): + pass + + async def accept(self): + await _core.checkpoint() + event = next(self._events) + if isinstance(event, BaseException): + raise event + else: + return event, None + + fake_server_sock = FakeSocket([]) + + fake_listen_sock = FakeSocket( + [ + OSError(errno.ECONNABORTED, "Connection aborted"), + OSError(errno.EPERM, "Permission denied"), + OSError(errno.EPROTO, "Bad protocol"), + fake_server_sock, + OSError(errno.EMFILE, "Out of file descriptors"), + OSError(errno.EFAULT, "attempt to write to read-only memory"), + OSError(errno.ENOBUFS, "out of buffers"), + fake_server_sock, + ] + ) + + l = SocketListener(fake_listen_sock) + + with assert_checkpoints(): + s = await l.accept() + assert s.socket is fake_server_sock + + for code in [errno.EMFILE, errno.EFAULT, errno.ENOBUFS]: + with assert_checkpoints(): + with pytest.raises(OSError) as excinfo: + await l.accept() + assert excinfo.value.errno == code + + with assert_checkpoints(): + s = await l.accept() + assert s.socket is fake_server_sock + + +async def test_socket_stream_works_when_peer_has_already_closed(): + sock_a, sock_b = tsocket.socketpair() + with sock_a, sock_b: + await sock_b.send(b"x") + sock_b.close() + stream = SocketStream(sock_a) + assert await stream.receive_some(1) == b"x" + assert await stream.receive_some(1) == b"" diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_ssl_helpers.py b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_ssl_helpers.py new file mode 100644 index 00000000..c00f5dc4 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_highlevel_ssl_helpers.py @@ -0,0 +1,113 @@ +import pytest + +from functools import partial + +import attr + +import trio +from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP +import trio.testing +from .test_ssl import client_ctx, SERVER_CTX + +from .._highlevel_ssl_helpers import ( + open_ssl_over_tcp_stream, + open_ssl_over_tcp_listeners, + serve_ssl_over_tcp, +) + + +async def echo_handler(stream): + async with stream: + try: + while True: + data = await stream.receive_some(10000) + if not data: + break + await stream.send_all(data) + except trio.BrokenResourceError: + pass + + +# Resolver that always returns the given sockaddr, no matter what host/port +# you ask for. +@attr.s +class FakeHostnameResolver(trio.abc.HostnameResolver): + sockaddr = attr.ib() + + async def getaddrinfo(self, *args): + return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)] + + async def getnameinfo(self, *args): # pragma: no cover + raise NotImplementedError + + +# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners... +# noqa is needed because flake8 doesn't understand how pytest fixtures work. +async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811 + async with trio.open_nursery() as nursery: + (listener,) = await nursery.start( + partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1") + ) + async with listener: + sockaddr = listener.transport_listener.socket.getsockname() + hostname_resolver = FakeHostnameResolver(sockaddr) + trio.socket.set_custom_hostname_resolver(hostname_resolver) + + # We don't have the right trust set up + # (checks that ssl_context=None is doing some validation) + stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80) + async with stream: + with pytest.raises(trio.BrokenResourceError): + await stream.do_handshake() + + # We have the trust but not the hostname + # (checks custom ssl_context + hostname checking) + stream = await open_ssl_over_tcp_stream( + "xyzzy.example.org", 80, ssl_context=client_ctx + ) + async with stream: + with pytest.raises(trio.BrokenResourceError): + await stream.do_handshake() + + # This one should work! + stream = await open_ssl_over_tcp_stream( + "trio-test-1.example.org", 80, ssl_context=client_ctx + ) + async with stream: + assert isinstance(stream, trio.SSLStream) + assert stream.server_hostname == "trio-test-1.example.org" + await stream.send_all(b"x") + assert await stream.receive_some(1) == b"x" + + # Check https_compatible settings are being passed through + assert not stream._https_compatible + stream = await open_ssl_over_tcp_stream( + "trio-test-1.example.org", + 80, + ssl_context=client_ctx, + https_compatible=True, + # also, smoke test happy_eyeballs_delay + happy_eyeballs_delay=1, + ) + async with stream: + assert stream._https_compatible + + # Stop the echo server + nursery.cancel_scope.cancel() + + +async def test_open_ssl_over_tcp_listeners(): + (listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1") + async with listener: + assert isinstance(listener, trio.SSLListener) + tl = listener.transport_listener + assert isinstance(tl, trio.SocketListener) + assert tl.socket.getsockname()[0] == "127.0.0.1" + + assert not listener._https_compatible + + (listener,) = await open_ssl_over_tcp_listeners( + 0, SERVER_CTX, host="127.0.0.1", https_compatible=True + ) + async with listener: + assert listener._https_compatible diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_path.py b/venv/lib/python3.9/site-packages/trio/tests/test_path.py new file mode 100644 index 00000000..284bcf82 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_path.py @@ -0,0 +1,262 @@ +import os +import pathlib + +import pytest + +import trio +from trio._path import AsyncAutoWrapperType as Type +from trio._file_io import AsyncIOWrapper + + +@pytest.fixture +def path(tmpdir): + p = str(tmpdir.join("test")) + return trio.Path(p) + + +def method_pair(path, method_name): + path = pathlib.Path(path) + async_path = trio.Path(path) + return getattr(path, method_name), getattr(async_path, method_name) + + +async def test_open_is_async_context_manager(path): + async with await path.open("w") as f: + assert isinstance(f, AsyncIOWrapper) + + assert f.closed + + +async def test_magic(): + path = trio.Path("test") + + assert str(path) == "test" + assert bytes(path) == b"test" + + +cls_pairs = [ + (trio.Path, pathlib.Path), + (pathlib.Path, trio.Path), + (trio.Path, trio.Path), +] + + +@pytest.mark.parametrize("cls_a,cls_b", cls_pairs) +async def test_cmp_magic(cls_a, cls_b): + a, b = cls_a(""), cls_b("") + assert a == b + assert not a != b + + a, b = cls_a("a"), cls_b("b") + assert a < b + assert b > a + + # this is intentionally testing equivalence with none, due to the + # other=sentinel logic in _forward_magic + assert not a == None # noqa + assert not b == None # noqa + + +# upstream python3.8 bug: we should also test (pathlib.Path, trio.Path), but +# __*div__ does not properly raise NotImplementedError like the other comparison +# magic, so trio.Path's implementation does not get dispatched +cls_pairs = [ + (trio.Path, pathlib.Path), + (trio.Path, trio.Path), + (trio.Path, str), + (str, trio.Path), +] + + +@pytest.mark.parametrize("cls_a,cls_b", cls_pairs) +async def test_div_magic(cls_a, cls_b): + a, b = cls_a("a"), cls_b("b") + + result = a / b + assert isinstance(result, trio.Path) + assert str(result) == os.path.join("a", "b") + + +@pytest.mark.parametrize( + "cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] +) +@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"]) +async def test_hash_magic(cls_a, cls_b, path): + a, b = cls_a(path), cls_b(path) + assert hash(a) == hash(b) + + +async def test_forwarded_properties(path): + # use `name` as a representative of forwarded properties + + assert "name" in dir(path) + assert path.name == "test" + + +async def test_async_method_signature(path): + # use `resolve` as a representative of wrapped methods + + assert path.resolve.__name__ == "resolve" + assert path.resolve.__qualname__ == "Path.resolve" + + assert "pathlib.Path.resolve" in path.resolve.__doc__ + + +@pytest.mark.parametrize("method_name", ["is_dir", "is_file"]) +async def test_compare_async_stat_methods(method_name): + + method, async_method = method_pair(".", method_name) + + result = method() + async_result = await async_method() + + assert result == async_result + + +async def test_invalid_name_not_wrapped(path): + with pytest.raises(AttributeError): + getattr(path, "invalid_fake_attr") + + +@pytest.mark.parametrize("method_name", ["absolute", "resolve"]) +async def test_async_methods_rewrap(method_name): + + method, async_method = method_pair(".", method_name) + + result = method() + async_result = await async_method() + + assert isinstance(async_result, trio.Path) + assert str(result) == str(async_result) + + +async def test_forward_methods_rewrap(path, tmpdir): + with_name = path.with_name("foo") + with_suffix = path.with_suffix(".py") + + assert isinstance(with_name, trio.Path) + assert with_name == tmpdir.join("foo") + assert isinstance(with_suffix, trio.Path) + assert with_suffix == tmpdir.join("test.py") + + +async def test_forward_properties_rewrap(path): + assert isinstance(path.parent, trio.Path) + + +async def test_forward_methods_without_rewrap(path, tmpdir): + path = await path.parent.resolve() + + assert path.as_uri().startswith("file:///") + + +async def test_repr(): + path = trio.Path(".") + + assert repr(path) == "trio.Path('.')" + + +class MockWrapped: + unsupported = "unsupported" + _private = "private" + + +class MockWrapper: + _forwards = MockWrapped + _wraps = MockWrapped + + +async def test_type_forwards_unsupported(): + with pytest.raises(TypeError): + Type.generate_forwards(MockWrapper, {}) + + +async def test_type_wraps_unsupported(): + with pytest.raises(TypeError): + Type.generate_wraps(MockWrapper, {}) + + +async def test_type_forwards_private(): + Type.generate_forwards(MockWrapper, {"unsupported": None}) + + assert not hasattr(MockWrapper, "_private") + + +async def test_type_wraps_private(): + Type.generate_wraps(MockWrapper, {"unsupported": None}) + + assert not hasattr(MockWrapper, "_private") + + +@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath]) +async def test_path_wraps_path(path, meth): + wrapped = await path.absolute() + result = meth(path, wrapped) + if result is None: + result = path + + assert wrapped == result + + +async def test_path_nonpath(): + with pytest.raises(TypeError): + trio.Path(1) + + +async def test_open_file_can_open_path(path): + async with await trio.open_file(path, "w") as f: + assert f.name == os.fspath(path) + + +async def test_globmethods(path): + # Populate a directory tree + await path.mkdir() + await (path / "foo").mkdir() + await (path / "foo" / "_bar.txt").write_bytes(b"") + await (path / "bar.txt").write_bytes(b"") + await (path / "bar.dat").write_bytes(b"") + + # Path.glob + for _pattern, _results in { + "*.txt": {"bar.txt"}, + "**/*.txt": {"_bar.txt", "bar.txt"}, + }.items(): + entries = set() + for entry in await path.glob(_pattern): + assert isinstance(entry, trio.Path) + entries.add(entry.name) + + assert entries == _results + + # Path.rglob + entries = set() + for entry in await path.rglob("*.txt"): + assert isinstance(entry, trio.Path) + entries.add(entry.name) + + assert entries == {"_bar.txt", "bar.txt"} + + +async def test_iterdir(path): + # Populate a directory + await path.mkdir() + await (path / "foo").mkdir() + await (path / "bar.txt").write_bytes(b"") + + entries = set() + for entry in await path.iterdir(): + assert isinstance(entry, trio.Path) + entries.add(entry.name) + + assert entries == {"bar.txt", "foo"} + + +async def test_classmethods(): + assert isinstance(await trio.Path.home(), trio.Path) + + # pathlib.Path has only two classmethods + assert str(await trio.Path.home()) == os.path.expanduser("~") + assert str(await trio.Path.cwd()) == os.getcwd() + + # Wrapped method has docstring + assert trio.Path.home.__doc__ diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_scheduler_determinism.py b/venv/lib/python3.9/site-packages/trio/tests/test_scheduler_determinism.py new file mode 100644 index 00000000..e2d3167e --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_scheduler_determinism.py @@ -0,0 +1,40 @@ +import trio + + +async def scheduler_trace(): + """Returns a scheduler-dependent value we can use to check determinism.""" + trace = [] + + async def tracer(name): + for i in range(50): + trace.append((name, i)) + await trio.sleep(0) + + async with trio.open_nursery() as nursery: + for i in range(5): + nursery.start_soon(tracer, i) + + return tuple(trace) + + +def test_the_trio_scheduler_is_not_deterministic(): + # At least, not yet. See https://github.com/python-trio/trio/issues/32 + traces = [] + for _ in range(10): + traces.append(trio.run(scheduler_trace)) + assert len(set(traces)) == len(traces) + + +def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch): + monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True) + traces = [] + for _ in range(10): + state = trio._core._run._r.getstate() + try: + trio._core._run._r.seed(0) + traces.append(trio.run(scheduler_trace)) + finally: + trio._core._run._r.setstate(state) + + assert len(traces) == 10 + assert len(set(traces)) == 1 diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_signals.py b/venv/lib/python3.9/site-packages/trio/tests/test_signals.py new file mode 100644 index 00000000..235772f9 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_signals.py @@ -0,0 +1,177 @@ +import signal + +import pytest + +import trio +from .. import _core +from .._util import signal_raise +from .._signals import open_signal_receiver, _signal_handler + + +async def test_open_signal_receiver(): + orig = signal.getsignal(signal.SIGILL) + with open_signal_receiver(signal.SIGILL) as receiver: + # Raise it a few times, to exercise signal coalescing, both at the + # call_soon level and at the SignalQueue level + signal_raise(signal.SIGILL) + signal_raise(signal.SIGILL) + await _core.wait_all_tasks_blocked() + signal_raise(signal.SIGILL) + await _core.wait_all_tasks_blocked() + async for signum in receiver: # pragma: no branch + assert signum == signal.SIGILL + break + assert receiver._pending_signal_count() == 0 + signal_raise(signal.SIGILL) + async for signum in receiver: # pragma: no branch + assert signum == signal.SIGILL + break + assert receiver._pending_signal_count() == 0 + with pytest.raises(RuntimeError): + await receiver.__anext__() + assert signal.getsignal(signal.SIGILL) is orig + + +async def test_open_signal_receiver_restore_handler_after_one_bad_signal(): + orig = signal.getsignal(signal.SIGILL) + with pytest.raises(ValueError): + with open_signal_receiver(signal.SIGILL, 1234567): + pass # pragma: no cover + # Still restored even if we errored out + assert signal.getsignal(signal.SIGILL) is orig + + +async def test_open_signal_receiver_empty_fail(): + with pytest.raises(TypeError, match="No signals were provided"): + with open_signal_receiver(): + pass + + +async def test_open_signal_receiver_restore_handler_after_duplicate_signal(): + orig = signal.getsignal(signal.SIGILL) + with open_signal_receiver(signal.SIGILL, signal.SIGILL): + pass + # Still restored correctly + assert signal.getsignal(signal.SIGILL) is orig + + +async def test_catch_signals_wrong_thread(): + async def naughty(): + with open_signal_receiver(signal.SIGINT): + pass # pragma: no cover + + with pytest.raises(RuntimeError): + await trio.to_thread.run_sync(trio.run, naughty) + + +async def test_open_signal_receiver_conflict(): + with pytest.raises(trio.BusyResourceError): + with open_signal_receiver(signal.SIGILL) as receiver: + async with trio.open_nursery() as nursery: + nursery.start_soon(receiver.__anext__) + nursery.start_soon(receiver.__anext__) + + +# Blocks until all previous calls to run_sync_soon(idempotent=True) have been +# processed. +async def wait_run_sync_soon_idempotent_queue_barrier(): + ev = trio.Event() + token = _core.current_trio_token() + token.run_sync_soon(ev.set, idempotent=True) + await ev.wait() + + +async def test_open_signal_receiver_no_starvation(): + # Set up a situation where there are always 2 pending signals available to + # report, and make sure that instead of getting the same signal reported + # over and over, it alternates between reporting both of them. + with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: + try: + print(signal.getsignal(signal.SIGILL)) + previous = None + for _ in range(10): + signal_raise(signal.SIGILL) + signal_raise(signal.SIGFPE) + await wait_run_sync_soon_idempotent_queue_barrier() + if previous is None: + previous = await receiver.__anext__() + else: + got = await receiver.__anext__() + assert got in [signal.SIGILL, signal.SIGFPE] + assert got != previous + previous = got + # Clear out the last signal so it doesn't get redelivered + while receiver._pending_signal_count() != 0: + await receiver.__anext__() + except: # pragma: no cover + # If there's an unhandled exception above, then exiting the + # open_signal_receiver block might cause the signal to be + # redelivered and give us a core dump instead of a traceback... + import traceback + + traceback.print_exc() + + +async def test_catch_signals_race_condition_on_exit(): + delivered_directly = set() + + def direct_handler(signo, frame): + delivered_directly.add(signo) + + print(1) + # Test the version where the call_soon *doesn't* have a chance to run + # before we exit the with block: + with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler): + with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: + signal_raise(signal.SIGILL) + signal_raise(signal.SIGFPE) + await wait_run_sync_soon_idempotent_queue_barrier() + assert delivered_directly == {signal.SIGILL, signal.SIGFPE} + delivered_directly.clear() + + print(2) + # Test the version where the call_soon *does* have a chance to run before + # we exit the with block: + with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler): + with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: + signal_raise(signal.SIGILL) + signal_raise(signal.SIGFPE) + await wait_run_sync_soon_idempotent_queue_barrier() + assert receiver._pending_signal_count() == 2 + assert delivered_directly == {signal.SIGILL, signal.SIGFPE} + delivered_directly.clear() + + # Again, but with a SIG_IGN signal: + + print(3) + with _signal_handler({signal.SIGILL}, signal.SIG_IGN): + with open_signal_receiver(signal.SIGILL) as receiver: + signal_raise(signal.SIGILL) + await wait_run_sync_soon_idempotent_queue_barrier() + # test passes if the process reaches this point without dying + + print(4) + with _signal_handler({signal.SIGILL}, signal.SIG_IGN): + with open_signal_receiver(signal.SIGILL) as receiver: + signal_raise(signal.SIGILL) + await wait_run_sync_soon_idempotent_queue_barrier() + assert receiver._pending_signal_count() == 1 + # test passes if the process reaches this point without dying + + # Check exception chaining if there are multiple exception-raising + # handlers + def raise_handler(signum, _): + raise RuntimeError(signum) + + with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler): + with pytest.raises(RuntimeError) as excinfo: + with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: + signal_raise(signal.SIGILL) + signal_raise(signal.SIGFPE) + await wait_run_sync_soon_idempotent_queue_barrier() + assert receiver._pending_signal_count() == 2 + exc = excinfo.value + signums = {exc.args[0]} + assert isinstance(exc.__context__, RuntimeError) + signums.add(exc.__context__.args[0]) + assert signums == {signal.SIGILL, signal.SIGFPE} diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_socket.py b/venv/lib/python3.9/site-packages/trio/tests/test_socket.py new file mode 100644 index 00000000..1fa3721f --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_socket.py @@ -0,0 +1,1017 @@ +import errno + +import pytest +import attr + +import os +import socket as stdlib_socket +import inspect +import tempfile +import sys as _sys +from .._core.tests.tutil import creates_ipv6, binds_ipv6 +from .. import _core +from .. import _socket as _tsocket +from .. import socket as tsocket +from .._socket import _NUMERIC_ONLY, _try_sync +from ..testing import assert_checkpoints, wait_all_tasks_blocked + +################################################################ +# utils +################################################################ + + +class MonkeypatchedGAI: + def __init__(self, orig_getaddrinfo): + self._orig_getaddrinfo = orig_getaddrinfo + self._responses = {} + self.record = [] + + # get a normalized getaddrinfo argument tuple + def _frozenbind(self, *args, **kwargs): + sig = inspect.signature(self._orig_getaddrinfo) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + frozenbound = bound.args + assert not bound.kwargs + return frozenbound + + def set(self, response, *args, **kwargs): + self._responses[self._frozenbind(*args, **kwargs)] = response + + def getaddrinfo(self, *args, **kwargs): + bound = self._frozenbind(*args, **kwargs) + self.record.append(bound) + if bound in self._responses: + return self._responses[bound] + elif bound[-1] & stdlib_socket.AI_NUMERICHOST: + return self._orig_getaddrinfo(*args, **kwargs) + else: + raise RuntimeError("gai called with unexpected arguments {}".format(bound)) + + +@pytest.fixture +def monkeygai(monkeypatch): + controller = MonkeypatchedGAI(stdlib_socket.getaddrinfo) + monkeypatch.setattr(stdlib_socket, "getaddrinfo", controller.getaddrinfo) + return controller + + +async def test__try_sync(): + with assert_checkpoints(): + async with _try_sync(): + pass + + with assert_checkpoints(): + with pytest.raises(KeyError): + async with _try_sync(): + raise KeyError + + async with _try_sync(): + raise BlockingIOError + + def _is_ValueError(exc): + return isinstance(exc, ValueError) + + async with _try_sync(_is_ValueError): + raise ValueError + + with assert_checkpoints(): + with pytest.raises(BlockingIOError): + async with _try_sync(_is_ValueError): + raise BlockingIOError + + +################################################################ +# basic re-exports +################################################################ + + +def test_socket_has_some_reexports(): + assert tsocket.SOL_SOCKET == stdlib_socket.SOL_SOCKET + assert tsocket.TCP_NODELAY == stdlib_socket.TCP_NODELAY + assert tsocket.gaierror == stdlib_socket.gaierror + assert tsocket.ntohs == stdlib_socket.ntohs + + +################################################################ +# name resolution +################################################################ + + +async def test_getaddrinfo(monkeygai): + def check(got, expected): + # win32 returns 0 for the proto field + # musl and glibc have inconsistent handling of the canonical name + # field (https://github.com/python-trio/trio/issues/1499) + # Neither field gets used much and there isn't much opportunity for us + # to mess them up, so we don't bother checking them here + def interesting_fields(gai_tup): + # (family, type, proto, canonname, sockaddr) + family, type, proto, canonname, sockaddr = gai_tup + return (family, type, sockaddr) + + def filtered(gai_list): + return [interesting_fields(gai_tup) for gai_tup in gai_list] + + assert filtered(got) == filtered(expected) + + # Simple non-blocking non-error cases, ipv4 and ipv6: + with assert_checkpoints(): + res = await tsocket.getaddrinfo("127.0.0.1", "12345", type=tsocket.SOCK_STREAM) + + check( + res, + [ + ( + tsocket.AF_INET, # 127.0.0.1 is ipv4 + tsocket.SOCK_STREAM, + tsocket.IPPROTO_TCP, + "", + ("127.0.0.1", 12345), + ), + ], + ) + + with assert_checkpoints(): + res = await tsocket.getaddrinfo("::1", "12345", type=tsocket.SOCK_DGRAM) + check( + res, + [ + ( + tsocket.AF_INET6, + tsocket.SOCK_DGRAM, + tsocket.IPPROTO_UDP, + "", + ("::1", 12345, 0, 0), + ), + ], + ) + + monkeygai.set("x", b"host", "port", family=0, type=0, proto=0, flags=0) + with assert_checkpoints(): + res = await tsocket.getaddrinfo("host", "port") + assert res == "x" + assert monkeygai.record[-1] == (b"host", "port", 0, 0, 0, 0) + + # check raising an error from a non-blocking getaddrinfo + with assert_checkpoints(): + with pytest.raises(tsocket.gaierror) as excinfo: + await tsocket.getaddrinfo("::1", "12345", type=-1) + # Linux + glibc, Windows + expected_errnos = {tsocket.EAI_SOCKTYPE} + # Linux + musl + expected_errnos.add(tsocket.EAI_SERVICE) + # macOS + if hasattr(tsocket, "EAI_BADHINTS"): + expected_errnos.add(tsocket.EAI_BADHINTS) + assert excinfo.value.errno in expected_errnos + + # check raising an error from a blocking getaddrinfo (exploits the fact + # that monkeygai raises if it gets a non-numeric request it hasn't been + # given an answer for) + with assert_checkpoints(): + with pytest.raises(RuntimeError): + await tsocket.getaddrinfo("asdf", "12345") + + +async def test_getnameinfo(): + # Trivial test: + ni_numeric = stdlib_socket.NI_NUMERICHOST | stdlib_socket.NI_NUMERICSERV + with assert_checkpoints(): + got = await tsocket.getnameinfo(("127.0.0.1", 1234), ni_numeric) + assert got == ("127.0.0.1", "1234") + + # getnameinfo requires a numeric address as input: + with assert_checkpoints(): + with pytest.raises(tsocket.gaierror): + await tsocket.getnameinfo(("google.com", 80), 0) + + with assert_checkpoints(): + with pytest.raises(tsocket.gaierror): + await tsocket.getnameinfo(("localhost", 80), 0) + + # Blocking call to get expected values: + host, service = stdlib_socket.getnameinfo(("127.0.0.1", 80), 0) + + # Some working calls: + got = await tsocket.getnameinfo(("127.0.0.1", 80), 0) + assert got == (host, service) + + got = await tsocket.getnameinfo(("127.0.0.1", 80), tsocket.NI_NUMERICHOST) + assert got == ("127.0.0.1", service) + + got = await tsocket.getnameinfo(("127.0.0.1", 80), tsocket.NI_NUMERICSERV) + assert got == (host, "80") + + +################################################################ +# constructors +################################################################ + + +async def test_from_stdlib_socket(): + sa, sb = stdlib_socket.socketpair() + assert not isinstance(sa, tsocket.SocketType) + with sa, sb: + ta = tsocket.from_stdlib_socket(sa) + assert isinstance(ta, tsocket.SocketType) + assert sa.fileno() == ta.fileno() + await ta.send(b"x") + assert sb.recv(1) == b"x" + + # rejects other types + with pytest.raises(TypeError): + tsocket.from_stdlib_socket(1) + + class MySocket(stdlib_socket.socket): + pass + + with MySocket() as mysock: + with pytest.raises(TypeError): + tsocket.from_stdlib_socket(mysock) + + +async def test_from_fd(): + sa, sb = stdlib_socket.socketpair() + ta = tsocket.fromfd(sa.fileno(), sa.family, sa.type, sa.proto) + with sa, sb, ta: + assert ta.fileno() != sa.fileno() + await ta.send(b"x") + assert sb.recv(3) == b"x" + + +async def test_socketpair_simple(): + async def child(sock): + print("sending hello") + await sock.send(b"h") + assert await sock.recv(1) == b"h" + + a, b = tsocket.socketpair() + with a, b: + async with _core.open_nursery() as nursery: + nursery.start_soon(child, a) + nursery.start_soon(child, b) + + +@pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only") +async def test_fromshare(): + a, b = tsocket.socketpair() + with a, b: + # share with ourselves + shared = a.share(os.getpid()) + a2 = tsocket.fromshare(shared) + with a2: + assert a.fileno() != a2.fileno() + await a2.send(b"x") + assert await b.recv(1) == b"x" + + +async def test_socket(): + with tsocket.socket() as s: + assert isinstance(s, tsocket.SocketType) + assert s.family == tsocket.AF_INET + + +@creates_ipv6 +async def test_socket_v6(): + with tsocket.socket(tsocket.AF_INET6, tsocket.SOCK_DGRAM) as s: + assert isinstance(s, tsocket.SocketType) + assert s.family == tsocket.AF_INET6 + + +@pytest.mark.skipif(not _sys.platform == "linux", reason="linux only") +async def test_sniff_sockopts(): + from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM + + # generate the combinations of families/types we're testing: + sockets = [] + for family in [AF_INET, AF_INET6]: + for type in [SOCK_DGRAM, SOCK_STREAM]: + sockets.append(stdlib_socket.socket(family, type)) + for socket in sockets: + # regular Trio socket constructor + tsocket_socket = tsocket.socket(fileno=socket.fileno()) + # check family / type for correctness: + assert tsocket_socket.family == socket.family + assert tsocket_socket.type == socket.type + tsocket_socket.detach() + + # fromfd constructor + tsocket_from_fd = tsocket.fromfd(socket.fileno(), AF_INET, SOCK_STREAM) + # check family / type for correctness: + assert tsocket_from_fd.family == socket.family + assert tsocket_from_fd.type == socket.type + tsocket_from_fd.close() + + socket.close() + + +################################################################ +# _SocketType +################################################################ + + +async def test_SocketType_basics(): + sock = tsocket.socket() + with sock as cm_enter_value: + assert cm_enter_value is sock + assert isinstance(sock.fileno(), int) + assert not sock.get_inheritable() + sock.set_inheritable(True) + assert sock.get_inheritable() + + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False) + assert not sock.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY) + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, True) + assert sock.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY) + # closed sockets have fileno() == -1 + assert sock.fileno() == -1 + + # smoke test + repr(sock) + + # detach + with tsocket.socket() as sock: + fd = sock.fileno() + assert sock.detach() == fd + assert sock.fileno() == -1 + + # close + sock = tsocket.socket() + assert sock.fileno() >= 0 + sock.close() + assert sock.fileno() == -1 + + # share was tested above together with fromshare + + # check __dir__ + assert "family" in dir(sock) + assert "recv" in dir(sock) + assert "setsockopt" in dir(sock) + + # our __getattr__ handles unknown names + with pytest.raises(AttributeError): + sock.asdf + + # type family proto + stdlib_sock = stdlib_socket.socket() + sock = tsocket.from_stdlib_socket(stdlib_sock) + assert sock.type == _tsocket.real_socket_type(stdlib_sock.type) + assert sock.family == stdlib_sock.family + assert sock.proto == stdlib_sock.proto + sock.close() + + +async def test_SocketType_dup(): + a, b = tsocket.socketpair() + with a, b: + a2 = a.dup() + with a2: + assert isinstance(a2, tsocket.SocketType) + assert a2.fileno() != a.fileno() + a.close() + await a2.send(b"x") + assert await b.recv(1) == b"x" + + +async def test_SocketType_shutdown(): + a, b = tsocket.socketpair() + with a, b: + await a.send(b"x") + assert await b.recv(1) == b"x" + assert not a.did_shutdown_SHUT_WR + assert not b.did_shutdown_SHUT_WR + a.shutdown(tsocket.SHUT_WR) + assert a.did_shutdown_SHUT_WR + assert not b.did_shutdown_SHUT_WR + assert await b.recv(1) == b"" + await b.send(b"y") + assert await a.recv(1) == b"y" + + a, b = tsocket.socketpair() + with a, b: + assert not a.did_shutdown_SHUT_WR + a.shutdown(tsocket.SHUT_RD) + assert not a.did_shutdown_SHUT_WR + + a, b = tsocket.socketpair() + with a, b: + assert not a.did_shutdown_SHUT_WR + a.shutdown(tsocket.SHUT_RDWR) + assert a.did_shutdown_SHUT_WR + + +@pytest.mark.parametrize( + "address, socket_type", + [ + ("127.0.0.1", tsocket.AF_INET), + pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6), + ], +) +async def test_SocketType_simple_server(address, socket_type): + # listen, bind, accept, connect, getpeername, getsockname + listener = tsocket.socket(socket_type) + client = tsocket.socket(socket_type) + with listener, client: + await listener.bind((address, 0)) + listener.listen(20) + addr = listener.getsockname()[:2] + async with _core.open_nursery() as nursery: + nursery.start_soon(client.connect, addr) + server, client_addr = await listener.accept() + with server: + assert client_addr == server.getpeername() == client.getsockname() + await server.send(b"x") + assert await client.recv(1) == b"x" + + +async def test_SocketType_is_readable(): + a, b = tsocket.socketpair() + with a, b: + assert not a.is_readable() + await b.send(b"x") + await _core.wait_readable(a) + assert a.is_readable() + assert await a.recv(1) == b"x" + assert not a.is_readable() + + +# On some macOS systems, getaddrinfo likes to return V4-mapped addresses even +# when we *don't* pass AI_V4MAPPED. +# https://github.com/python-trio/trio/issues/580 +def gai_without_v4mapped_is_buggy(): # pragma: no cover + try: + stdlib_socket.getaddrinfo("1.2.3.4", 0, family=stdlib_socket.AF_INET6) + except stdlib_socket.gaierror: + return False + else: + return True + + +@attr.s +class Addresses: + bind_all = attr.ib() + localhost = attr.ib() + arbitrary = attr.ib() + broadcast = attr.ib() + + +# Direct thorough tests of the implicit resolver helpers +@pytest.mark.parametrize( + "socket_type, addrs", + [ + ( + tsocket.AF_INET, + Addresses( + bind_all="0.0.0.0", + localhost="127.0.0.1", + arbitrary="1.2.3.4", + broadcast="255.255.255.255", + ), + ), + pytest.param( + tsocket.AF_INET6, + Addresses( + bind_all="::", + localhost="::1", + arbitrary="1::2", + broadcast="::ffff:255.255.255.255", + ), + marks=creates_ipv6, + ), + ], +) +async def test_SocketType_resolve(socket_type, addrs): + v6 = socket_type == tsocket.AF_INET6 + + def pad(addr): + if v6: + while len(addr) < 4: + addr += (0,) + return addr + + def assert_eq(actual, expected): + assert pad(expected) == pad(actual) + + with tsocket.socket(family=socket_type) as sock: + # For some reason the stdlib special-cases "" to pass NULL to + # getaddrinfo. They also error out on None, but whatever, None is much + # more consistent, so we accept it too. + for null in [None, ""]: + got = await sock._resolve_address_nocp((null, 80), local=True) + assert_eq(got, (addrs.bind_all, 80)) + got = await sock._resolve_address_nocp((null, 80), local=False) + assert_eq(got, (addrs.localhost, 80)) + + # AI_PASSIVE only affects the wildcard address, so for everything else + # local=True/local=False should work the same: + for local in [False, True]: + + async def res(*args): + return await sock._resolve_address_nocp(*args, local=local) + + assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80)) + if v6: + # Check handling of different length ipv6 address tuples + assert_eq(await res(("1::2", 80)), ("1::2", 80, 0, 0)) + assert_eq(await res(("1::2", 80, 0)), ("1::2", 80, 0, 0)) + assert_eq(await res(("1::2", 80, 0, 0)), ("1::2", 80, 0, 0)) + # Non-zero flowinfo/scopeid get passed through + assert_eq(await res(("1::2", 80, 1)), ("1::2", 80, 1, 0)) + assert_eq(await res(("1::2", 80, 1, 2)), ("1::2", 80, 1, 2)) + + # And again with a string port, as a trick to avoid the + # already-resolved address fastpath and make sure we call + # getaddrinfo + assert_eq(await res(("1::2", "80")), ("1::2", 80, 0, 0)) + assert_eq(await res(("1::2", "80", 0)), ("1::2", 80, 0, 0)) + assert_eq(await res(("1::2", "80", 0, 0)), ("1::2", 80, 0, 0)) + assert_eq(await res(("1::2", "80", 1)), ("1::2", 80, 1, 0)) + assert_eq(await res(("1::2", "80", 1, 2)), ("1::2", 80, 1, 2)) + + # V4 mapped addresses resolved if V6ONLY is False + sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, False) + assert_eq(await res(("1.2.3.4", "http")), ("::ffff:1.2.3.4", 80)) + + # Check the special case, because why not + assert_eq(await res(("", 123)), (addrs.broadcast, 123)) + + # But not if it's true (at least on systems where getaddrinfo works + # correctly) + if v6 and not gai_without_v4mapped_is_buggy(): + sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, True) + with pytest.raises(tsocket.gaierror) as excinfo: + await res(("1.2.3.4", 80)) + # Windows, macOS + expected_errnos = {tsocket.EAI_NONAME} + # Linux + if hasattr(tsocket, "EAI_ADDRFAMILY"): + expected_errnos.add(tsocket.EAI_ADDRFAMILY) + assert excinfo.value.errno in expected_errnos + + # A family where we know nothing about the addresses, so should just + # pass them through. This should work on Linux, which is enough to + # smoke test the basic functionality... + try: + netlink_sock = tsocket.socket( + family=tsocket.AF_NETLINK, type=tsocket.SOCK_DGRAM + ) + except (AttributeError, OSError): + pass + else: + assert ( + await netlink_sock._resolve_address_nocp("asdf", local=local) + == "asdf" + ) + netlink_sock.close() + + with pytest.raises(ValueError): + await res("1.2.3.4") + with pytest.raises(ValueError): + await res(("1.2.3.4",)) + with pytest.raises(ValueError): + if v6: + await res(("1.2.3.4", 80, 0, 0, 0)) + else: + await res(("1.2.3.4", 80, 0, 0)) + + +async def test_SocketType_unresolved_names(): + with tsocket.socket() as sock: + await sock.bind(("localhost", 0)) + assert sock.getsockname()[0] == "127.0.0.1" + sock.listen(10) + + with tsocket.socket() as sock2: + await sock2.connect(("localhost", sock.getsockname()[1])) + assert sock2.getpeername() == sock.getsockname() + + # check gaierror propagates out + with tsocket.socket() as sock: + with pytest.raises(tsocket.gaierror): + # definitely not a valid request + await sock.bind(("1.2:3", -1)) + + +# This tests all the complicated paths through _nonblocking_helper, using recv +# as a stand-in for all the methods that use _nonblocking_helper. +async def test_SocketType_non_blocking_paths(): + a, b = stdlib_socket.socketpair() + with a, b: + ta = tsocket.from_stdlib_socket(a) + b.setblocking(False) + + # cancel before even calling + b.send(b"1") + with _core.CancelScope() as cscope: + cscope.cancel() + with assert_checkpoints(): + with pytest.raises(_core.Cancelled): + await ta.recv(10) + # immediate success (also checks that the previous attempt didn't + # actually read anything) + with assert_checkpoints(): + await ta.recv(10) == b"1" + # immediate failure + with assert_checkpoints(): + with pytest.raises(TypeError): + await ta.recv("haha") + # block then succeed + + async def do_successful_blocking_recv(): + with assert_checkpoints(): + assert await ta.recv(10) == b"2" + + async with _core.open_nursery() as nursery: + nursery.start_soon(do_successful_blocking_recv) + await wait_all_tasks_blocked() + b.send(b"2") + # block then cancelled + + async def do_cancelled_blocking_recv(): + with assert_checkpoints(): + with pytest.raises(_core.Cancelled): + await ta.recv(10) + + async with _core.open_nursery() as nursery: + nursery.start_soon(do_cancelled_blocking_recv) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + # Okay, here's the trickiest one: we want to exercise the path where + # the task is signaled to wake, goes to recv, but then the recv fails, + # so it has to go back to sleep and try again. Strategy: have two + # tasks waiting on two sockets (to work around the rule against having + # two tasks waiting on the same socket), wake them both up at the same + # time, and whichever one runs first "steals" the data from the + # other: + tb = tsocket.from_stdlib_socket(b) + + async def t1(): + with assert_checkpoints(): + assert await ta.recv(1) == b"a" + with assert_checkpoints(): + assert await tb.recv(1) == b"b" + + async def t2(): + with assert_checkpoints(): + assert await tb.recv(1) == b"b" + with assert_checkpoints(): + assert await ta.recv(1) == b"a" + + async with _core.open_nursery() as nursery: + nursery.start_soon(t1) + nursery.start_soon(t2) + await wait_all_tasks_blocked() + a.send(b"b") + b.send(b"a") + await wait_all_tasks_blocked() + a.send(b"b") + b.send(b"a") + + +# This tests the complicated paths through connect +async def test_SocketType_connect_paths(): + with tsocket.socket() as sock: + with pytest.raises(ValueError): + # Should be a tuple + await sock.connect("localhost") + + # cancelled before we start + with tsocket.socket() as sock: + with _core.CancelScope() as cancel_scope: + cancel_scope.cancel() + with pytest.raises(_core.Cancelled): + await sock.connect(("127.0.0.1", 80)) + + # Cancelled in between the connect() call and the connect completing + with _core.CancelScope() as cancel_scope: + with tsocket.socket() as sock, tsocket.socket() as listener: + await listener.bind(("127.0.0.1", 0)) + listener.listen() + + # Swap in our weird subclass under the trio.socket._SocketType's + # nose -- and then swap it back out again before we hit + # wait_socket_writable, which insists on a real socket. + class CancelSocket(stdlib_socket.socket): + def connect(self, *args, **kwargs): + cancel_scope.cancel() + sock._sock = stdlib_socket.fromfd( + self.detach(), self.family, self.type + ) + sock._sock.connect(*args, **kwargs) + # If connect *doesn't* raise, then pretend it did + raise BlockingIOError # pragma: no cover + + sock._sock.close() + sock._sock = CancelSocket() + + with assert_checkpoints(): + with pytest.raises(_core.Cancelled): + await sock.connect(listener.getsockname()) + assert sock.fileno() == -1 + + # Failed connect (hopefully after raising BlockingIOError) + with tsocket.socket() as sock: + with pytest.raises(OSError): + # TCP port 2 is not assigned. Pretty sure nothing will be + # listening there. (We used to bind a port and then *not* call + # listen() to ensure nothing was listening there, but it turns + # out on macOS if you do this it takes 30 seconds for the + # connect to fail. Really. Also if you use a non-routable + # address. This way fails instantly though. As long as nothing + # is listening on port 2.) + await sock.connect(("127.0.0.1", 2)) + + +async def test_resolve_address_exception_in_connect_closes_socket(): + # Here we are testing issue 247, any cancellation will leave the socket closed + with _core.CancelScope() as cancel_scope: + with tsocket.socket() as sock: + + async def _resolve_address_nocp(self, *args, **kwargs): + cancel_scope.cancel() + await _core.checkpoint() + + sock._resolve_address_nocp = _resolve_address_nocp + with assert_checkpoints(): + with pytest.raises(_core.Cancelled): + await sock.connect("") + assert sock.fileno() == -1 + + +async def test_send_recv_variants(): + a, b = tsocket.socketpair() + with a, b: + # recv, including with flags + assert await a.send(b"x") == 1 + assert await b.recv(10, tsocket.MSG_PEEK) == b"x" + assert await b.recv(10) == b"x" + + # recv_into + await a.send(b"x") + buf = bytearray(10) + await b.recv_into(buf) + assert buf == b"x" + b"\x00" * 9 + + if hasattr(a, "sendmsg"): + assert await a.sendmsg([b"xxx"], []) == 3 + assert await b.recv(10) == b"xxx" + + a = tsocket.socket(type=tsocket.SOCK_DGRAM) + b = tsocket.socket(type=tsocket.SOCK_DGRAM) + with a, b: + await a.bind(("127.0.0.1", 0)) + await b.bind(("127.0.0.1", 0)) + + targets = [b.getsockname(), ("localhost", b.getsockname()[1])] + + # recvfrom + sendto, with and without names + for target in targets: + assert await a.sendto(b"xxx", target) == 3 + (data, addr) = await b.recvfrom(10) + assert data == b"xxx" + assert addr == a.getsockname() + + # sendto + flags + # + # I can't find any flags that send() accepts... on Linux at least + # passing MSG_MORE to send_some on a connected UDP socket seems to + # just be ignored. + # + # But there's no MSG_MORE on Windows or macOS. I guess send_some flags + # are really not very useful, but at least this tests them a bit. + if hasattr(tsocket, "MSG_MORE"): + await a.sendto(b"xxx", tsocket.MSG_MORE, b.getsockname()) + await a.sendto(b"yyy", tsocket.MSG_MORE, b.getsockname()) + await a.sendto(b"zzz", b.getsockname()) + (data, addr) = await b.recvfrom(10) + assert data == b"xxxyyyzzz" + assert addr == a.getsockname() + + # recvfrom_into + assert await a.sendto(b"xxx", b.getsockname()) == 3 + buf = bytearray(10) + (nbytes, addr) = await b.recvfrom_into(buf) + assert nbytes == 3 + assert buf == b"xxx" + b"\x00" * 7 + assert addr == a.getsockname() + + if hasattr(b, "recvmsg"): + assert await a.sendto(b"xxx", b.getsockname()) == 3 + (data, ancdata, msg_flags, addr) = await b.recvmsg(10) + assert data == b"xxx" + assert ancdata == [] + assert msg_flags == 0 + assert addr == a.getsockname() + + if hasattr(b, "recvmsg_into"): + assert await a.sendto(b"xyzw", b.getsockname()) == 4 + buf1 = bytearray(2) + buf2 = bytearray(3) + ret = await b.recvmsg_into([buf1, buf2]) + (nbytes, ancdata, msg_flags, addr) = ret + assert nbytes == 4 + assert buf1 == b"xy" + assert buf2 == b"zw" + b"\x00" + assert ancdata == [] + assert msg_flags == 0 + assert addr == a.getsockname() + + if hasattr(a, "sendmsg"): + for target in targets: + assert await a.sendmsg([b"x", b"yz"], [], 0, target) == 3 + assert await b.recvfrom(10) == (b"xyz", a.getsockname()) + + a = tsocket.socket(type=tsocket.SOCK_DGRAM) + b = tsocket.socket(type=tsocket.SOCK_DGRAM) + with a, b: + await b.bind(("127.0.0.1", 0)) + await a.connect(b.getsockname()) + # send on a connected udp socket; each call creates a separate + # datagram + await a.send(b"xxx") + await a.send(b"yyy") + assert await b.recv(10) == b"xxx" + assert await b.recv(10) == b"yyy" + + +async def test_idna(monkeygai): + # This is the encoding for "faß.de", which uses one of the characters that + # IDNA 2003 handles incorrectly: + monkeygai.set("ok faß.de", b"xn--fa-hia.de", 80) + monkeygai.set("ok ::1", "::1", 80, flags=_NUMERIC_ONLY) + monkeygai.set("ok ::1", b"::1", 80, flags=_NUMERIC_ONLY) + # Some things that should not reach the underlying socket.getaddrinfo: + monkeygai.set("bad", "fass.de", 80) + # We always call socket.getaddrinfo with bytes objects: + monkeygai.set("bad", "xn--fa-hia.de", 80) + + assert "ok ::1" == await tsocket.getaddrinfo("::1", 80) + assert "ok ::1" == await tsocket.getaddrinfo(b"::1", 80) + assert "ok faß.de" == await tsocket.getaddrinfo("faß.de", 80) + assert "ok faß.de" == await tsocket.getaddrinfo("xn--fa-hia.de", 80) + assert "ok faß.de" == await tsocket.getaddrinfo(b"xn--fa-hia.de", 80) + + +async def test_getprotobyname(): + # These are the constants used in IP header fields, so the numeric values + # had *better* be stable across systems... + assert await tsocket.getprotobyname("udp") == 17 + assert await tsocket.getprotobyname("tcp") == 6 + + +async def test_custom_hostname_resolver(monkeygai): + class CustomResolver: + async def getaddrinfo(self, host, port, family, type, proto, flags): + return ("custom_gai", host, port, family, type, proto, flags) + + async def getnameinfo(self, sockaddr, flags): + return ("custom_gni", sockaddr, flags) + + cr = CustomResolver() + + assert tsocket.set_custom_hostname_resolver(cr) is None + + # Check that the arguments are all getting passed through. + # We have to use valid calls to avoid making the underlying system + # getaddrinfo cranky when it's used for NUMERIC checks. + for vals in [ + (tsocket.AF_INET, 0, 0, 0), + (0, tsocket.SOCK_STREAM, 0, 0), + (0, 0, tsocket.IPPROTO_TCP, 0), + (0, 0, 0, tsocket.AI_CANONNAME), + ]: + assert await tsocket.getaddrinfo("localhost", "foo", *vals) == ( + "custom_gai", + b"localhost", + "foo", + *vals, + ) + + # IDNA encoding is handled before calling the special object + got = await tsocket.getaddrinfo("föö", "foo") + expected = ("custom_gai", b"xn--f-1gaa", "foo", 0, 0, 0, 0) + assert got == expected + + assert await tsocket.getnameinfo("a", 0) == ("custom_gni", "a", 0) + + # We can set it back to None + assert tsocket.set_custom_hostname_resolver(None) is cr + + # And now Trio switches back to calling socket.getaddrinfo (specifically + # our monkeypatched version of socket.getaddrinfo) + monkeygai.set("x", b"host", "port", family=0, type=0, proto=0, flags=0) + assert await tsocket.getaddrinfo("host", "port") == "x" + + +async def test_custom_socket_factory(): + class CustomSocketFactory: + def socket(self, family, type, proto): + return ("hi", family, type, proto) + + csf = CustomSocketFactory() + + assert tsocket.set_custom_socket_factory(csf) is None + + assert tsocket.socket() == ("hi", tsocket.AF_INET, tsocket.SOCK_STREAM, 0) + assert tsocket.socket(1, 2, 3) == ("hi", 1, 2, 3) + + # socket with fileno= doesn't call our custom method + fd = stdlib_socket.socket().detach() + wrapped = tsocket.socket(fileno=fd) + assert hasattr(wrapped, "bind") + wrapped.close() + + # Likewise for socketpair + a, b = tsocket.socketpair() + with a, b: + assert hasattr(a, "bind") + assert hasattr(b, "bind") + + assert tsocket.set_custom_socket_factory(None) is csf + + +async def test_SocketType_is_abstract(): + with pytest.raises(TypeError): + tsocket.SocketType() + + +@pytest.mark.skipif(not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets") +async def test_unix_domain_socket(): + # Bind has a special branch to use a thread, since it has to do filesystem + # traversal. Maybe connect should too? Not sure. + + async def check_AF_UNIX(path): + with tsocket.socket(family=tsocket.AF_UNIX) as lsock: + await lsock.bind(path) + lsock.listen(10) + with tsocket.socket(family=tsocket.AF_UNIX) as csock: + await csock.connect(path) + ssock, _ = await lsock.accept() + with ssock: + await csock.send(b"x") + assert await ssock.recv(1) == b"x" + + # Can't use tmpdir fixture, because we can exceed the maximum AF_UNIX path + # length on macOS. + with tempfile.TemporaryDirectory() as tmpdir: + path = "{}/sock".format(tmpdir) + await check_AF_UNIX(path) + + try: + cookie = os.urandom(20).hex().encode("ascii") + await check_AF_UNIX(b"\x00trio-test-" + cookie) + except FileNotFoundError: + # macOS doesn't support abstract filenames with the leading NUL byte + pass + + +async def test_interrupted_by_close(): + a_stdlib, b_stdlib = stdlib_socket.socketpair() + with a_stdlib, b_stdlib: + a_stdlib.setblocking(False) + + data = b"x" * 99999 + + try: + while True: + a_stdlib.send(data) + except BlockingIOError: + pass + + a = tsocket.from_stdlib_socket(a_stdlib) + + async def sender(): + with pytest.raises(_core.ClosedResourceError): + await a.send(data) + + async def receiver(): + with pytest.raises(_core.ClosedResourceError): + await a.recv(1) + + async with _core.open_nursery() as nursery: + nursery.start_soon(sender) + nursery.start_soon(receiver) + await wait_all_tasks_blocked() + a.close() + + +async def test_many_sockets(): + total = 5000 # Must be more than MAX_AFD_GROUP_SIZE + sockets = [] + for x in range(total // 2): + try: + a, b = stdlib_socket.socketpair() + except OSError as e: # pragma: no cover + assert e.errno in (errno.EMFILE, errno.ENFILE) + break + sockets += [a, b] + async with _core.open_nursery() as nursery: + for s in sockets: + nursery.start_soon(_core.wait_readable, s) + await _core.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + for sock in sockets: + sock.close() + if x != total // 2 - 1: # pragma: no cover + print(f"Unable to open more than {(x-1)*2} sockets.") diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_ssl.py b/venv/lib/python3.9/site-packages/trio/tests/test_ssl.py new file mode 100644 index 00000000..7825e319 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_ssl.py @@ -0,0 +1,1303 @@ +import os +import re +import sys + +import pytest + +import threading +import socket as stdlib_socket +import ssl +from contextlib import contextmanager +from functools import partial + +from OpenSSL import SSL +import trustme +from async_generator import asynccontextmanager + +import trio +from .. import _core +from .._highlevel_socket import SocketStream, SocketListener +from .._highlevel_generic import aclose_forcefully +from .._core import ClosedResourceError, BrokenResourceError +from .._highlevel_open_tcp_stream import open_tcp_stream +from .. import socket as tsocket +from .._ssl import SSLStream, SSLListener, NeedHandshakeError, _is_eof +from .._util import ConflictDetector + +from .._core.tests.tutil import slow + +from ..testing import ( + assert_checkpoints, + Sequencer, + memory_stream_pair, + lockstep_stream_pair, + check_two_way_stream, +) + +# We have two different kinds of echo server fixtures we use for testing. The +# first is a real server written using the stdlib ssl module and blocking +# sockets. It runs in a thread and we talk to it over a real socketpair(), to +# validate interoperability in a semi-realistic setting. +# +# The second is a very weird virtual echo server that lives inside a custom +# Stream class. It lives entirely inside the Python object space; there are no +# operating system calls in it at all. No threads, no I/O, nothing. It's +# 'send_all' call takes encrypted data from a client and feeds it directly into +# the server-side TLS state engine to decrypt, then takes that data, feeds it +# back through to get the encrypted response, and returns it from 'receive_some'. This +# gives us full control and reproducibility. This server is written using +# PyOpenSSL, so that we can trigger renegotiations on demand. It also allows +# us to insert random (virtual) delays, to really exercise all the weird paths +# in SSLStream's state engine. +# +# Both present a certificate for "trio-test-1.example.org". + +TRIO_TEST_CA = trustme.CA() +TRIO_TEST_1_CERT = TRIO_TEST_CA.issue_server_cert("trio-test-1.example.org") + +SERVER_CTX = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) +if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): + SERVER_CTX.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF + +TRIO_TEST_1_CERT.configure_cert(SERVER_CTX) + + +# TLS 1.3 has a lot of changes from previous versions. So we want to run tests +# with both TLS 1.3, and TLS 1.2. +# "tls13" means that we're willing to negotiate TLS 1.3. Usually that's +# what will happen, but the renegotiation tests explicitly force a +# downgrade on the server side. "tls12" means we refuse to negotiate TLS +# 1.3, so we'll almost certainly use TLS 1.2. +@pytest.fixture(scope="module", params=["tls13", "tls12"]) +def client_ctx(request): + ctx = ssl.create_default_context() + + if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): + ctx.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF + + TRIO_TEST_CA.configure_trust(ctx) + if request.param in ["default", "tls13"]: + return ctx + elif request.param == "tls12": + ctx.maximum_version = ssl.TLSVersion.TLSv1_2 + return ctx + else: # pragma: no cover + assert False + + +# The blocking socket server. +def ssl_echo_serve_sync(sock, *, expect_fail=False): + try: + wrapped = SERVER_CTX.wrap_socket( + sock, server_side=True, suppress_ragged_eofs=False + ) + with wrapped: + wrapped.do_handshake() + while True: + data = wrapped.recv(4096) + if not data: + # other side has initiated a graceful shutdown; we try to + # respond in kind but it's legal for them to have already + # gone away. + exceptions = (BrokenPipeError, ssl.SSLZeroReturnError) + try: + wrapped.unwrap() + except exceptions: + pass + except ssl.SSLWantWriteError: # pragma: no cover + # Under unclear conditions, CPython sometimes raises + # SSLWantWriteError here. This is a bug (bpo-32219), + # but it's not our bug. Christian Heimes thinks + # it's fixed in 'recent' CPython versions so we fail + # the test for those and ignore it for earlier + # versions. + if ( + sys.implementation.name != "cpython" + or sys.version_info >= (3, 8) + ): + pytest.fail( + "still an issue on recent python versions " + "add a comment to " + "https://bugs.python.org/issue32219" + ) + return + wrapped.sendall(data) + # This is an obscure workaround for an openssl bug. In server mode, in + # some versions, openssl sends some extra data at the end of do_handshake + # that it shouldn't send. Normally this is harmless, but, if the other + # side shuts down the connection before it reads that data, it might cause + # the OS to report a ECONNREST or even ECONNABORTED (which is just wrong, + # since ECONNABORTED is supposed to mean that connect() failed, but what + # can you do). In this case the other side did nothing wrong, but there's + # no way to recover, so we let it pass, and just cross our fingers its not + # hiding any (other) real bugs. For more details see: + # + # https://github.com/python-trio/trio/issues/1293 + # + # Also, this happens frequently but non-deterministically, so we have to + # 'no cover' it to avoid coverage flapping. + except (ConnectionResetError, ConnectionAbortedError): # pragma: no cover + return + except Exception as exc: + if expect_fail: + print("ssl_echo_serve_sync got error as expected:", exc) + else: # pragma: no cover + print("ssl_echo_serve_sync got unexpected error:", exc) + raise + else: + if expect_fail: # pragma: no cover + raise RuntimeError("failed to fail?") + finally: + sock.close() + + +# Fixture that gives a raw socket connected to a trio-test-1 echo server +# (running in a thread). Useful for testing making connections with different +# SSLContexts. +@asynccontextmanager +async def ssl_echo_server_raw(**kwargs): + a, b = stdlib_socket.socketpair() + async with trio.open_nursery() as nursery: + # Exiting the 'with a, b' context manager closes the sockets, which + # causes the thread to exit (possibly with an error), which allows the + # nursery context manager to exit too. + with a, b: + nursery.start_soon( + trio.to_thread.run_sync, partial(ssl_echo_serve_sync, b, **kwargs) + ) + + yield SocketStream(tsocket.from_stdlib_socket(a)) + + +# Fixture that gives a properly set up SSLStream connected to a trio-test-1 +# echo server (running in a thread) +@asynccontextmanager +async def ssl_echo_server(client_ctx, **kwargs): + async with ssl_echo_server_raw(**kwargs) as sock: + yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") + + +# The weird in-memory server ... thing. +# Doesn't inherit from Stream because I left out the methods that we don't +# actually need. +class PyOpenSSLEchoStream: + def __init__(self, sleeper=None): + ctx = SSL.Context(SSL.SSLv23_METHOD) + # TLS 1.3 removes renegotiation support. Which is great for them, but + # we still have to support versions before that, and that means we + # need to test renegotiation support, which means we need to force this + # to use a lower version where this test server can trigger + # renegotiations. Of course TLS 1.3 support isn't released yet, but + # I'm told that this will work once it is. (And once it is we can + # remove the pragma: no cover too.) Alternatively, we could switch to + # using TLSv1_2_METHOD. + # + # Discussion: https://github.com/pyca/pyopenssl/issues/624 + + # This is the right way, but we can't use it until this PR is in a + # released: + # https://github.com/pyca/pyopenssl/pull/861 + # + # if hasattr(SSL, "OP_NO_TLSv1_3"): + # ctx.set_options(SSL.OP_NO_TLSv1_3) + # + # Fortunately pyopenssl uses cryptography under the hood, so we can be + # confident that they're using the same version of openssl + from cryptography.hazmat.bindings.openssl.binding import Binding + + b = Binding() + if hasattr(b.lib, "SSL_OP_NO_TLSv1_3"): + ctx.set_options(b.lib.SSL_OP_NO_TLSv1_3) + + # Unfortunately there's currently no way to say "use 1.3 or worse", we + # can only disable specific versions. And if the two sides start + # negotiating 1.4 at some point in the future, it *might* mean that + # our tests silently stop working properly. So the next line is a + # tripwire to remind us we need to revisit this stuff in 5 years or + # whatever when the next TLS version is released: + assert not hasattr(SSL, "OP_NO_TLSv1_4") + TRIO_TEST_1_CERT.configure_cert(ctx) + self._conn = SSL.Connection(ctx, None) + self._conn.set_accept_state() + self._lot = _core.ParkingLot() + self._pending_cleartext = bytearray() + + self._send_all_conflict_detector = ConflictDetector( + "simultaneous calls to PyOpenSSLEchoStream.send_all" + ) + self._receive_some_conflict_detector = ConflictDetector( + "simultaneous calls to PyOpenSSLEchoStream.receive_some" + ) + + if sleeper is None: + + async def no_op_sleeper(_): + return + + self.sleeper = no_op_sleeper + else: + self.sleeper = sleeper + + async def aclose(self): + self._conn.bio_shutdown() + + def renegotiate_pending(self): + return self._conn.renegotiate_pending() + + def renegotiate(self): + # Returns false if a renegotiation is already in progress, meaning + # nothing happens. + assert self._conn.renegotiate() + + async def wait_send_all_might_not_block(self): + with self._send_all_conflict_detector: + await _core.checkpoint() + await _core.checkpoint() + await self.sleeper("wait_send_all_might_not_block") + + async def send_all(self, data): + print(" --> transport_stream.send_all") + with self._send_all_conflict_detector: + await _core.checkpoint() + await _core.checkpoint() + await self.sleeper("send_all") + self._conn.bio_write(data) + while True: + await self.sleeper("send_all") + try: + data = self._conn.recv(1) + except SSL.ZeroReturnError: + self._conn.shutdown() + print("renegotiations:", self._conn.total_renegotiations()) + break + except SSL.WantReadError: + break + else: + self._pending_cleartext += data + self._lot.unpark_all() + await self.sleeper("send_all") + print(" <-- transport_stream.send_all finished") + + async def receive_some(self, nbytes=None): + print(" --> transport_stream.receive_some") + if nbytes is None: + nbytes = 65536 # arbitrary + with self._receive_some_conflict_detector: + try: + await _core.checkpoint() + await _core.checkpoint() + while True: + await self.sleeper("receive_some") + try: + return self._conn.bio_read(nbytes) + except SSL.WantReadError: + # No data in our ciphertext buffer; try to generate + # some. + if self._pending_cleartext: + # We have some cleartext; maybe we can encrypt it + # and then return it. + print(" trying", self._pending_cleartext) + try: + # PyOpenSSL bug: doesn't accept bytearray + # https://github.com/pyca/pyopenssl/issues/621 + next_byte = self._pending_cleartext[0:1] + self._conn.send(bytes(next_byte)) + # Apparently this next bit never gets hit in the + # test suite, but it's not an interesting omission + # so let's pragma it. + except SSL.WantReadError: # pragma: no cover + # We didn't manage to send the cleartext (and + # in particular we better leave it there to + # try again, due to openssl's retry + # semantics), but it's possible we pushed a + # renegotiation forward and *now* we have data + # to send. + try: + return self._conn.bio_read(nbytes) + except SSL.WantReadError: + # Nope. We're just going to have to wait + # for someone to call send_all() to give + # use more data. + print("parking (a)") + await self._lot.park() + else: + # We successfully sent that byte, so we don't + # have to again. + del self._pending_cleartext[0:1] + else: + # no pending cleartext; nothing to do but wait for + # someone to call send_all + print("parking (b)") + await self._lot.park() + finally: + await self.sleeper("receive_some") + print(" <-- transport_stream.receive_some finished") + + +async def test_PyOpenSSLEchoStream_gives_resource_busy_errors(): + # Make sure that PyOpenSSLEchoStream complains if two tasks call send_all + # at the same time, or ditto for receive_some. The tricky cases where SSLStream + # might accidentally do this are during renegotiation, which we test using + # PyOpenSSLEchoStream, so this makes sure that if we do have a bug then + # PyOpenSSLEchoStream will notice and complain. + + s = PyOpenSSLEchoStream() + with pytest.raises(_core.BusyResourceError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.start_soon(s.send_all, b"x") + nursery.start_soon(s.send_all, b"x") + assert "simultaneous" in str(excinfo.value) + + s = PyOpenSSLEchoStream() + with pytest.raises(_core.BusyResourceError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.start_soon(s.send_all, b"x") + nursery.start_soon(s.wait_send_all_might_not_block) + assert "simultaneous" in str(excinfo.value) + + s = PyOpenSSLEchoStream() + with pytest.raises(_core.BusyResourceError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.start_soon(s.wait_send_all_might_not_block) + nursery.start_soon(s.wait_send_all_might_not_block) + assert "simultaneous" in str(excinfo.value) + + s = PyOpenSSLEchoStream() + with pytest.raises(_core.BusyResourceError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.start_soon(s.receive_some, 1) + nursery.start_soon(s.receive_some, 1) + assert "simultaneous" in str(excinfo.value) + + +@contextmanager +def virtual_ssl_echo_server(client_ctx, **kwargs): + fakesock = PyOpenSSLEchoStream(**kwargs) + yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org") + + +def ssl_wrap_pair( + client_ctx, + client_transport, + server_transport, + *, + client_kwargs={}, + server_kwargs={}, +): + client_ssl = SSLStream( + client_transport, + client_ctx, + server_hostname="trio-test-1.example.org", + **client_kwargs, + ) + server_ssl = SSLStream( + server_transport, SERVER_CTX, server_side=True, **server_kwargs + ) + return client_ssl, server_ssl + + +def ssl_memory_stream_pair(client_ctx, **kwargs): + client_transport, server_transport = memory_stream_pair() + return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) + + +def ssl_lockstep_stream_pair(client_ctx, **kwargs): + client_transport, server_transport = lockstep_stream_pair() + return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) + + +# Simple smoke test for handshake/send/receive/shutdown talking to a +# synchronous server, plus make sure that we do the bare minimum of +# certificate checking (even though this is really Python's responsibility) +async def test_ssl_client_basics(client_ctx): + # Everything OK + async with ssl_echo_server(client_ctx) as s: + assert not s.server_side + await s.send_all(b"x") + assert await s.receive_some(1) == b"x" + await s.aclose() + + # Didn't configure the CA file, should fail + async with ssl_echo_server_raw(expect_fail=True) as sock: + bad_client_ctx = ssl.create_default_context() + s = SSLStream(sock, bad_client_ctx, server_hostname="trio-test-1.example.org") + assert not s.server_side + with pytest.raises(BrokenResourceError) as excinfo: + await s.send_all(b"x") + assert isinstance(excinfo.value.__cause__, ssl.SSLError) + + # Trusted CA, but wrong host name + async with ssl_echo_server_raw(expect_fail=True) as sock: + s = SSLStream(sock, client_ctx, server_hostname="trio-test-2.example.org") + assert not s.server_side + with pytest.raises(BrokenResourceError) as excinfo: + await s.send_all(b"x") + assert isinstance(excinfo.value.__cause__, ssl.CertificateError) + + +async def test_ssl_server_basics(client_ctx): + a, b = stdlib_socket.socketpair() + with a, b: + server_sock = tsocket.from_stdlib_socket(b) + server_transport = SSLStream( + SocketStream(server_sock), SERVER_CTX, server_side=True + ) + assert server_transport.server_side + + def client(): + with client_ctx.wrap_socket( + a, server_hostname="trio-test-1.example.org" + ) as client_sock: + client_sock.sendall(b"x") + assert client_sock.recv(1) == b"y" + client_sock.sendall(b"z") + client_sock.unwrap() + + t = threading.Thread(target=client) + t.start() + + assert await server_transport.receive_some(1) == b"x" + await server_transport.send_all(b"y") + assert await server_transport.receive_some(1) == b"z" + assert await server_transport.receive_some(1) == b"" + await server_transport.aclose() + + t.join() + + +async def test_attributes(client_ctx): + async with ssl_echo_server_raw(expect_fail=True) as sock: + good_ctx = client_ctx + bad_ctx = ssl.create_default_context() + s = SSLStream(sock, good_ctx, server_hostname="trio-test-1.example.org") + + assert s.transport_stream is sock + + # Forwarded attribute getting + assert s.context is good_ctx + assert s.server_side == False # noqa + assert s.server_hostname == "trio-test-1.example.org" + with pytest.raises(AttributeError): + s.asfdasdfsa + + # __dir__ + assert "transport_stream" in dir(s) + assert "context" in dir(s) + + # Setting the attribute goes through to the underlying object + + # most attributes on SSLObject are read-only + with pytest.raises(AttributeError): + s.server_side = True + with pytest.raises(AttributeError): + s.server_hostname = "asdf" + + # but .context is *not*. Check that we forward attribute setting by + # making sure that after we set the bad context our handshake indeed + # fails: + s.context = bad_ctx + assert s.context is bad_ctx + with pytest.raises(BrokenResourceError) as excinfo: + await s.do_handshake() + assert isinstance(excinfo.value.__cause__, ssl.SSLError) + + +# Note: this test fails horribly if we force TLS 1.2 and trigger a +# renegotiation at the beginning (e.g. by switching to the pyopenssl +# server). Usually the client crashes in SSLObject.write with "UNEXPECTED +# RECORD"; sometimes we get something more exotic like a SyscallError. This is +# odd because openssl isn't doing any syscalls, but so it goes. After lots of +# websearching I'm pretty sure this is due to a bug in OpenSSL, where it just +# can't reliably handle full-duplex communication combined with +# renegotiation. Nice, eh? +# +# https://rt.openssl.org/Ticket/Display.html?id=3712 +# https://rt.openssl.org/Ticket/Display.html?id=2481 +# http://openssl.6102.n7.nabble.com/TLS-renegotiation-failure-on-receiving-application-data-during-handshake-td48127.html +# https://stackoverflow.com/questions/18728355/ssl-renegotiation-with-full-duplex-socket-communication +# +# In some variants of this test (maybe only against the java server?) I've +# also seen cases where our send_all blocks waiting to write, and then our receive_some +# also blocks waiting to write, and they never wake up again. It looks like +# some kind of deadlock. I suspect there may be an issue where we've filled up +# the send buffers, and the remote side is trying to handle the renegotiation +# from inside a write() call, so it has a problem: there's all this application +# data clogging up the pipe, but it can't process and return it to the +# application because it's in write(), and it doesn't want to buffer infinite +# amounts of data, and... actually I guess those are the only two choices. +# +# NSS even documents that you shouldn't try to do a renegotiation except when +# the connection is idle: +# +# https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/SSL_functions/sslfnc.html#1061582 +# +# I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it... + + +async def test_full_duplex_basics(client_ctx): + CHUNKS = 30 + CHUNK_SIZE = 32768 + EXPECTED = CHUNKS * CHUNK_SIZE + + sent = bytearray() + received = bytearray() + + async def sender(s): + nonlocal sent + for i in range(CHUNKS): + print(i) + chunk = bytes([i] * CHUNK_SIZE) + sent += chunk + await s.send_all(chunk) + + async def receiver(s): + nonlocal received + while len(received) < EXPECTED: + chunk = await s.receive_some(CHUNK_SIZE // 2) + received += chunk + + async with ssl_echo_server(client_ctx) as s: + async with _core.open_nursery() as nursery: + nursery.start_soon(sender, s) + nursery.start_soon(receiver, s) + # And let's have some doing handshakes too, everyone + # simultaneously + nursery.start_soon(s.do_handshake) + nursery.start_soon(s.do_handshake) + + await s.aclose() + + assert len(sent) == len(received) == EXPECTED + assert sent == received + + +async def test_renegotiation_simple(client_ctx): + with virtual_ssl_echo_server(client_ctx) as s: + await s.do_handshake() + + s.transport_stream.renegotiate() + await s.send_all(b"a") + assert await s.receive_some(1) == b"a" + + # Have to send some more data back and forth to make sure the + # renegotiation is finished before shutting down the + # connection... otherwise openssl raises an error. I think this is a + # bug in openssl but what can ya do. + await s.send_all(b"b") + assert await s.receive_some(1) == b"b" + + await s.aclose() + + +@slow +async def test_renegotiation_randomized(mock_clock, client_ctx): + # The only blocking things in this function are our random sleeps, so 0 is + # a good threshold. + mock_clock.autojump_threshold = 0 + + import random + + r = random.Random(0) + + async def sleeper(_): + await trio.sleep(r.uniform(0, 10)) + + async def clear(): + while s.transport_stream.renegotiate_pending(): + with assert_checkpoints(): + await send(b"-") + with assert_checkpoints(): + await expect(b"-") + print("-- clear --") + + async def send(byte): + await s.transport_stream.sleeper("outer send") + print("calling SSLStream.send_all", byte) + with assert_checkpoints(): + await s.send_all(byte) + + async def expect(expected): + await s.transport_stream.sleeper("expect") + print("calling SSLStream.receive_some, expecting", expected) + assert len(expected) == 1 + with assert_checkpoints(): + assert await s.receive_some(1) == expected + + with virtual_ssl_echo_server(client_ctx, sleeper=sleeper) as s: + await s.do_handshake() + + await send(b"a") + s.transport_stream.renegotiate() + await expect(b"a") + + await clear() + + for i in range(100): + b1 = bytes([i % 0xFF]) + b2 = bytes([(2 * i) % 0xFF]) + s.transport_stream.renegotiate() + async with _core.open_nursery() as nursery: + nursery.start_soon(send, b1) + nursery.start_soon(expect, b1) + async with _core.open_nursery() as nursery: + nursery.start_soon(expect, b2) + nursery.start_soon(send, b2) + await clear() + + for i in range(100): + b1 = bytes([i % 0xFF]) + b2 = bytes([(2 * i) % 0xFF]) + await send(b1) + s.transport_stream.renegotiate() + await expect(b1) + async with _core.open_nursery() as nursery: + nursery.start_soon(expect, b2) + nursery.start_soon(send, b2) + await clear() + + # Checking that wait_send_all_might_not_block and receive_some don't + # conflict: + + # 1) Set up a situation where expect (receive_some) is blocked sending, + # and wait_send_all_might_not_block comes in. + + # Our receive_some() call will get stuck when it hits send_all + async def sleeper_with_slow_send_all(method): + if method == "send_all": + await trio.sleep(100000) + + # And our wait_send_all_might_not_block call will give it time to get + # stuck, and then start + async def sleep_then_wait_writable(): + await trio.sleep(1000) + await s.wait_send_all_might_not_block() + + with virtual_ssl_echo_server(client_ctx, sleeper=sleeper_with_slow_send_all) as s: + await send(b"x") + s.transport_stream.renegotiate() + async with _core.open_nursery() as nursery: + nursery.start_soon(expect, b"x") + nursery.start_soon(sleep_then_wait_writable) + + await clear() + + await s.aclose() + + # 2) Same, but now wait_send_all_might_not_block is stuck when + # receive_some tries to send. + + async def sleeper_with_slow_wait_writable_and_expect(method): + if method == "wait_send_all_might_not_block": + await trio.sleep(100000) + elif method == "expect": + await trio.sleep(1000) + + with virtual_ssl_echo_server( + client_ctx, sleeper=sleeper_with_slow_wait_writable_and_expect + ) as s: + await send(b"x") + s.transport_stream.renegotiate() + async with _core.open_nursery() as nursery: + nursery.start_soon(expect, b"x") + nursery.start_soon(s.wait_send_all_might_not_block) + + await clear() + + await s.aclose() + + +async def test_resource_busy_errors(client_ctx): + async def do_send_all(): + with assert_checkpoints(): + await s.send_all(b"x") + + async def do_receive_some(): + with assert_checkpoints(): + await s.receive_some(1) + + async def do_wait_send_all_might_not_block(): + with assert_checkpoints(): + await s.wait_send_all_might_not_block() + + s, _ = ssl_lockstep_stream_pair(client_ctx) + with pytest.raises(_core.BusyResourceError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.start_soon(do_send_all) + nursery.start_soon(do_send_all) + assert "another task" in str(excinfo.value) + + s, _ = ssl_lockstep_stream_pair(client_ctx) + with pytest.raises(_core.BusyResourceError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.start_soon(do_receive_some) + nursery.start_soon(do_receive_some) + assert "another task" in str(excinfo.value) + + s, _ = ssl_lockstep_stream_pair(client_ctx) + with pytest.raises(_core.BusyResourceError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.start_soon(do_send_all) + nursery.start_soon(do_wait_send_all_might_not_block) + assert "another task" in str(excinfo.value) + + s, _ = ssl_lockstep_stream_pair(client_ctx) + with pytest.raises(_core.BusyResourceError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.start_soon(do_wait_send_all_might_not_block) + nursery.start_soon(do_wait_send_all_might_not_block) + assert "another task" in str(excinfo.value) + + +async def test_wait_writable_calls_underlying_wait_writable(): + record = [] + + class NotAStream: + async def wait_send_all_might_not_block(self): + record.append("ok") + + ctx = ssl.create_default_context() + s = SSLStream(NotAStream(), ctx, server_hostname="x") + await s.wait_send_all_might_not_block() + assert record == ["ok"] + + +@pytest.mark.skipif( + os.name == "nt" and sys.version_info >= (3, 10), + reason="frequently fails on Windows + Python 3.10", +) +async def test_checkpoints(client_ctx): + async with ssl_echo_server(client_ctx) as s: + with assert_checkpoints(): + await s.do_handshake() + with assert_checkpoints(): + await s.do_handshake() + with assert_checkpoints(): + await s.wait_send_all_might_not_block() + with assert_checkpoints(): + await s.send_all(b"xxx") + with assert_checkpoints(): + await s.receive_some(1) + # These receive_some's in theory could return immediately, because the + # "xxx" was sent in a single record and after the first + # receive_some(1) the rest are sitting inside the SSLObject's internal + # buffers. + with assert_checkpoints(): + await s.receive_some(1) + with assert_checkpoints(): + await s.receive_some(1) + with assert_checkpoints(): + await s.unwrap() + + async with ssl_echo_server(client_ctx) as s: + await s.do_handshake() + with assert_checkpoints(): + await s.aclose() + + +async def test_send_all_empty_string(client_ctx): + async with ssl_echo_server(client_ctx) as s: + await s.do_handshake() + + # underlying SSLObject interprets writing b"" as indicating an EOF, + # for some reason. Make sure we don't inherit this. + with assert_checkpoints(): + await s.send_all(b"") + with assert_checkpoints(): + await s.send_all(b"") + await s.send_all(b"x") + assert await s.receive_some(1) == b"x" + + await s.aclose() + + +@pytest.mark.parametrize("https_compatible", [False, True]) +async def test_SSLStream_generic(client_ctx, https_compatible): + async def stream_maker(): + return ssl_memory_stream_pair( + client_ctx, + client_kwargs={"https_compatible": https_compatible}, + server_kwargs={"https_compatible": https_compatible}, + ) + + async def clogged_stream_maker(): + client, server = ssl_lockstep_stream_pair(client_ctx) + # If we don't do handshakes up front, then we run into a problem in + # the following situation: + # - server does wait_send_all_might_not_block + # - client does receive_some to unclog it + # Then the client's receive_some will actually send some data to start + # the handshake, and itself get stuck. + async with _core.open_nursery() as nursery: + nursery.start_soon(client.do_handshake) + nursery.start_soon(server.do_handshake) + return client, server + + await check_two_way_stream(stream_maker, clogged_stream_maker) + + +async def test_unwrap(client_ctx): + client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) + client_transport = client_ssl.transport_stream + server_transport = server_ssl.transport_stream + + seq = Sequencer() + + async def client(): + await client_ssl.do_handshake() + await client_ssl.send_all(b"x") + assert await client_ssl.receive_some(1) == b"y" + await client_ssl.send_all(b"z") + + # After sending that, disable outgoing data from our end, to make + # sure the server doesn't see our EOF until after we've sent some + # trailing data + async with seq(0): + send_all_hook = client_transport.send_stream.send_all_hook + client_transport.send_stream.send_all_hook = None + + assert await client_ssl.receive_some(1) == b"" + assert client_ssl.transport_stream is client_transport + # We just received EOF. Unwrap the connection and send some more. + raw, trailing = await client_ssl.unwrap() + assert raw is client_transport + assert trailing == b"" + assert client_ssl.transport_stream is None + await raw.send_all(b"trailing") + + # Reconnect the streams. Now the server will receive both our shutdown + # acknowledgement + the trailing data in a single lump. + client_transport.send_stream.send_all_hook = send_all_hook + await client_transport.send_stream.send_all_hook() + + async def server(): + await server_ssl.do_handshake() + assert await server_ssl.receive_some(1) == b"x" + await server_ssl.send_all(b"y") + assert await server_ssl.receive_some(1) == b"z" + # Now client is blocked waiting for us to send something, but + # instead we close the TLS connection (with sequencer to make sure + # that the client won't see and automatically respond before we've had + # a chance to disable the client->server transport) + async with seq(1): + raw, trailing = await server_ssl.unwrap() + assert raw is server_transport + assert trailing == b"trailing" + assert server_ssl.transport_stream is None + + async with _core.open_nursery() as nursery: + nursery.start_soon(client) + nursery.start_soon(server) + + +async def test_closing_nice_case(client_ctx): + # the nice case: graceful closes all around + + client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) + client_transport = client_ssl.transport_stream + + # Both the handshake and the close require back-and-forth discussion, so + # we need to run them concurrently + async def client_closer(): + with assert_checkpoints(): + await client_ssl.aclose() + + async def server_closer(): + assert await server_ssl.receive_some(10) == b"" + assert await server_ssl.receive_some(10) == b"" + with assert_checkpoints(): + await server_ssl.aclose() + + async with _core.open_nursery() as nursery: + nursery.start_soon(client_closer) + nursery.start_soon(server_closer) + + # closing the SSLStream also closes its transport + with pytest.raises(ClosedResourceError): + await client_transport.send_all(b"123") + + # once closed, it's OK to close again + with assert_checkpoints(): + await client_ssl.aclose() + with assert_checkpoints(): + await client_ssl.aclose() + + # Trying to send more data does not work + with pytest.raises(ClosedResourceError): + await server_ssl.send_all(b"123") + + # And once the connection is has been closed *locally*, then instead of + # getting empty bytestrings we get a proper error + with pytest.raises(ClosedResourceError): + await client_ssl.receive_some(10) == b"" + + with pytest.raises(ClosedResourceError): + await client_ssl.unwrap() + + with pytest.raises(ClosedResourceError): + await client_ssl.do_handshake() + + # Check that a graceful close *before* handshaking gives a clean EOF on + # the other side + client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) + + async def expect_eof_server(): + with assert_checkpoints(): + assert await server_ssl.receive_some(10) == b"" + with assert_checkpoints(): + await server_ssl.aclose() + + async with _core.open_nursery() as nursery: + nursery.start_soon(client_ssl.aclose) + nursery.start_soon(expect_eof_server) + + +async def test_send_all_fails_in_the_middle(client_ctx): + client, server = ssl_memory_stream_pair(client_ctx) + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.do_handshake) + nursery.start_soon(server.do_handshake) + + async def bad_hook(): + raise KeyError + + client.transport_stream.send_stream.send_all_hook = bad_hook + + with pytest.raises(KeyError): + await client.send_all(b"x") + + with pytest.raises(BrokenResourceError): + await client.wait_send_all_might_not_block() + + closed = 0 + + def close_hook(): + nonlocal closed + closed += 1 + + client.transport_stream.send_stream.close_hook = close_hook + client.transport_stream.receive_stream.close_hook = close_hook + await client.aclose() + + assert closed == 2 + + +async def test_ssl_over_ssl(client_ctx): + client_0, server_0 = memory_stream_pair() + + client_1 = SSLStream( + client_0, client_ctx, server_hostname="trio-test-1.example.org" + ) + server_1 = SSLStream(server_0, SERVER_CTX, server_side=True) + + client_2 = SSLStream( + client_1, client_ctx, server_hostname="trio-test-1.example.org" + ) + server_2 = SSLStream(server_1, SERVER_CTX, server_side=True) + + async def client(): + await client_2.send_all(b"hi") + assert await client_2.receive_some(10) == b"bye" + + async def server(): + assert await server_2.receive_some(10) == b"hi" + await server_2.send_all(b"bye") + + async with _core.open_nursery() as nursery: + nursery.start_soon(client) + nursery.start_soon(server) + + +async def test_ssl_bad_shutdown(client_ctx): + client, server = ssl_memory_stream_pair(client_ctx) + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.do_handshake) + nursery.start_soon(server.do_handshake) + + await trio.aclose_forcefully(client) + # now the server sees a broken stream + with pytest.raises(BrokenResourceError): + await server.receive_some(10) + with pytest.raises(BrokenResourceError): + await server.send_all(b"x" * 10) + + await server.aclose() + + +async def test_ssl_bad_shutdown_but_its_ok(client_ctx): + client, server = ssl_memory_stream_pair( + client_ctx, + server_kwargs={"https_compatible": True}, + client_kwargs={"https_compatible": True}, + ) + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.do_handshake) + nursery.start_soon(server.do_handshake) + + await trio.aclose_forcefully(client) + # the server sees that as a clean shutdown + assert await server.receive_some(10) == b"" + with pytest.raises(BrokenResourceError): + await server.send_all(b"x" * 10) + + await server.aclose() + + +async def test_ssl_handshake_failure_during_aclose(): + # Weird scenario: aclose() triggers an automatic handshake, and this + # fails. This also exercises a bit of code in aclose() that was otherwise + # uncovered, for re-raising exceptions after calling aclose_forcefully on + # the underlying transport. + async with ssl_echo_server_raw(expect_fail=True) as sock: + # Don't configure trust correctly + client_ctx = ssl.create_default_context() + s = SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") + # It's a little unclear here whether aclose should swallow the error + # or let it escape. We *do* swallow the error if it arrives when we're + # sending close_notify, because both sides closing the connection + # simultaneously is allowed. But I guess when https_compatible=False + # then it's bad if we can get through a whole connection with a peer + # that has no valid certificate, and never raise an error. + with pytest.raises(BrokenResourceError): + await s.aclose() + + +async def test_ssl_only_closes_stream_once(client_ctx): + # We used to have a bug where if transport_stream.aclose() raised an + # error, we would call it again. This checks that that's fixed. + client, server = ssl_memory_stream_pair(client_ctx) + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.do_handshake) + nursery.start_soon(server.do_handshake) + + client_orig_close_hook = client.transport_stream.send_stream.close_hook + transport_close_count = 0 + + def close_hook(): + nonlocal transport_close_count + client_orig_close_hook() + transport_close_count += 1 + raise KeyError + + client.transport_stream.send_stream.close_hook = close_hook + + with pytest.raises(KeyError): + await client.aclose() + assert transport_close_count == 1 + + +async def test_ssl_https_compatibility_disagreement(client_ctx): + client, server = ssl_memory_stream_pair( + client_ctx, + server_kwargs={"https_compatible": False}, + client_kwargs={"https_compatible": True}, + ) + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.do_handshake) + nursery.start_soon(server.do_handshake) + + # client is in HTTPS-mode, server is not + # so client doing graceful_shutdown causes an error on server + async def receive_and_expect_error(): + with pytest.raises(BrokenResourceError) as excinfo: + await server.receive_some(10) + + assert _is_eof(excinfo.value.__cause__) + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.aclose) + nursery.start_soon(receive_and_expect_error) + + +async def test_https_mode_eof_before_handshake(client_ctx): + client, server = ssl_memory_stream_pair( + client_ctx, + server_kwargs={"https_compatible": True}, + client_kwargs={"https_compatible": True}, + ) + + async def server_expect_clean_eof(): + assert await server.receive_some(10) == b"" + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.aclose) + nursery.start_soon(server_expect_clean_eof) + + +async def test_send_error_during_handshake(client_ctx): + client, server = ssl_memory_stream_pair(client_ctx) + + async def bad_hook(): + raise KeyError + + client.transport_stream.send_stream.send_all_hook = bad_hook + + with pytest.raises(KeyError): + with assert_checkpoints(): + await client.do_handshake() + + with pytest.raises(BrokenResourceError): + with assert_checkpoints(): + await client.do_handshake() + + +async def test_receive_error_during_handshake(client_ctx): + client, server = ssl_memory_stream_pair(client_ctx) + + async def bad_hook(): + raise KeyError + + client.transport_stream.receive_stream.receive_some_hook = bad_hook + + async def client_side(cancel_scope): + with pytest.raises(KeyError): + with assert_checkpoints(): + await client.do_handshake() + cancel_scope.cancel() + + async with _core.open_nursery() as nursery: + nursery.start_soon(client_side, nursery.cancel_scope) + nursery.start_soon(server.do_handshake) + + with pytest.raises(BrokenResourceError): + with assert_checkpoints(): + await client.do_handshake() + + +async def test_selected_alpn_protocol_before_handshake(client_ctx): + client, server = ssl_memory_stream_pair(client_ctx) + + with pytest.raises(NeedHandshakeError): + client.selected_alpn_protocol() + + with pytest.raises(NeedHandshakeError): + server.selected_alpn_protocol() + + +async def test_selected_alpn_protocol_when_not_set(client_ctx): + # ALPN protocol still returns None when it's not set, + # instead of raising an exception + client, server = ssl_memory_stream_pair(client_ctx) + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.do_handshake) + nursery.start_soon(server.do_handshake) + + assert client.selected_alpn_protocol() is None + assert server.selected_alpn_protocol() is None + + assert client.selected_alpn_protocol() == server.selected_alpn_protocol() + + +async def test_selected_npn_protocol_before_handshake(client_ctx): + client, server = ssl_memory_stream_pair(client_ctx) + + with pytest.raises(NeedHandshakeError): + client.selected_npn_protocol() + + with pytest.raises(NeedHandshakeError): + server.selected_npn_protocol() + + +@pytest.mark.filterwarnings( + r"ignore: ssl module. NPN is deprecated, use ALPN instead:UserWarning", + r"ignore:ssl NPN is deprecated, use ALPN instead:DeprecationWarning", +) +async def test_selected_npn_protocol_when_not_set(client_ctx): + # NPN protocol still returns None when it's not set, + # instead of raising an exception + client, server = ssl_memory_stream_pair(client_ctx) + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.do_handshake) + nursery.start_soon(server.do_handshake) + + assert client.selected_npn_protocol() is None + assert server.selected_npn_protocol() is None + + assert client.selected_npn_protocol() == server.selected_npn_protocol() + + +async def test_get_channel_binding_before_handshake(client_ctx): + client, server = ssl_memory_stream_pair(client_ctx) + + with pytest.raises(NeedHandshakeError): + client.get_channel_binding() + + with pytest.raises(NeedHandshakeError): + server.get_channel_binding() + + +async def test_get_channel_binding_after_handshake(client_ctx): + client, server = ssl_memory_stream_pair(client_ctx) + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.do_handshake) + nursery.start_soon(server.do_handshake) + + assert client.get_channel_binding() is not None + assert server.get_channel_binding() is not None + + assert client.get_channel_binding() == server.get_channel_binding() + + +async def test_getpeercert(client_ctx): + # Make sure we're not affected by https://bugs.python.org/issue29334 + client, server = ssl_memory_stream_pair(client_ctx) + + async with _core.open_nursery() as nursery: + nursery.start_soon(client.do_handshake) + nursery.start_soon(server.do_handshake) + + assert server.getpeercert() is None + print(client.getpeercert()) + assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"] + + +async def test_SSLListener(client_ctx): + async def setup(**kwargs): + listen_sock = tsocket.socket() + await listen_sock.bind(("127.0.0.1", 0)) + listen_sock.listen(1) + socket_listener = SocketListener(listen_sock) + ssl_listener = SSLListener(socket_listener, SERVER_CTX, **kwargs) + + transport_client = await open_tcp_stream(*listen_sock.getsockname()) + ssl_client = SSLStream( + transport_client, client_ctx, server_hostname="trio-test-1.example.org" + ) + return listen_sock, ssl_listener, ssl_client + + listen_sock, ssl_listener, ssl_client = await setup() + + async with ssl_client: + ssl_server = await ssl_listener.accept() + + async with ssl_server: + assert not ssl_server._https_compatible + + # Make sure the connection works + async with _core.open_nursery() as nursery: + nursery.start_soon(ssl_client.do_handshake) + nursery.start_soon(ssl_server.do_handshake) + + # Test SSLListener.aclose + await ssl_listener.aclose() + assert listen_sock.fileno() == -1 + + ################ + + # Test https_compatible + _, ssl_listener, ssl_client = await setup(https_compatible=True) + + ssl_server = await ssl_listener.accept() + + assert ssl_server._https_compatible + + await aclose_forcefully(ssl_listener) + await aclose_forcefully(ssl_client) + await aclose_forcefully(ssl_server) diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_subprocess.py b/venv/lib/python3.9/site-packages/trio/tests/test_subprocess.py new file mode 100644 index 00000000..061a7151 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_subprocess.py @@ -0,0 +1,602 @@ +import os +import random +import signal +import subprocess +import sys +from functools import partial +from pathlib import Path as SyncPath + +import pytest +from async_generator import asynccontextmanager + +from .. import ( + ClosedResourceError, + Event, + Process, + _core, + fail_after, + move_on_after, + run_process, + sleep, + sleep_forever, +) +from .._core.tests.tutil import skip_if_fbsd_pipes_broken, slow +from ..lowlevel import open_process +from ..testing import assert_no_checkpoints, wait_all_tasks_blocked + +posix = os.name == "posix" +if posix: + from signal import SIGKILL, SIGTERM, SIGUSR1 +else: + SIGKILL, SIGTERM, SIGUSR1 = None, None, None + + +# Since Windows has very few command-line utilities generally available, +# all of our subprocesses are Python processes running short bits of +# (mostly) cross-platform code. +def python(code): + return [sys.executable, "-u", "-c", "import sys; " + code] + + +EXIT_TRUE = python("sys.exit(0)") +EXIT_FALSE = python("sys.exit(1)") +CAT = python("sys.stdout.buffer.write(sys.stdin.buffer.read())") + +if posix: + SLEEP = lambda seconds: ["/bin/sleep", str(seconds)] +else: + SLEEP = lambda seconds: python("import time; time.sleep({})".format(seconds)) + + +def got_signal(proc, sig): + if posix: + return proc.returncode == -sig + else: + return proc.returncode != 0 + + +@asynccontextmanager +async def open_process_then_kill(*args, **kwargs): + proc = await open_process(*args, **kwargs) + try: + yield proc + finally: + proc.kill() + await proc.wait() + + +@asynccontextmanager +async def run_process_in_nursery(*args, **kwargs): + async with _core.open_nursery() as nursery: + kwargs.setdefault("check", False) + proc = await nursery.start(partial(run_process, *args, **kwargs)) + yield proc + nursery.cancel_scope.cancel() + + +background_process_param = pytest.mark.parametrize( + "background_process", + [open_process_then_kill, run_process_in_nursery], + ids=["open_process", "run_process in nursery"], +) + + +@background_process_param +async def test_basic(background_process): + async with background_process(EXIT_TRUE) as proc: + await proc.wait() + assert isinstance(proc, Process) + assert proc._pidfd is None + assert proc.returncode == 0 + assert repr(proc) == f"" + + async with background_process(EXIT_FALSE) as proc: + await proc.wait() + assert proc.returncode == 1 + assert repr(proc) == "".format( + EXIT_FALSE, "exited with status 1" + ) + + +@background_process_param +async def test_auto_update_returncode(background_process): + async with background_process(SLEEP(9999)) as p: + assert p.returncode is None + assert "running" in repr(p) + p.kill() + p._proc.wait() + assert p.returncode is not None + assert "exited" in repr(p) + assert p._pidfd is None + assert p.returncode is not None + + +@background_process_param +async def test_multi_wait(background_process): + async with background_process(SLEEP(10)) as proc: + # Check that wait (including multi-wait) tolerates being cancelled + async with _core.open_nursery() as nursery: + nursery.start_soon(proc.wait) + nursery.start_soon(proc.wait) + nursery.start_soon(proc.wait) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + # Now try waiting for real + async with _core.open_nursery() as nursery: + nursery.start_soon(proc.wait) + nursery.start_soon(proc.wait) + nursery.start_soon(proc.wait) + await wait_all_tasks_blocked() + proc.kill() + + +# Test for deprecated 'async with process:' semantics +async def test_async_with_basics_deprecated(recwarn): + async with await open_process( + CAT, stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) as proc: + pass + assert proc.returncode is not None + with pytest.raises(ClosedResourceError): + await proc.stdin.send_all(b"x") + with pytest.raises(ClosedResourceError): + await proc.stdout.receive_some() + + +# Test for deprecated 'async with process:' semantics +async def test_kill_when_context_cancelled(recwarn): + with move_on_after(100) as scope: + async with await open_process(SLEEP(10)) as proc: + assert proc.poll() is None + scope.cancel() + await sleep_forever() + assert scope.cancelled_caught + assert got_signal(proc, SIGKILL) + assert repr(proc) == "".format( + SLEEP(10), "exited with signal 9" if posix else "exited with status 1" + ) + + +COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python( + "data = sys.stdin.buffer.read(); " + "sys.stdout.buffer.write(data); " + "sys.stderr.buffer.write(data[::-1])" +) + + +@background_process_param +async def test_pipes(background_process): + async with background_process( + COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) as proc: + msg = b"the quick brown fox jumps over the lazy dog" + + async def feed_input(): + await proc.stdin.send_all(msg) + await proc.stdin.aclose() + + async def check_output(stream, expected): + seen = bytearray() + async for chunk in stream: + seen += chunk + assert seen == expected + + async with _core.open_nursery() as nursery: + # fail eventually if something is broken + nursery.cancel_scope.deadline = _core.current_time() + 30.0 + nursery.start_soon(feed_input) + nursery.start_soon(check_output, proc.stdout, msg) + nursery.start_soon(check_output, proc.stderr, msg[::-1]) + + assert not nursery.cancel_scope.cancelled_caught + assert 0 == await proc.wait() + + +@background_process_param +async def test_interactive(background_process): + # Test some back-and-forth with a subprocess. This one works like so: + # in: 32\n + # out: 0000...0000\n (32 zeroes) + # err: 1111...1111\n (64 ones) + # in: 10\n + # out: 2222222222\n (10 twos) + # err: 3333....3333\n (20 threes) + # in: EOF + # out: EOF + # err: EOF + + async with background_process( + python( + "idx = 0\n" + "while True:\n" + " line = sys.stdin.readline()\n" + " if line == '': break\n" + " request = int(line.strip())\n" + " print(str(idx * 2) * request)\n" + " print(str(idx * 2 + 1) * request * 2, file=sys.stderr)\n" + " idx += 1\n" + ), + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) as proc: + + newline = b"\n" if posix else b"\r\n" + + async def expect(idx, request): + async with _core.open_nursery() as nursery: + + async def drain_one(stream, count, digit): + while count > 0: + result = await stream.receive_some(count) + assert result == ( + "{}".format(digit).encode("utf-8") * len(result) + ) + count -= len(result) + assert count == 0 + assert await stream.receive_some(len(newline)) == newline + + nursery.start_soon(drain_one, proc.stdout, request, idx * 2) + nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1) + + with fail_after(5): + await proc.stdin.send_all(b"12") + await sleep(0.1) + await proc.stdin.send_all(b"345" + newline) + await expect(0, 12345) + await proc.stdin.send_all(b"100" + newline + b"200" + newline) + await expect(1, 100) + await expect(2, 200) + await proc.stdin.send_all(b"0" + newline) + await expect(3, 0) + await proc.stdin.send_all(b"999999") + with move_on_after(0.1) as scope: + await expect(4, 0) + assert scope.cancelled_caught + await proc.stdin.send_all(newline) + await expect(4, 999999) + await proc.stdin.aclose() + assert await proc.stdout.receive_some(1) == b"" + assert await proc.stderr.receive_some(1) == b"" + await proc.wait() + + assert proc.returncode == 0 + + +async def test_run(): + data = bytes(random.randint(0, 255) for _ in range(2**18)) + + result = await run_process( + CAT, stdin=data, capture_stdout=True, capture_stderr=True + ) + assert result.args == CAT + assert result.returncode == 0 + assert result.stdout == data + assert result.stderr == b"" + + result = await run_process(CAT, capture_stdout=True) + assert result.args == CAT + assert result.returncode == 0 + assert result.stdout == b"" + assert result.stderr is None + + result = await run_process( + COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, + stdin=data, + capture_stdout=True, + capture_stderr=True, + ) + assert result.args == COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR + assert result.returncode == 0 + assert result.stdout == data + assert result.stderr == data[::-1] + + # invalid combinations + with pytest.raises(UnicodeError): + await run_process(CAT, stdin="oh no, it's text") + with pytest.raises(ValueError): + await run_process(CAT, stdin=subprocess.PIPE) + with pytest.raises(ValueError): + await run_process(CAT, stdout=subprocess.PIPE) + with pytest.raises(ValueError): + await run_process(CAT, stderr=subprocess.PIPE) + with pytest.raises(ValueError): + await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL) + with pytest.raises(ValueError): + await run_process(CAT, capture_stderr=True, stderr=None) + + +async def test_run_check(): + cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)") + with pytest.raises(subprocess.CalledProcessError) as excinfo: + await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True) + assert excinfo.value.cmd == cmd + assert excinfo.value.returncode == 1 + assert excinfo.value.stderr == b"test\n" + assert excinfo.value.stdout is None + + result = await run_process( + cmd, capture_stdout=True, capture_stderr=True, check=False + ) + assert result.args == cmd + assert result.stdout == b"" + assert result.stderr == b"test\n" + assert result.returncode == 1 + + +@skip_if_fbsd_pipes_broken +async def test_run_with_broken_pipe(): + result = await run_process( + [sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072 + ) + assert result.returncode == 0 + assert result.stdout is result.stderr is None + + +@background_process_param +async def test_stderr_stdout(background_process): + async with background_process( + COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) as proc: + assert proc.stdout is not None + assert proc.stderr is None + await proc.stdio.send_all(b"1234") + await proc.stdio.send_eof() + + output = [] + while True: + chunk = await proc.stdio.receive_some(16) + if chunk == b"": + break + output.append(chunk) + assert b"".join(output) == b"12344321" + assert proc.returncode == 0 + + # equivalent test with run_process() + result = await run_process( + COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, + stdin=b"1234", + capture_stdout=True, + stderr=subprocess.STDOUT, + ) + assert result.returncode == 0 + assert result.stdout == b"12344321" + assert result.stderr is None + + # this one hits the branch where stderr=STDOUT but stdout + # is not redirected + async with background_process( + CAT, stdin=subprocess.PIPE, stderr=subprocess.STDOUT + ) as proc: + assert proc.stdout is None + assert proc.stderr is None + await proc.stdin.aclose() + await proc.wait() + assert proc.returncode == 0 + + if posix: + try: + r, w = os.pipe() + + async with background_process( + COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, + stdin=subprocess.PIPE, + stdout=w, + stderr=subprocess.STDOUT, + ) as proc: + os.close(w) + assert proc.stdio is None + assert proc.stdout is None + assert proc.stderr is None + await proc.stdin.send_all(b"1234") + await proc.stdin.aclose() + assert await proc.wait() == 0 + assert os.read(r, 4096) == b"12344321" + assert os.read(r, 4096) == b"" + finally: + os.close(r) + + +async def test_errors(): + with pytest.raises(TypeError) as excinfo: + await open_process(["ls"], encoding="utf-8") + assert "unbuffered byte streams" in str(excinfo.value) + assert "the 'encoding' option is not supported" in str(excinfo.value) + + if posix: + with pytest.raises(TypeError) as excinfo: + await open_process(["ls"], shell=True) + with pytest.raises(TypeError) as excinfo: + await open_process("ls", shell=False) + + +@background_process_param +async def test_signals(background_process): + async def test_one_signal(send_it, signum): + with move_on_after(1.0) as scope: + async with background_process(SLEEP(3600)) as proc: + send_it(proc) + await proc.wait() + assert not scope.cancelled_caught + if posix: + assert proc.returncode == -signum + else: + assert proc.returncode != 0 + + await test_one_signal(Process.kill, SIGKILL) + await test_one_signal(Process.terminate, SIGTERM) + # Test that we can send arbitrary signals. + # + # We used to use SIGINT here, but it turns out that the Python interpreter + # has race conditions that can cause it to explode in weird ways if it + # tries to handle SIGINT during startup. SIGUSR1's default disposition is + # to terminate the target process, and Python doesn't try to do anything + # clever to handle it. + if posix: + await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1) + + +@pytest.mark.skipif(not posix, reason="POSIX specific") +@background_process_param +async def test_wait_reapable_fails(background_process): + old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN) + try: + # With SIGCHLD disabled, the wait() syscall will wait for the + # process to exit but then fail with ECHILD. Make sure we + # support this case as the stdlib subprocess module does. + async with background_process(SLEEP(3600)) as proc: + async with _core.open_nursery() as nursery: + nursery.start_soon(proc.wait) + await wait_all_tasks_blocked() + proc.kill() + nursery.cancel_scope.deadline = _core.current_time() + 1.0 + assert not nursery.cancel_scope.cancelled_caught + assert proc.returncode == 0 # exit status unknowable, so... + finally: + signal.signal(signal.SIGCHLD, old_sigchld) + + +@slow +def test_waitid_eintr(): + # This only matters on PyPy (where we're coding EINTR handling + # ourselves) but the test works on all waitid platforms. + from .._subprocess_platform import wait_child_exiting + + if not wait_child_exiting.__module__.endswith("waitid"): + pytest.skip("waitid only") + from .._subprocess_platform.waitid import sync_wait_reapable + + got_alarm = False + sleeper = subprocess.Popen(["sleep", "3600"]) + + def on_alarm(sig, frame): + nonlocal got_alarm + got_alarm = True + sleeper.kill() + + old_sigalrm = signal.signal(signal.SIGALRM, on_alarm) + try: + signal.alarm(1) + sync_wait_reapable(sleeper.pid) + assert sleeper.wait(timeout=1) == -9 + finally: + if sleeper.returncode is None: # pragma: no cover + # We only get here if something fails in the above; + # if the test passes, wait() will reap the process + sleeper.kill() + sleeper.wait() + signal.signal(signal.SIGALRM, old_sigalrm) + + +async def test_custom_deliver_cancel(): + custom_deliver_cancel_called = False + + async def custom_deliver_cancel(proc): + nonlocal custom_deliver_cancel_called + custom_deliver_cancel_called = True + proc.terminate() + # Make sure this does get cancelled when the process exits, and that + # the process really exited. + try: + await sleep_forever() + finally: + assert proc.returncode is not None + + async with _core.open_nursery() as nursery: + nursery.start_soon( + partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel) + ) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + assert custom_deliver_cancel_called + + +async def test_warn_on_failed_cancel_terminate(monkeypatch): + original_terminate = Process.terminate + + def broken_terminate(self): + original_terminate(self) + raise OSError("whoops") + + monkeypatch.setattr(Process, "terminate", broken_terminate) + + with pytest.warns(RuntimeWarning, match=".*whoops.*"): + async with _core.open_nursery() as nursery: + nursery.start_soon(run_process, SLEEP(9999)) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + +@pytest.mark.skipif(not posix, reason="posix only") +async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch): + monkeypatch.setattr(Process, "terminate", lambda *args: None) + + with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"): + async with _core.open_nursery() as nursery: + nursery.start_soon(run_process, SLEEP(9999)) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + +# the background_process_param exercises a lot of run_process cases, but it uses +# check=False, so lets have a test that uses check=True as well +async def test_run_process_background_fail(): + with pytest.raises(subprocess.CalledProcessError): + async with _core.open_nursery() as nursery: + proc = await nursery.start(run_process, EXIT_FALSE) + assert proc.returncode == 1 + + +@pytest.mark.skipif( + not SyncPath("/dev/fd").exists(), + reason="requires a way to iterate through open files", +) +async def test_for_leaking_fds(): + starting_fds = set(SyncPath("/dev/fd").iterdir()) + await run_process(EXIT_TRUE) + assert set(SyncPath("/dev/fd").iterdir()) == starting_fds + + with pytest.raises(subprocess.CalledProcessError): + await run_process(EXIT_FALSE) + assert set(SyncPath("/dev/fd").iterdir()) == starting_fds + + with pytest.raises(PermissionError): + await run_process(["/dev/fd/0"]) + assert set(SyncPath("/dev/fd").iterdir()) == starting_fds + + +# regression test for #2209 +async def test_subprocess_pidfd_unnotified(): + noticed_exit = None + + async def wait_and_tell(proc) -> None: + nonlocal noticed_exit + noticed_exit = Event() + await proc.wait() + noticed_exit.set() + + proc = await open_process(SLEEP(9999)) + async with _core.open_nursery() as nursery: + nursery.start_soon(wait_and_tell, proc) + await wait_all_tasks_blocked() + assert isinstance(noticed_exit, Event) + proc.terminate() + # without giving trio a chance to do so, + with assert_no_checkpoints(): + # wait until the process has actually exited; + proc._proc.wait() + # force a call to poll (that closes the pidfd on linux) + proc.poll() + with move_on_after(5): + # Some platforms use threads to wait for exit, so it might take a bit + # for everything to notice + await noticed_exit.wait() + assert noticed_exit.is_set(), "child task wasn't woken after poll, DEADLOCK" diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_sync.py b/venv/lib/python3.9/site-packages/trio/tests/test_sync.py new file mode 100644 index 00000000..33f79c4d --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_sync.py @@ -0,0 +1,567 @@ +import pytest + +import weakref + +from ..testing import wait_all_tasks_blocked, assert_checkpoints + +from .. import _core +from .. import _timeouts +from .._timeouts import sleep_forever, move_on_after +from .._sync import * + + +async def test_Event(): + e = Event() + assert not e.is_set() + assert e.statistics().tasks_waiting == 0 + + e.set() + assert e.is_set() + with assert_checkpoints(): + await e.wait() + + e = Event() + + record = [] + + async def child(): + record.append("sleeping") + await e.wait() + record.append("woken") + + async with _core.open_nursery() as nursery: + nursery.start_soon(child) + nursery.start_soon(child) + await wait_all_tasks_blocked() + assert record == ["sleeping", "sleeping"] + assert e.statistics().tasks_waiting == 2 + e.set() + await wait_all_tasks_blocked() + assert record == ["sleeping", "sleeping", "woken", "woken"] + + +async def test_CapacityLimiter(): + with pytest.raises(TypeError): + CapacityLimiter(1.0) + with pytest.raises(ValueError): + CapacityLimiter(-1) + c = CapacityLimiter(2) + repr(c) # smoke test + assert c.total_tokens == 2 + assert c.borrowed_tokens == 0 + assert c.available_tokens == 2 + with pytest.raises(RuntimeError): + c.release() + assert c.borrowed_tokens == 0 + c.acquire_nowait() + assert c.borrowed_tokens == 1 + assert c.available_tokens == 1 + + stats = c.statistics() + assert stats.borrowed_tokens == 1 + assert stats.total_tokens == 2 + assert stats.borrowers == [_core.current_task()] + assert stats.tasks_waiting == 0 + + # Can't re-acquire when we already have it + with pytest.raises(RuntimeError): + c.acquire_nowait() + assert c.borrowed_tokens == 1 + with pytest.raises(RuntimeError): + await c.acquire() + assert c.borrowed_tokens == 1 + + # We can acquire on behalf of someone else though + with assert_checkpoints(): + await c.acquire_on_behalf_of("someone") + + # But then we've run out of capacity + assert c.borrowed_tokens == 2 + with pytest.raises(_core.WouldBlock): + c.acquire_on_behalf_of_nowait("third party") + + assert set(c.statistics().borrowers) == {_core.current_task(), "someone"} + + # Until we release one + c.release_on_behalf_of(_core.current_task()) + assert c.statistics().borrowers == ["someone"] + + c.release_on_behalf_of("someone") + assert c.borrowed_tokens == 0 + with assert_checkpoints(): + async with c: + assert c.borrowed_tokens == 1 + + async with _core.open_nursery() as nursery: + await c.acquire_on_behalf_of("value 1") + await c.acquire_on_behalf_of("value 2") + nursery.start_soon(c.acquire_on_behalf_of, "value 3") + await wait_all_tasks_blocked() + assert c.borrowed_tokens == 2 + assert c.statistics().tasks_waiting == 1 + c.release_on_behalf_of("value 2") + # Fairness: + assert c.borrowed_tokens == 2 + with pytest.raises(_core.WouldBlock): + c.acquire_nowait() + + c.release_on_behalf_of("value 3") + c.release_on_behalf_of("value 1") + + +async def test_CapacityLimiter_inf(): + from math import inf + + c = CapacityLimiter(inf) + repr(c) # smoke test + assert c.total_tokens == inf + assert c.borrowed_tokens == 0 + assert c.available_tokens == inf + with pytest.raises(RuntimeError): + c.release() + assert c.borrowed_tokens == 0 + c.acquire_nowait() + assert c.borrowed_tokens == 1 + assert c.available_tokens == inf + + +async def test_CapacityLimiter_change_total_tokens(): + c = CapacityLimiter(2) + + with pytest.raises(TypeError): + c.total_tokens = 1.0 + + with pytest.raises(ValueError): + c.total_tokens = 0 + + with pytest.raises(ValueError): + c.total_tokens = -10 + + assert c.total_tokens == 2 + + async with _core.open_nursery() as nursery: + for i in range(5): + nursery.start_soon(c.acquire_on_behalf_of, i) + await wait_all_tasks_blocked() + assert set(c.statistics().borrowers) == {0, 1} + assert c.statistics().tasks_waiting == 3 + c.total_tokens += 2 + assert set(c.statistics().borrowers) == {0, 1, 2, 3} + assert c.statistics().tasks_waiting == 1 + c.total_tokens -= 3 + assert c.borrowed_tokens == 4 + assert c.total_tokens == 1 + c.release_on_behalf_of(0) + c.release_on_behalf_of(1) + c.release_on_behalf_of(2) + assert set(c.statistics().borrowers) == {3} + assert c.statistics().tasks_waiting == 1 + c.release_on_behalf_of(3) + assert set(c.statistics().borrowers) == {4} + assert c.statistics().tasks_waiting == 0 + + +# regression test for issue #548 +async def test_CapacityLimiter_memleak_548(): + limiter = CapacityLimiter(total_tokens=1) + await limiter.acquire() + + async with _core.open_nursery() as n: + n.start_soon(limiter.acquire) + await wait_all_tasks_blocked() # give it a chance to run the task + n.cancel_scope.cancel() + + # if this is 1, the acquire call (despite being killed) is still there in the task, and will + # leak memory all the while the limiter is active + assert len(limiter._pending_borrowers) == 0 + + +async def test_Semaphore(): + with pytest.raises(TypeError): + Semaphore(1.0) + with pytest.raises(ValueError): + Semaphore(-1) + s = Semaphore(1) + repr(s) # smoke test + assert s.value == 1 + assert s.max_value is None + s.release() + assert s.value == 2 + assert s.statistics().tasks_waiting == 0 + s.acquire_nowait() + assert s.value == 1 + with assert_checkpoints(): + await s.acquire() + assert s.value == 0 + with pytest.raises(_core.WouldBlock): + s.acquire_nowait() + + s.release() + assert s.value == 1 + with assert_checkpoints(): + async with s: + assert s.value == 0 + assert s.value == 1 + s.acquire_nowait() + + record = [] + + async def do_acquire(s): + record.append("started") + await s.acquire() + record.append("finished") + + async with _core.open_nursery() as nursery: + nursery.start_soon(do_acquire, s) + await wait_all_tasks_blocked() + assert record == ["started"] + assert s.value == 0 + s.release() + # Fairness: + assert s.value == 0 + with pytest.raises(_core.WouldBlock): + s.acquire_nowait() + assert record == ["started", "finished"] + + +async def test_Semaphore_bounded(): + with pytest.raises(TypeError): + Semaphore(1, max_value=1.0) + with pytest.raises(ValueError): + Semaphore(2, max_value=1) + bs = Semaphore(1, max_value=1) + assert bs.max_value == 1 + repr(bs) # smoke test + with pytest.raises(ValueError): + bs.release() + assert bs.value == 1 + bs.acquire_nowait() + assert bs.value == 0 + bs.release() + assert bs.value == 1 + + +@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__) +async def test_Lock_and_StrictFIFOLock(lockcls): + l = lockcls() # noqa + assert not l.locked() + + # make sure locks can be weakref'ed (gh-331) + r = weakref.ref(l) + assert r() is l + + repr(l) # smoke test + # make sure repr uses the right name for subclasses + assert lockcls.__name__ in repr(l) + with assert_checkpoints(): + async with l: + assert l.locked() + repr(l) # smoke test (repr branches on locked/unlocked) + assert not l.locked() + l.acquire_nowait() + assert l.locked() + l.release() + assert not l.locked() + with assert_checkpoints(): + await l.acquire() + assert l.locked() + l.release() + assert not l.locked() + + l.acquire_nowait() + with pytest.raises(RuntimeError): + # Error out if we already own the lock + l.acquire_nowait() + l.release() + with pytest.raises(RuntimeError): + # Error out if we don't own the lock + l.release() + + holder_task = None + + async def holder(): + nonlocal holder_task + holder_task = _core.current_task() + async with l: + await sleep_forever() + + async with _core.open_nursery() as nursery: + assert not l.locked() + nursery.start_soon(holder) + await wait_all_tasks_blocked() + assert l.locked() + # WouldBlock if someone else holds the lock + with pytest.raises(_core.WouldBlock): + l.acquire_nowait() + # Can't release a lock someone else holds + with pytest.raises(RuntimeError): + l.release() + + statistics = l.statistics() + print(statistics) + assert statistics.locked + assert statistics.owner is holder_task + assert statistics.tasks_waiting == 0 + + nursery.start_soon(holder) + await wait_all_tasks_blocked() + statistics = l.statistics() + print(statistics) + assert statistics.tasks_waiting == 1 + + nursery.cancel_scope.cancel() + + statistics = l.statistics() + assert not statistics.locked + assert statistics.owner is None + assert statistics.tasks_waiting == 0 + + +async def test_Condition(): + with pytest.raises(TypeError): + Condition(Semaphore(1)) + with pytest.raises(TypeError): + Condition(StrictFIFOLock) + l = Lock() # noqa + c = Condition(l) + assert not l.locked() + assert not c.locked() + with assert_checkpoints(): + await c.acquire() + assert l.locked() + assert c.locked() + + c = Condition() + assert not c.locked() + c.acquire_nowait() + assert c.locked() + with pytest.raises(RuntimeError): + c.acquire_nowait() + c.release() + + with pytest.raises(RuntimeError): + # Can't wait without holding the lock + await c.wait() + with pytest.raises(RuntimeError): + # Can't notify without holding the lock + c.notify() + with pytest.raises(RuntimeError): + # Can't notify without holding the lock + c.notify_all() + + finished_waiters = set() + + async def waiter(i): + async with c: + await c.wait() + finished_waiters.add(i) + + async with _core.open_nursery() as nursery: + for i in range(3): + nursery.start_soon(waiter, i) + await wait_all_tasks_blocked() + async with c: + c.notify() + assert c.locked() + await wait_all_tasks_blocked() + assert finished_waiters == {0} + async with c: + c.notify_all() + await wait_all_tasks_blocked() + assert finished_waiters == {0, 1, 2} + + finished_waiters = set() + async with _core.open_nursery() as nursery: + for i in range(3): + nursery.start_soon(waiter, i) + await wait_all_tasks_blocked() + async with c: + c.notify(2) + statistics = c.statistics() + print(statistics) + assert statistics.tasks_waiting == 1 + assert statistics.lock_statistics.tasks_waiting == 2 + # exiting the context manager hands off the lock to the first task + assert c.statistics().lock_statistics.tasks_waiting == 1 + + await wait_all_tasks_blocked() + assert finished_waiters == {0, 1} + + async with c: + c.notify_all() + + # After being cancelled still hold the lock (!) + # (Note that c.__aexit__ checks that we hold the lock as well) + with _core.CancelScope() as scope: + async with c: + scope.cancel() + try: + await c.wait() + finally: + assert c.locked() + + +from .._sync import AsyncContextManagerMixin +from .._channel import open_memory_channel + +# Three ways of implementing a Lock in terms of a channel. Used to let us put +# the channel through the generic lock tests. + + +class ChannelLock1(AsyncContextManagerMixin): + def __init__(self, capacity): + self.s, self.r = open_memory_channel(capacity) + for _ in range(capacity - 1): + self.s.send_nowait(None) + + def acquire_nowait(self): + self.s.send_nowait(None) + + async def acquire(self): + await self.s.send(None) + + def release(self): + self.r.receive_nowait() + + +class ChannelLock2(AsyncContextManagerMixin): + def __init__(self): + self.s, self.r = open_memory_channel(10) + self.s.send_nowait(None) + + def acquire_nowait(self): + self.r.receive_nowait() + + async def acquire(self): + await self.r.receive() + + def release(self): + self.s.send_nowait(None) + + +class ChannelLock3(AsyncContextManagerMixin): + def __init__(self): + self.s, self.r = open_memory_channel(0) + # self.acquired is true when one task acquires the lock and + # only becomes false when it's released and no tasks are + # waiting to acquire. + self.acquired = False + + def acquire_nowait(self): + assert not self.acquired + self.acquired = True + + async def acquire(self): + if self.acquired: + await self.s.send(None) + else: + self.acquired = True + await _core.checkpoint() + + def release(self): + try: + self.r.receive_nowait() + except _core.WouldBlock: + assert self.acquired + self.acquired = False + + +lock_factories = [ + lambda: CapacityLimiter(1), + lambda: Semaphore(1), + Lock, + StrictFIFOLock, + lambda: ChannelLock1(10), + lambda: ChannelLock1(1), + ChannelLock2, + ChannelLock3, +] +lock_factory_names = [ + "CapacityLimiter(1)", + "Semaphore(1)", + "Lock", + "StrictFIFOLock", + "ChannelLock1(10)", + "ChannelLock1(1)", + "ChannelLock2", + "ChannelLock3", +] + +generic_lock_test = pytest.mark.parametrize( + "lock_factory", lock_factories, ids=lock_factory_names +) + + +# Spawn a bunch of workers that take a lock and then yield; make sure that +# only one worker is ever in the critical section at a time. +@generic_lock_test +async def test_generic_lock_exclusion(lock_factory): + LOOPS = 10 + WORKERS = 5 + in_critical_section = False + acquires = 0 + + async def worker(lock_like): + nonlocal in_critical_section, acquires + for _ in range(LOOPS): + async with lock_like: + acquires += 1 + assert not in_critical_section + in_critical_section = True + await _core.checkpoint() + await _core.checkpoint() + assert in_critical_section + in_critical_section = False + + async with _core.open_nursery() as nursery: + lock_like = lock_factory() + for _ in range(WORKERS): + nursery.start_soon(worker, lock_like) + assert not in_critical_section + assert acquires == LOOPS * WORKERS + + +# Several workers queue on the same lock; make sure they each get it, in +# order. +@generic_lock_test +async def test_generic_lock_fifo_fairness(lock_factory): + initial_order = [] + record = [] + LOOPS = 5 + + async def loopy(name, lock_like): + # Record the order each task was initially scheduled in + initial_order.append(name) + for _ in range(LOOPS): + async with lock_like: + record.append(name) + + lock_like = lock_factory() + async with _core.open_nursery() as nursery: + nursery.start_soon(loopy, 1, lock_like) + nursery.start_soon(loopy, 2, lock_like) + nursery.start_soon(loopy, 3, lock_like) + # The first three could be in any order due to scheduling randomness, + # but after that they should repeat in the same order + for i in range(LOOPS): + assert record[3 * i : 3 * (i + 1)] == initial_order + + +@generic_lock_test +async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory): + lock_like = lock_factory() + + record = [] + + async def lock_taker(): + record.append("started") + async with lock_like: + pass + record.append("finished") + + async with _core.open_nursery() as nursery: + lock_like.acquire_nowait() + nursery.start_soon(lock_taker) + await wait_all_tasks_blocked() + assert record == ["started"] + lock_like.release() diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_testing.py b/venv/lib/python3.9/site-packages/trio/tests/test_testing.py new file mode 100644 index 00000000..0b10ae71 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_testing.py @@ -0,0 +1,657 @@ +# XX this should get broken up, like testing.py did + +import tempfile + +import pytest + +from .._core.tests.tutil import can_bind_ipv6 +from .. import sleep +from .. import _core +from .._highlevel_generic import aclose_forcefully +from ..testing import * +from ..testing._check_streams import _assert_raises +from ..testing._memory_streams import _UnboundedByteQueue +from .. import socket as tsocket +from .._highlevel_socket import SocketListener + + +async def test_wait_all_tasks_blocked(): + record = [] + + async def busy_bee(): + for _ in range(10): + await _core.checkpoint() + record.append("busy bee exhausted") + + async def waiting_for_bee_to_leave(): + await wait_all_tasks_blocked() + record.append("quiet at last!") + + async with _core.open_nursery() as nursery: + nursery.start_soon(busy_bee) + nursery.start_soon(waiting_for_bee_to_leave) + nursery.start_soon(waiting_for_bee_to_leave) + + # check cancellation + record = [] + + async def cancelled_while_waiting(): + try: + await wait_all_tasks_blocked() + except _core.Cancelled: + record.append("ok") + + async with _core.open_nursery() as nursery: + nursery.start_soon(cancelled_while_waiting) + nursery.cancel_scope.cancel() + assert record == ["ok"] + + +async def test_wait_all_tasks_blocked_with_timeouts(mock_clock): + record = [] + + async def timeout_task(): + record.append("tt start") + await sleep(5) + record.append("tt finished") + + async with _core.open_nursery() as nursery: + nursery.start_soon(timeout_task) + await wait_all_tasks_blocked() + assert record == ["tt start"] + mock_clock.jump(10) + await wait_all_tasks_blocked() + assert record == ["tt start", "tt finished"] + + +async def test_wait_all_tasks_blocked_with_cushion(): + record = [] + + async def blink(): + record.append("blink start") + await sleep(0.01) + await sleep(0.01) + await sleep(0.01) + record.append("blink end") + + async def wait_no_cushion(): + await wait_all_tasks_blocked() + record.append("wait_no_cushion end") + + async def wait_small_cushion(): + await wait_all_tasks_blocked(0.02) + record.append("wait_small_cushion end") + + async def wait_big_cushion(): + await wait_all_tasks_blocked(0.03) + record.append("wait_big_cushion end") + + async with _core.open_nursery() as nursery: + nursery.start_soon(blink) + nursery.start_soon(wait_no_cushion) + nursery.start_soon(wait_small_cushion) + nursery.start_soon(wait_small_cushion) + nursery.start_soon(wait_big_cushion) + + assert record == [ + "blink start", + "wait_no_cushion end", + "blink end", + "wait_small_cushion end", + "wait_small_cushion end", + "wait_big_cushion end", + ] + + +################################################################ + + +async def test_assert_checkpoints(recwarn): + with assert_checkpoints(): + await _core.checkpoint() + + with pytest.raises(AssertionError): + with assert_checkpoints(): + 1 + 1 + + # partial yield cases + # if you have a schedule point but not a cancel point, or vice-versa, then + # that's not a checkpoint. + for partial_yield in [ + _core.checkpoint_if_cancelled, + _core.cancel_shielded_checkpoint, + ]: + print(partial_yield) + with pytest.raises(AssertionError): + with assert_checkpoints(): + await partial_yield() + + # But both together count as a checkpoint + with assert_checkpoints(): + await _core.checkpoint_if_cancelled() + await _core.cancel_shielded_checkpoint() + + +async def test_assert_no_checkpoints(recwarn): + with assert_no_checkpoints(): + 1 + 1 + + with pytest.raises(AssertionError): + with assert_no_checkpoints(): + await _core.checkpoint() + + # partial yield cases + # if you have a schedule point but not a cancel point, or vice-versa, then + # that doesn't make *either* version of assert_{no_,}yields happy. + for partial_yield in [ + _core.checkpoint_if_cancelled, + _core.cancel_shielded_checkpoint, + ]: + print(partial_yield) + with pytest.raises(AssertionError): + with assert_no_checkpoints(): + await partial_yield() + + # And both together also count as a checkpoint + with pytest.raises(AssertionError): + with assert_no_checkpoints(): + await _core.checkpoint_if_cancelled() + await _core.cancel_shielded_checkpoint() + + +################################################################ + + +async def test_Sequencer(): + record = [] + + def t(val): + print(val) + record.append(val) + + async def f1(seq): + async with seq(1): + t(("f1", 1)) + async with seq(3): + t(("f1", 3)) + async with seq(4): + t(("f1", 4)) + + async def f2(seq): + async with seq(0): + t(("f2", 0)) + async with seq(2): + t(("f2", 2)) + + seq = Sequencer() + async with _core.open_nursery() as nursery: + nursery.start_soon(f1, seq) + nursery.start_soon(f2, seq) + async with seq(5): + await wait_all_tasks_blocked() + assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)] + + seq = Sequencer() + # Catches us if we try to re-use a sequence point: + async with seq(0): + pass + with pytest.raises(RuntimeError): + async with seq(0): + pass # pragma: no cover + + +async def test_Sequencer_cancel(): + # Killing a blocked task makes everything blow up + record = [] + seq = Sequencer() + + async def child(i): + with _core.CancelScope() as scope: + if i == 1: + scope.cancel() + try: + async with seq(i): + pass # pragma: no cover + except RuntimeError: + record.append("seq({}) RuntimeError".format(i)) + + async with _core.open_nursery() as nursery: + nursery.start_soon(child, 1) + nursery.start_soon(child, 2) + async with seq(0): + pass # pragma: no cover + + assert record == ["seq(1) RuntimeError", "seq(2) RuntimeError"] + + # Late arrivals also get errors + with pytest.raises(RuntimeError): + async with seq(3): + pass # pragma: no cover + + +################################################################ + + +async def test__assert_raises(): + with pytest.raises(AssertionError): + with _assert_raises(RuntimeError): + 1 + 1 + + with pytest.raises(TypeError): + with _assert_raises(RuntimeError): + "foo" + 1 + + with _assert_raises(RuntimeError): + raise RuntimeError + + +# This is a private implementation detail, but it's complex enough to be worth +# testing directly +async def test__UnboundeByteQueue(): + ubq = _UnboundedByteQueue() + + ubq.put(b"123") + ubq.put(b"456") + assert ubq.get_nowait(1) == b"1" + assert ubq.get_nowait(10) == b"23456" + ubq.put(b"789") + assert ubq.get_nowait() == b"789" + + with pytest.raises(_core.WouldBlock): + ubq.get_nowait(10) + with pytest.raises(_core.WouldBlock): + ubq.get_nowait() + + with pytest.raises(TypeError): + ubq.put("string") + + ubq.put(b"abc") + with assert_checkpoints(): + assert await ubq.get(10) == b"abc" + ubq.put(b"def") + ubq.put(b"ghi") + with assert_checkpoints(): + assert await ubq.get(1) == b"d" + with assert_checkpoints(): + assert await ubq.get() == b"efghi" + + async def putter(data): + await wait_all_tasks_blocked() + ubq.put(data) + + async def getter(expect): + with assert_checkpoints(): + assert await ubq.get() == expect + + async with _core.open_nursery() as nursery: + nursery.start_soon(getter, b"xyz") + nursery.start_soon(putter, b"xyz") + + # Two gets at the same time -> BusyResourceError + with pytest.raises(_core.BusyResourceError): + async with _core.open_nursery() as nursery: + nursery.start_soon(getter, b"asdf") + nursery.start_soon(getter, b"asdf") + + # Closing + + ubq.close() + with pytest.raises(_core.ClosedResourceError): + ubq.put(b"---") + + assert ubq.get_nowait(10) == b"" + assert ubq.get_nowait() == b"" + assert await ubq.get(10) == b"" + assert await ubq.get() == b"" + + # close is idempotent + ubq.close() + + # close wakes up blocked getters + ubq2 = _UnboundedByteQueue() + + async def closer(): + await wait_all_tasks_blocked() + ubq2.close() + + async with _core.open_nursery() as nursery: + nursery.start_soon(getter, b"") + nursery.start_soon(closer) + + +async def test_MemorySendStream(): + mss = MemorySendStream() + + async def do_send_all(data): + with assert_checkpoints(): + await mss.send_all(data) + + await do_send_all(b"123") + assert mss.get_data_nowait(1) == b"1" + assert mss.get_data_nowait() == b"23" + + with assert_checkpoints(): + await mss.wait_send_all_might_not_block() + + with pytest.raises(_core.WouldBlock): + mss.get_data_nowait() + with pytest.raises(_core.WouldBlock): + mss.get_data_nowait(10) + + await do_send_all(b"456") + with assert_checkpoints(): + assert await mss.get_data() == b"456" + + # Call send_all twice at once; one should get BusyResourceError and one + # should succeed. But we can't let the error propagate, because it might + # cause the other to be cancelled before it can finish doing its thing, + # and we don't know which one will get the error. + resource_busy_count = 0 + + async def do_send_all_count_resourcebusy(): + nonlocal resource_busy_count + try: + await do_send_all(b"xxx") + except _core.BusyResourceError: + resource_busy_count += 1 + + async with _core.open_nursery() as nursery: + nursery.start_soon(do_send_all_count_resourcebusy) + nursery.start_soon(do_send_all_count_resourcebusy) + + assert resource_busy_count == 1 + + with assert_checkpoints(): + await mss.aclose() + + assert await mss.get_data() == b"xxx" + assert await mss.get_data() == b"" + with pytest.raises(_core.ClosedResourceError): + await do_send_all(b"---") + + # hooks + + assert mss.send_all_hook is None + assert mss.wait_send_all_might_not_block_hook is None + assert mss.close_hook is None + + record = [] + + async def send_all_hook(): + # hook runs after send_all does its work (can pull data out) + assert mss2.get_data_nowait() == b"abc" + record.append("send_all_hook") + + async def wait_send_all_might_not_block_hook(): + record.append("wait_send_all_might_not_block_hook") + + def close_hook(): + record.append("close_hook") + + mss2 = MemorySendStream( + send_all_hook, wait_send_all_might_not_block_hook, close_hook + ) + + assert mss2.send_all_hook is send_all_hook + assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook + assert mss2.close_hook is close_hook + + await mss2.send_all(b"abc") + await mss2.wait_send_all_might_not_block() + await aclose_forcefully(mss2) + mss2.close() + + assert record == [ + "send_all_hook", + "wait_send_all_might_not_block_hook", + "close_hook", + "close_hook", + ] + + +async def test_MemoryReceiveStream(): + mrs = MemoryReceiveStream() + + async def do_receive_some(max_bytes): + with assert_checkpoints(): + return await mrs.receive_some(max_bytes) + + mrs.put_data(b"abc") + assert await do_receive_some(1) == b"a" + assert await do_receive_some(10) == b"bc" + mrs.put_data(b"abc") + assert await do_receive_some(None) == b"abc" + + with pytest.raises(_core.BusyResourceError): + async with _core.open_nursery() as nursery: + nursery.start_soon(do_receive_some, 10) + nursery.start_soon(do_receive_some, 10) + + assert mrs.receive_some_hook is None + + mrs.put_data(b"def") + mrs.put_eof() + mrs.put_eof() + + assert await do_receive_some(10) == b"def" + assert await do_receive_some(10) == b"" + assert await do_receive_some(10) == b"" + + with pytest.raises(_core.ClosedResourceError): + mrs.put_data(b"---") + + async def receive_some_hook(): + mrs2.put_data(b"xxx") + + record = [] + + def close_hook(): + record.append("closed") + + mrs2 = MemoryReceiveStream(receive_some_hook, close_hook) + assert mrs2.receive_some_hook is receive_some_hook + assert mrs2.close_hook is close_hook + + mrs2.put_data(b"yyy") + assert await mrs2.receive_some(10) == b"yyyxxx" + assert await mrs2.receive_some(10) == b"xxx" + assert await mrs2.receive_some(10) == b"xxx" + + mrs2.put_data(b"zzz") + mrs2.receive_some_hook = None + assert await mrs2.receive_some(10) == b"zzz" + + mrs2.put_data(b"lost on close") + with assert_checkpoints(): + await mrs2.aclose() + assert record == ["closed"] + + with pytest.raises(_core.ClosedResourceError): + await mrs2.receive_some(10) + + +async def test_MemoryRecvStream_closing(): + mrs = MemoryReceiveStream() + # close with no pending data + mrs.close() + with pytest.raises(_core.ClosedResourceError): + assert await mrs.receive_some(10) == b"" + # repeated closes ok + mrs.close() + # put_data now fails + with pytest.raises(_core.ClosedResourceError): + mrs.put_data(b"123") + + mrs2 = MemoryReceiveStream() + # close with pending data + mrs2.put_data(b"xyz") + mrs2.close() + with pytest.raises(_core.ClosedResourceError): + await mrs2.receive_some(10) + + +async def test_memory_stream_pump(): + mss = MemorySendStream() + mrs = MemoryReceiveStream() + + # no-op if no data present + memory_stream_pump(mss, mrs) + + await mss.send_all(b"123") + memory_stream_pump(mss, mrs) + assert await mrs.receive_some(10) == b"123" + + await mss.send_all(b"456") + assert memory_stream_pump(mss, mrs, max_bytes=1) + assert await mrs.receive_some(10) == b"4" + assert memory_stream_pump(mss, mrs, max_bytes=1) + assert memory_stream_pump(mss, mrs, max_bytes=1) + assert not memory_stream_pump(mss, mrs, max_bytes=1) + assert await mrs.receive_some(10) == b"56" + + mss.close() + memory_stream_pump(mss, mrs) + assert await mrs.receive_some(10) == b"" + + +async def test_memory_stream_one_way_pair(): + s, r = memory_stream_one_way_pair() + assert s.send_all_hook is not None + assert s.wait_send_all_might_not_block_hook is None + assert s.close_hook is not None + assert r.receive_some_hook is None + await s.send_all(b"123") + assert await r.receive_some(10) == b"123" + + async def receiver(expected): + assert await r.receive_some(10) == expected + + # This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook + async with _core.open_nursery() as nursery: + nursery.start_soon(receiver, b"abc") + await wait_all_tasks_blocked() + await s.send_all(b"abc") + + # And this fails if we don't pump from close_hook + async with _core.open_nursery() as nursery: + nursery.start_soon(receiver, b"") + await wait_all_tasks_blocked() + await s.aclose() + + s, r = memory_stream_one_way_pair() + + async with _core.open_nursery() as nursery: + nursery.start_soon(receiver, b"") + await wait_all_tasks_blocked() + s.close() + + s, r = memory_stream_one_way_pair() + + old = s.send_all_hook + s.send_all_hook = None + await s.send_all(b"456") + + async def cancel_after_idle(nursery): + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + async def check_for_cancel(): + with pytest.raises(_core.Cancelled): + # This should block forever... or until cancelled. Even though we + # sent some data on the send stream. + await r.receive_some(10) + + async with _core.open_nursery() as nursery: + nursery.start_soon(cancel_after_idle, nursery) + nursery.start_soon(check_for_cancel) + + s.send_all_hook = old + await s.send_all(b"789") + assert await r.receive_some(10) == b"456789" + + +async def test_memory_stream_pair(): + a, b = memory_stream_pair() + await a.send_all(b"123") + await b.send_all(b"abc") + assert await b.receive_some(10) == b"123" + assert await a.receive_some(10) == b"abc" + + await a.send_eof() + assert await b.receive_some(10) == b"" + + async def sender(): + await wait_all_tasks_blocked() + await b.send_all(b"xyz") + + async def receiver(): + assert await a.receive_some(10) == b"xyz" + + async with _core.open_nursery() as nursery: + nursery.start_soon(receiver) + nursery.start_soon(sender) + + +async def test_memory_streams_with_generic_tests(): + async def one_way_stream_maker(): + return memory_stream_one_way_pair() + + await check_one_way_stream(one_way_stream_maker, None) + + async def half_closeable_stream_maker(): + return memory_stream_pair() + + await check_half_closeable_stream(half_closeable_stream_maker, None) + + +async def test_lockstep_streams_with_generic_tests(): + async def one_way_stream_maker(): + return lockstep_stream_one_way_pair() + + await check_one_way_stream(one_way_stream_maker, one_way_stream_maker) + + async def two_way_stream_maker(): + return lockstep_stream_pair() + + await check_two_way_stream(two_way_stream_maker, two_way_stream_maker) + + +async def test_open_stream_to_socket_listener(): + async def check(listener): + async with listener: + client_stream = await open_stream_to_socket_listener(listener) + async with client_stream: + server_stream = await listener.accept() + async with server_stream: + await client_stream.send_all(b"x") + await server_stream.receive_some(1) == b"x" + + # Listener bound to localhost + sock = tsocket.socket() + await sock.bind(("127.0.0.1", 0)) + sock.listen(10) + await check(SocketListener(sock)) + + # Listener bound to IPv4 wildcard (needs special handling) + sock = tsocket.socket() + await sock.bind(("0.0.0.0", 0)) + sock.listen(10) + await check(SocketListener(sock)) + + if can_bind_ipv6: + # Listener bound to IPv6 wildcard (needs special handling) + sock = tsocket.socket(family=tsocket.AF_INET6) + await sock.bind(("::", 0)) + sock.listen(10) + await check(SocketListener(sock)) + + if hasattr(tsocket, "AF_UNIX"): + # Listener bound to Unix-domain socket + sock = tsocket.socket(family=tsocket.AF_UNIX) + # can't use pytest's tmpdir; if we try then macOS says "OSError: + # AF_UNIX path too long" + with tempfile.TemporaryDirectory() as tmpdir: + path = "{}/sock".format(tmpdir) + await sock.bind(path) + sock.listen(10) + await check(SocketListener(sock)) diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_threads.py b/venv/lib/python3.9/site-packages/trio/tests/test_threads.py new file mode 100644 index 00000000..baff1827 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_threads.py @@ -0,0 +1,752 @@ +import contextvars +import threading +import queue as stdlib_queue +import time +import weakref + +import pytest +from sniffio import current_async_library_cvar +from trio._core import TrioToken, current_trio_token + +from .. import _core +from .. import Event, CapacityLimiter, sleep +from ..testing import wait_all_tasks_blocked +from .._core.tests.tutil import buggy_pypy_asyncgens +from .._threads import ( + to_thread_run_sync, + current_default_thread_limiter, + from_thread_run, + from_thread_run_sync, +) + +from .._core.tests.test_ki import ki_self + + +async def test_do_in_trio_thread(): + trio_thread = threading.current_thread() + + async def check_case(do_in_trio_thread, fn, expected, trio_token=None): + record = [] + + def threadfn(): + try: + record.append(("start", threading.current_thread())) + x = do_in_trio_thread(fn, record, trio_token=trio_token) + record.append(("got", x)) + except BaseException as exc: + print(exc) + record.append(("error", type(exc))) + + child_thread = threading.Thread(target=threadfn, daemon=True) + child_thread.start() + while child_thread.is_alive(): + print("yawn") + await sleep(0.01) + assert record == [("start", child_thread), ("f", trio_thread), expected] + + token = _core.current_trio_token() + + def f(record): + assert not _core.currently_ki_protected() + record.append(("f", threading.current_thread())) + return 2 + + await check_case(from_thread_run_sync, f, ("got", 2), trio_token=token) + + def f(record): + assert not _core.currently_ki_protected() + record.append(("f", threading.current_thread())) + raise ValueError + + await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token) + + async def f(record): + assert not _core.currently_ki_protected() + await _core.checkpoint() + record.append(("f", threading.current_thread())) + return 3 + + await check_case(from_thread_run, f, ("got", 3), trio_token=token) + + async def f(record): + assert not _core.currently_ki_protected() + await _core.checkpoint() + record.append(("f", threading.current_thread())) + raise KeyError + + await check_case(from_thread_run, f, ("error", KeyError), trio_token=token) + + +async def test_do_in_trio_thread_from_trio_thread(): + with pytest.raises(RuntimeError): + from_thread_run_sync(lambda: None) # pragma: no branch + + async def foo(): # pragma: no cover + pass + + with pytest.raises(RuntimeError): + from_thread_run(foo) + + +def test_run_in_trio_thread_ki(): + # if we get a control-C during a run_in_trio_thread, then it propagates + # back to the caller (slick!) + record = set() + + async def check_run_in_trio_thread(): + token = _core.current_trio_token() + + def trio_thread_fn(): + print("in Trio thread") + assert not _core.currently_ki_protected() + print("ki_self") + try: + ki_self() + finally: + import sys + + print("finally", sys.exc_info()) + + async def trio_thread_afn(): + trio_thread_fn() + + def external_thread_fn(): + try: + print("running") + from_thread_run_sync(trio_thread_fn, trio_token=token) + except KeyboardInterrupt: + print("ok1") + record.add("ok1") + try: + from_thread_run(trio_thread_afn, trio_token=token) + except KeyboardInterrupt: + print("ok2") + record.add("ok2") + + thread = threading.Thread(target=external_thread_fn) + thread.start() + print("waiting") + while thread.is_alive(): + await sleep(0.01) + print("waited, joining") + thread.join() + print("done") + + _core.run(check_run_in_trio_thread) + assert record == {"ok1", "ok2"} + + +def test_await_in_trio_thread_while_main_exits(): + record = [] + ev = Event() + + async def trio_fn(): + record.append("sleeping") + ev.set() + await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) + + def thread_fn(token): + try: + from_thread_run(trio_fn, trio_token=token) + except _core.Cancelled: + record.append("cancelled") + + async def main(): + token = _core.current_trio_token() + thread = threading.Thread(target=thread_fn, args=(token,)) + thread.start() + await ev.wait() + assert record == ["sleeping"] + return thread + + thread = _core.run(main) + thread.join() + assert record == ["sleeping", "cancelled"] + + +async def test_run_in_worker_thread(): + trio_thread = threading.current_thread() + + def f(x): + return (x, threading.current_thread()) + + x, child_thread = await to_thread_run_sync(f, 1) + assert x == 1 + assert child_thread != trio_thread + + def g(): + raise ValueError(threading.current_thread()) + + with pytest.raises(ValueError) as excinfo: + await to_thread_run_sync(g) + print(excinfo.value.args) + assert excinfo.value.args[0] != trio_thread + + +async def test_run_in_worker_thread_cancellation(): + register = [None] + + def f(q): + # Make the thread block for a controlled amount of time + register[0] = "blocking" + q.get() + register[0] = "finished" + + async def child(q, cancellable): + record.append("start") + try: + return await to_thread_run_sync(f, q, cancellable=cancellable) + finally: + record.append("exit") + + record = [] + q = stdlib_queue.Queue() + async with _core.open_nursery() as nursery: + nursery.start_soon(child, q, True) + # Give it a chance to get started. (This is important because + # to_thread_run_sync does a checkpoint_if_cancelled before + # blocking on the thread, and we don't want to trigger this.) + await wait_all_tasks_blocked() + assert record == ["start"] + # Then cancel it. + nursery.cancel_scope.cancel() + # The task exited, but the thread didn't: + assert register[0] != "finished" + # Put the thread out of its misery: + q.put(None) + while register[0] != "finished": + time.sleep(0.01) + + # This one can't be cancelled + record = [] + register[0] = None + async with _core.open_nursery() as nursery: + nursery.start_soon(child, q, False) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + with _core.CancelScope(shield=True): + for _ in range(10): + await _core.checkpoint() + # It's still running + assert record == ["start"] + q.put(None) + # Now it exits + + # But if we cancel *before* it enters, the entry is itself a cancellation + # point + with _core.CancelScope() as scope: + scope.cancel() + await child(q, False) + assert scope.cancelled_caught + + +# Make sure that if trio.run exits, and then the thread finishes, then that's +# handled gracefully. (Requires that the thread result machinery be prepared +# for call_soon to raise RunFinishedError.) +def test_run_in_worker_thread_abandoned(capfd, monkeypatch): + monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) + + q1 = stdlib_queue.Queue() + q2 = stdlib_queue.Queue() + + def thread_fn(): + q1.get() + q2.put(threading.current_thread()) + + async def main(): + async def child(): + await to_thread_run_sync(thread_fn, cancellable=True) + + async with _core.open_nursery() as nursery: + nursery.start_soon(child) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + _core.run(main) + + q1.put(None) + # This makes sure: + # - the thread actually ran + # - that thread has finished before we check for its output + thread = q2.get() + while thread.is_alive(): + time.sleep(0.01) # pragma: no cover + + # Make sure we don't have a "Exception in thread ..." dump to the console: + out, err = capfd.readouterr() + assert "Exception in thread" not in out + assert "Exception in thread" not in err + + +@pytest.mark.parametrize("MAX", [3, 5, 10]) +@pytest.mark.parametrize("cancel", [False, True]) +@pytest.mark.parametrize("use_default_limiter", [False, True]) +async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter): + # This test is a bit tricky. The goal is to make sure that if we set + # limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever + # running at a time, even if there are more concurrent calls to + # to_thread_run_sync, and even if some of those are cancelled. And + # also to make sure that the default limiter actually limits. + COUNT = 2 * MAX + gate = threading.Event() + lock = threading.Lock() + if use_default_limiter: + c = current_default_thread_limiter() + orig_total_tokens = c.total_tokens + c.total_tokens = MAX + limiter_arg = None + else: + c = CapacityLimiter(MAX) + orig_total_tokens = MAX + limiter_arg = c + try: + # We used to use regular variables and 'nonlocal' here, but it turns + # out that it's not safe to assign to closed-over variables that are + # visible in multiple threads, at least as of CPython 3.10 and PyPy + # 7.3: + # + # https://bugs.python.org/issue30744 + # https://bitbucket.org/pypy/pypy/issues/2591/ + # + # Mutating them in-place is OK though (as long as you use proper + # locking etc.). + class state: + pass + + state.ran = 0 + state.high_water = 0 + state.running = 0 + state.parked = 0 + + token = _core.current_trio_token() + + def thread_fn(cancel_scope): + print("thread_fn start") + from_thread_run_sync(cancel_scope.cancel, trio_token=token) + with lock: + state.ran += 1 + state.running += 1 + state.high_water = max(state.high_water, state.running) + # The Trio thread below watches this value and uses it as a + # signal that all the stats calculations have finished. + state.parked += 1 + gate.wait() + with lock: + state.parked -= 1 + state.running -= 1 + print("thread_fn exiting") + + async def run_thread(event): + with _core.CancelScope() as cancel_scope: + await to_thread_run_sync( + thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel + ) + print("run_thread finished, cancelled:", cancel_scope.cancelled_caught) + event.set() + + async with _core.open_nursery() as nursery: + print("spawning") + events = [] + for i in range(COUNT): + events.append(Event()) + nursery.start_soon(run_thread, events[-1]) + await wait_all_tasks_blocked() + # In the cancel case, we in particular want to make sure that the + # cancelled tasks don't release the semaphore. So let's wait until + # at least one of them has exited, and that everything has had a + # chance to settle down from this, before we check that everyone + # who's supposed to be waiting is waiting: + if cancel: + print("waiting for first cancellation to clear") + await events[0].wait() + await wait_all_tasks_blocked() + # Then wait until the first MAX threads are parked in gate.wait(), + # and the next MAX threads are parked on the semaphore, to make + # sure no-one is sneaking past, and to make sure the high_water + # check below won't fail due to scheduling issues. (It could still + # fail if too many threads are let through here.) + while state.parked != MAX or c.statistics().tasks_waiting != MAX: + await sleep(0.01) # pragma: no cover + # Then release the threads + gate.set() + + assert state.high_water == MAX + + if cancel: + # Some threads might still be running; need to wait to them to + # finish before checking that all threads ran. We can do this + # using the CapacityLimiter. + while c.borrowed_tokens > 0: + await sleep(0.01) # pragma: no cover + + assert state.ran == COUNT + assert state.running == 0 + finally: + c.total_tokens = orig_total_tokens + + +async def test_run_in_worker_thread_custom_limiter(): + # Basically just checking that we only call acquire_on_behalf_of and + # release_on_behalf_of, since that's part of our documented API. + record = [] + + class CustomLimiter: + async def acquire_on_behalf_of(self, borrower): + record.append("acquire") + self._borrower = borrower + + def release_on_behalf_of(self, borrower): + record.append("release") + assert borrower == self._borrower + + await to_thread_run_sync(lambda: None, limiter=CustomLimiter()) + assert record == ["acquire", "release"] + + +async def test_run_in_worker_thread_limiter_error(): + record = [] + + class BadCapacityLimiter: + async def acquire_on_behalf_of(self, borrower): + record.append("acquire") + + def release_on_behalf_of(self, borrower): + record.append("release") + raise ValueError + + bs = BadCapacityLimiter() + + with pytest.raises(ValueError) as excinfo: + await to_thread_run_sync(lambda: None, limiter=bs) + assert excinfo.value.__context__ is None + assert record == ["acquire", "release"] + record = [] + + # If the original function raised an error, then the semaphore error + # chains with it + d = {} + with pytest.raises(ValueError) as excinfo: + await to_thread_run_sync(lambda: d["x"], limiter=bs) + assert isinstance(excinfo.value.__context__, KeyError) + assert record == ["acquire", "release"] + + +async def test_run_in_worker_thread_fail_to_spawn(monkeypatch): + # Test the unlikely but possible case where trying to spawn a thread fails + def bad_start(self, *args): + raise RuntimeError("the engines canna take it captain") + + monkeypatch.setattr(_core._thread_cache.ThreadCache, "start_thread_soon", bad_start) + + limiter = current_default_thread_limiter() + assert limiter.borrowed_tokens == 0 + + # We get an appropriate error, and the limiter is cleanly released + with pytest.raises(RuntimeError) as excinfo: + await to_thread_run_sync(lambda: None) # pragma: no cover + assert "engines" in str(excinfo.value) + + assert limiter.borrowed_tokens == 0 + + +async def test_trio_to_thread_run_sync_token(): + # Test that to_thread_run_sync automatically injects the current trio token + # into a spawned thread + def thread_fn(): + callee_token = from_thread_run_sync(_core.current_trio_token) + return callee_token + + caller_token = _core.current_trio_token() + callee_token = await to_thread_run_sync(thread_fn) + assert callee_token == caller_token + + +async def test_trio_to_thread_run_sync_expected_error(): + # Test correct error when passed async function + async def async_fn(): # pragma: no cover + pass + + with pytest.raises(TypeError, match="expected a sync function"): + await to_thread_run_sync(async_fn) + + +trio_test_contextvar = contextvars.ContextVar("trio_test_contextvar") + + +async def test_trio_to_thread_run_sync_contextvars(): + trio_thread = threading.current_thread() + trio_test_contextvar.set("main") + + def f(): + value = trio_test_contextvar.get() + sniffio_cvar_value = current_async_library_cvar.get() + return (value, sniffio_cvar_value, threading.current_thread()) + + value, sniffio_cvar_value, child_thread = await to_thread_run_sync(f) + assert value == "main" + assert sniffio_cvar_value == None + assert child_thread != trio_thread + + def g(): + parent_value = trio_test_contextvar.get() + trio_test_contextvar.set("worker") + inner_value = trio_test_contextvar.get() + sniffio_cvar_value = current_async_library_cvar.get() + return ( + parent_value, + inner_value, + sniffio_cvar_value, + threading.current_thread(), + ) + + ( + parent_value, + inner_value, + sniffio_cvar_value, + child_thread, + ) = await to_thread_run_sync(g) + current_value = trio_test_contextvar.get() + sniffio_outer_value = current_async_library_cvar.get() + assert parent_value == "main" + assert inner_value == "worker" + assert ( + current_value == "main" + ), "The contextvar value set on the worker would not propagate back to the main thread" + assert sniffio_cvar_value is None + assert sniffio_outer_value == "trio" + + +async def test_trio_from_thread_run_sync(): + # Test that to_thread_run_sync correctly "hands off" the trio token to + # trio.from_thread.run_sync() + def thread_fn(): + trio_time = from_thread_run_sync(_core.current_time) + return trio_time + + trio_time = await to_thread_run_sync(thread_fn) + assert isinstance(trio_time, float) + + # Test correct error when passed async function + async def async_fn(): # pragma: no cover + pass + + def thread_fn(): + from_thread_run_sync(async_fn) + + with pytest.raises(TypeError, match="expected a sync function"): + await to_thread_run_sync(thread_fn) + + +async def test_trio_from_thread_run(): + # Test that to_thread_run_sync correctly "hands off" the trio token to + # trio.from_thread.run() + record = [] + + async def back_in_trio_fn(): + _core.current_time() # implicitly checks that we're in trio + record.append("back in trio") + + def thread_fn(): + record.append("in thread") + from_thread_run(back_in_trio_fn) + + await to_thread_run_sync(thread_fn) + assert record == ["in thread", "back in trio"] + + # Test correct error when passed sync function + def sync_fn(): # pragma: no cover + pass + + with pytest.raises(TypeError, match="appears to be synchronous"): + await to_thread_run_sync(from_thread_run, sync_fn) + + +async def test_trio_from_thread_token(): + # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() + # share the same Trio token + def thread_fn(): + callee_token = from_thread_run_sync(_core.current_trio_token) + return callee_token + + caller_token = _core.current_trio_token() + callee_token = await to_thread_run_sync(thread_fn) + assert callee_token == caller_token + + +async def test_trio_from_thread_token_kwarg(): + # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can + # use an explicitly defined token + def thread_fn(token): + callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token) + return callee_token + + caller_token = _core.current_trio_token() + callee_token = await to_thread_run_sync(thread_fn, caller_token) + assert callee_token == caller_token + + +async def test_from_thread_no_token(): + # Test that a "raw call" to trio.from_thread.run() fails because no token + # has been provided + + with pytest.raises(RuntimeError): + from_thread_run_sync(_core.current_time) + + +async def test_trio_from_thread_run_sync_contextvars(): + trio_test_contextvar.set("main") + + def thread_fn(): + thread_parent_value = trio_test_contextvar.get() + trio_test_contextvar.set("worker") + thread_current_value = trio_test_contextvar.get() + sniffio_cvar_thread_pre_value = current_async_library_cvar.get() + + def back_in_main(): + back_parent_value = trio_test_contextvar.get() + trio_test_contextvar.set("back_in_main") + back_current_value = trio_test_contextvar.get() + sniffio_cvar_back_value = current_async_library_cvar.get() + return back_parent_value, back_current_value, sniffio_cvar_back_value + + ( + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = from_thread_run_sync(back_in_main) + thread_after_value = trio_test_contextvar.get() + sniffio_cvar_thread_after_value = current_async_library_cvar.get() + return ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) + + ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = await to_thread_run_sync(thread_fn) + current_value = trio_test_contextvar.get() + sniffio_cvar_out_value = current_async_library_cvar.get() + assert current_value == thread_parent_value == "main" + assert thread_current_value == back_parent_value == thread_after_value == "worker" + assert back_current_value == "back_in_main" + assert sniffio_cvar_out_value == sniffio_cvar_back_value == "trio" + assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None + + +async def test_trio_from_thread_run_contextvars(): + trio_test_contextvar.set("main") + + def thread_fn(): + thread_parent_value = trio_test_contextvar.get() + trio_test_contextvar.set("worker") + thread_current_value = trio_test_contextvar.get() + sniffio_cvar_thread_pre_value = current_async_library_cvar.get() + + async def async_back_in_main(): + back_parent_value = trio_test_contextvar.get() + trio_test_contextvar.set("back_in_main") + back_current_value = trio_test_contextvar.get() + sniffio_cvar_back_value = current_async_library_cvar.get() + return back_parent_value, back_current_value, sniffio_cvar_back_value + + ( + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = from_thread_run(async_back_in_main) + thread_after_value = trio_test_contextvar.get() + sniffio_cvar_thread_after_value = current_async_library_cvar.get() + return ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) + + ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = await to_thread_run_sync(thread_fn) + current_value = trio_test_contextvar.get() + assert current_value == thread_parent_value == "main" + assert thread_current_value == back_parent_value == thread_after_value == "worker" + assert back_current_value == "back_in_main" + assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None + assert sniffio_cvar_back_value == "trio" + + +def test_run_fn_as_system_task_catched_badly_typed_token(): + with pytest.raises(RuntimeError): + from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype") + + +async def test_from_thread_inside_trio_thread(): + def not_called(): # pragma: no cover + assert False + + trio_token = _core.current_trio_token() + with pytest.raises(RuntimeError): + from_thread_run_sync(not_called, trio_token=trio_token) + + +@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") +def test_from_thread_run_during_shutdown(): + save = [] + record = [] + + async def agen(): + try: + yield + finally: + with pytest.raises(_core.RunFinishedError), _core.CancelScope(shield=True): + await to_thread_run_sync(from_thread_run, sleep, 0) + record.append("ok") + + async def main(): + save.append(agen()) + await save[-1].asend(None) + + _core.run(main) + assert record == ["ok"] + + +async def test_trio_token_weak_referenceable(): + token = current_trio_token() + assert isinstance(token, TrioToken) + weak_reference = weakref.ref(token) + assert token is weak_reference() + + +async def test_unsafe_cancellable_kwarg(): + + # This is a stand in for a numpy ndarray or other objects + # that (maybe surprisingly) lack a notion of truthiness + class BadBool: + def __bool__(self): + raise NotImplementedError + + with pytest.raises(NotImplementedError): + await to_thread_run_sync(int, cancellable=BadBool()) diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_timeouts.py b/venv/lib/python3.9/site-packages/trio/tests/test_timeouts.py new file mode 100644 index 00000000..382c015b --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_timeouts.py @@ -0,0 +1,104 @@ +import outcome +import pytest +import time + +from .._core.tests.tutil import slow +from .. import _core +from ..testing import assert_checkpoints +from .._timeouts import * + + +async def check_takes_about(f, expected_dur): + start = time.perf_counter() + result = await outcome.acapture(f) + dur = time.perf_counter() - start + print(dur / expected_dur) + # 1.5 is an arbitrary fudge factor because there's always some delay + # between when we become eligible to wake up and when we actually do. We + # used to sleep for 0.05, and regularly observed overruns of 1.6x on + # Appveyor, and then started seeing overruns of 2.3x on Travis's macOS, so + # now we bumped up the sleep to 1 second, marked the tests as slow, and + # hopefully now the proportional error will be less huge. + # + # We also also for durations that are a hair shorter than expected. For + # example, here's a run on Windows where a 1.0 second sleep was measured + # to take 0.9999999999999858 seconds: + # https://ci.appveyor.com/project/njsmith/trio/build/1.0.768/job/3lbdyxl63q3h9s21 + # I believe that what happened here is that Windows's low clock resolution + # meant that our calls to time.monotonic() returned exactly the same + # values as the calls inside the actual run loop, but the two subtractions + # returned slightly different values because the run loop's clock adds a + # random floating point offset to both times, which should cancel out, but + # lol floating point we got slightly different rounding errors. (That + # value above is exactly 128 ULPs below 1.0, which would make sense if it + # started as a 1 ULP error at a different dynamic range.) + assert (1 - 1e-8) <= (dur / expected_dur) < 1.5 + return result.unwrap() + + +# How long to (attempt to) sleep for when testing. Smaller numbers make the +# test suite go faster. +TARGET = 1.0 + + +@slow +async def test_sleep(): + async def sleep_1(): + await sleep_until(_core.current_time() + TARGET) + + await check_takes_about(sleep_1, TARGET) + + async def sleep_2(): + await sleep(TARGET) + + await check_takes_about(sleep_2, TARGET) + + with pytest.raises(ValueError): + await sleep(-1) + + with assert_checkpoints(): + await sleep(0) + # This also serves as a test of the trivial move_on_at + with move_on_at(_core.current_time()): + with pytest.raises(_core.Cancelled): + await sleep(0) + + +@slow +async def test_move_on_after(): + with pytest.raises(ValueError): + with move_on_after(-1): + pass # pragma: no cover + + async def sleep_3(): + with move_on_after(TARGET): + await sleep(100) + + await check_takes_about(sleep_3, TARGET) + + +@slow +async def test_fail(): + async def sleep_4(): + with fail_at(_core.current_time() + TARGET): + await sleep(100) + + with pytest.raises(TooSlowError): + await check_takes_about(sleep_4, TARGET) + + with fail_at(_core.current_time() + 100): + await sleep(0) + + async def sleep_5(): + with fail_after(TARGET): + await sleep(100) + + with pytest.raises(TooSlowError): + await check_takes_about(sleep_5, TARGET) + + with fail_after(100): + await sleep(0) + + with pytest.raises(ValueError): + with fail_after(-1): + pass # pragma: no cover diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_unix_pipes.py b/venv/lib/python3.9/site-packages/trio/tests/test_unix_pipes.py new file mode 100644 index 00000000..cf98942e --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_unix_pipes.py @@ -0,0 +1,276 @@ +import errno +import select +import os +import tempfile +import sys + +import pytest + +from .._core.tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken +from .. import _core, move_on_after +from ..testing import wait_all_tasks_blocked, check_one_way_stream + +posix = os.name == "posix" +pytestmark = pytest.mark.skipif(not posix, reason="posix only") +if posix: + from .._unix_pipes import FdStream +else: + with pytest.raises(ImportError): + from .._unix_pipes import FdStream + + +# Have to use quoted types so import doesn't crash on windows +async def make_pipe() -> "Tuple[FdStream, FdStream]": + """Makes a new pair of pipes.""" + (r, w) = os.pipe() + return FdStream(w), FdStream(r) + + +async def make_clogged_pipe(): + s, r = await make_pipe() + try: + while True: + # We want to totally fill up the pipe buffer. + # This requires working around a weird feature that POSIX pipes + # have. + # If you do a write of <= PIPE_BUF bytes, then it's guaranteed + # to either complete entirely, or not at all. So if we tried to + # write PIPE_BUF bytes, and the buffer's free space is only + # PIPE_BUF/2, then the write will raise BlockingIOError... even + # though a smaller write could still succeed! To avoid this, + # make sure to write >PIPE_BUF bytes each time, which disables + # the special behavior. + # For details, search for PIPE_BUF here: + # http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html + + # for the getattr: + # https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3 + buf_size = getattr(select, "PIPE_BUF", 8192) + os.write(s.fileno(), b"x" * buf_size * 2) + except BlockingIOError: + pass + return s, r + + +async def test_send_pipe(): + r, w = os.pipe() + async with FdStream(w) as send: + assert send.fileno() == w + await send.send_all(b"123") + assert (os.read(r, 8)) == b"123" + + os.close(r) + + +async def test_receive_pipe(): + r, w = os.pipe() + async with FdStream(r) as recv: + assert (recv.fileno()) == r + os.write(w, b"123") + assert (await recv.receive_some(8)) == b"123" + + os.close(w) + + +async def test_pipes_combined(): + write, read = await make_pipe() + count = 2**20 + + async def sender(): + big = bytearray(count) + await write.send_all(big) + + async def reader(): + await wait_all_tasks_blocked() + received = 0 + while received < count: + received += len(await read.receive_some(4096)) + + assert received == count + + async with _core.open_nursery() as n: + n.start_soon(sender) + n.start_soon(reader) + + await read.aclose() + await write.aclose() + + +async def test_pipe_errors(): + with pytest.raises(TypeError): + FdStream(None) + + r, w = os.pipe() + os.close(w) + async with FdStream(r) as s: + with pytest.raises(ValueError): + await s.receive_some(0) + + +async def test_del(): + w, r = await make_pipe() + f1, f2 = w.fileno(), r.fileno() + del w, r + gc_collect_harder() + + with pytest.raises(OSError) as excinfo: + os.close(f1) + assert excinfo.value.errno == errno.EBADF + + with pytest.raises(OSError) as excinfo: + os.close(f2) + assert excinfo.value.errno == errno.EBADF + + +async def test_async_with(): + w, r = await make_pipe() + async with w, r: + pass + + assert w.fileno() == -1 + assert r.fileno() == -1 + + with pytest.raises(OSError) as excinfo: + os.close(w.fileno()) + assert excinfo.value.errno == errno.EBADF + + with pytest.raises(OSError) as excinfo: + os.close(r.fileno()) + assert excinfo.value.errno == errno.EBADF + + +async def test_misdirected_aclose_regression(): + # https://github.com/python-trio/trio/issues/661#issuecomment-456582356 + w, r = await make_pipe() + old_r_fd = r.fileno() + + # Close the original objects + await w.aclose() + await r.aclose() + + # Do a little dance to get a new pipe whose receive handle matches the old + # receive handle. + r2_fd, w2_fd = os.pipe() + if r2_fd != old_r_fd: # pragma: no cover + os.dup2(r2_fd, old_r_fd) + os.close(r2_fd) + async with FdStream(old_r_fd) as r2: + assert r2.fileno() == old_r_fd + + # And now set up a background task that's working on the new receive + # handle + async def expect_eof(): + assert await r2.receive_some(10) == b"" + + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_eof) + await wait_all_tasks_blocked() + + # Here's the key test: does calling aclose() again on the *old* + # handle, cause the task blocked on the *new* handle to raise + # ClosedResourceError? + await r.aclose() + await wait_all_tasks_blocked() + + # Guess we survived! Close the new write handle so that the task + # gets an EOF and can exit cleanly. + os.close(w2_fd) + + +async def test_close_at_bad_time_for_receive_some(monkeypatch): + # We used to have race conditions where if one task was using the pipe, + # and another closed it at *just* the wrong moment, it would give an + # unexpected error instead of ClosedResourceError: + # https://github.com/python-trio/trio/issues/661 + # + # This tests what happens if the pipe gets closed in the moment *between* + # when receive_some wakes up, and when it tries to call os.read + async def expect_closedresourceerror(): + with pytest.raises(_core.ClosedResourceError): + await r.receive_some(10) + + orig_wait_readable = _core._run.TheIOManager.wait_readable + + async def patched_wait_readable(*args, **kwargs): + await orig_wait_readable(*args, **kwargs) + await r.aclose() + + monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable) + s, r = await make_pipe() + async with s, r: + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_closedresourceerror) + await wait_all_tasks_blocked() + # Trigger everything by waking up the receiver + await s.send_all(b"x") + + +async def test_close_at_bad_time_for_send_all(monkeypatch): + # We used to have race conditions where if one task was using the pipe, + # and another closed it at *just* the wrong moment, it would give an + # unexpected error instead of ClosedResourceError: + # https://github.com/python-trio/trio/issues/661 + # + # This tests what happens if the pipe gets closed in the moment *between* + # when send_all wakes up, and when it tries to call os.write + async def expect_closedresourceerror(): + with pytest.raises(_core.ClosedResourceError): + await s.send_all(b"x" * 100) + + orig_wait_writable = _core._run.TheIOManager.wait_writable + + async def patched_wait_writable(*args, **kwargs): + await orig_wait_writable(*args, **kwargs) + await s.aclose() + + monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable) + s, r = await make_clogged_pipe() + async with s, r: + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_closedresourceerror) + await wait_all_tasks_blocked() + # Trigger everything by waking up the sender. On ppc64el, PIPE_BUF + # is 8192 but make_clogged_pipe() ends up writing a total of + # 1048576 bytes before the pipe is full, and then a subsequent + # receive_some(10000) isn't sufficient for orig_wait_writable() to + # return for our subsequent aclose() call. It's necessary to empty + # the pipe further before this happens. So we loop here until the + # pipe is empty to make sure that the sender wakes up even in this + # case. Otherwise patched_wait_writable() never gets to the + # aclose(), so expect_closedresourceerror() never returns, the + # nursery never finishes all tasks and this test hangs. + received_data = await r.receive_some(10000) + while received_data: + received_data = await r.receive_some(10000) + + +# On FreeBSD, directories are readable, and we haven't found any other trick +# for making an unreadable fd, so there's no way to run this test. Fortunately +# the logic this is testing doesn't depend on the platform, so testing on +# other platforms is probably good enough. +@pytest.mark.skipif( + sys.platform.startswith("freebsd"), + reason="no way to make read() return a bizarro error on FreeBSD", +) +async def test_bizarro_OSError_from_receive(): + # Make sure that if the read syscall returns some bizarro error, then we + # get a BrokenResourceError. This is incredibly unlikely; there's almost + # no way to trigger a failure here intentionally (except for EBADF, but we + # exploit that to detect file closure, so it takes a different path). So + # we set up a strange scenario where the pipe fd somehow transmutes into a + # directory fd, causing os.read to raise IsADirectoryError (yes, that's a + # real built-in exception type). + s, r = await make_pipe() + async with s, r: + dir_fd = os.open("/", os.O_DIRECTORY, 0) + try: + os.dup2(dir_fd, r.fileno()) + with pytest.raises(_core.BrokenResourceError): + await r.receive_some(10) + finally: + os.close(dir_fd) + + +@skip_if_fbsd_pipes_broken +async def test_pipe_fully(): + await check_one_way_stream(make_pipe, make_clogged_pipe) diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_util.py b/venv/lib/python3.9/site-packages/trio/tests/test_util.py new file mode 100644 index 00000000..15ab09a8 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_util.py @@ -0,0 +1,193 @@ +import signal +import sys + +import pytest + +import trio +from .. import _core +from .._core.tests.tutil import ( + ignore_coroutine_never_awaited_warnings, + create_asyncio_future_in_new_loop, +) +from .._util import ( + signal_raise, + ConflictDetector, + is_main_thread, + coroutine_or_error, + generic_function, + Final, + NoPublicConstructor, +) +from ..testing import wait_all_tasks_blocked + + +def test_signal_raise(): + record = [] + + def handler(signum, _): + record.append(signum) + + old = signal.signal(signal.SIGFPE, handler) + try: + signal_raise(signal.SIGFPE) + finally: + signal.signal(signal.SIGFPE, old) + assert record == [signal.SIGFPE] + + +async def test_ConflictDetector(): + ul1 = ConflictDetector("ul1") + ul2 = ConflictDetector("ul2") + + with ul1: + with ul2: + print("ok") + + with pytest.raises(_core.BusyResourceError) as excinfo: + with ul1: + with ul1: + pass # pragma: no cover + assert "ul1" in str(excinfo.value) + + async def wait_with_ul1(): + with ul1: + await wait_all_tasks_blocked() + + with pytest.raises(_core.BusyResourceError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.start_soon(wait_with_ul1) + nursery.start_soon(wait_with_ul1) + assert "ul1" in str(excinfo.value) + + +def test_module_metadata_is_fixed_up(): + import trio + import trio.testing + + assert trio.Cancelled.__module__ == "trio" + assert trio.open_nursery.__module__ == "trio" + assert trio.abc.Stream.__module__ == "trio.abc" + assert trio.lowlevel.wait_task_rescheduled.__module__ == "trio.lowlevel" + assert trio.testing.trio_test.__module__ == "trio.testing" + + # Also check methods + assert trio.lowlevel.ParkingLot.__init__.__module__ == "trio.lowlevel" + assert trio.abc.Stream.send_all.__module__ == "trio.abc" + + # And names + assert trio.Cancelled.__name__ == "Cancelled" + assert trio.Cancelled.__qualname__ == "Cancelled" + assert trio.abc.SendStream.send_all.__name__ == "send_all" + assert trio.abc.SendStream.send_all.__qualname__ == "SendStream.send_all" + assert trio.to_thread.__name__ == "trio.to_thread" + assert trio.to_thread.run_sync.__name__ == "run_sync" + assert trio.to_thread.run_sync.__qualname__ == "run_sync" + + +async def test_is_main_thread(): + assert is_main_thread() + + def not_main_thread(): + assert not is_main_thread() + + await trio.to_thread.run_sync(not_main_thread) + + +# @coroutine is deprecated since python 3.8, which is fine with us. +@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning") +def test_coroutine_or_error(): + class Deferred: + "Just kidding" + + with ignore_coroutine_never_awaited_warnings(): + + async def f(): # pragma: no cover + pass + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(f()) + assert "expecting an async function" in str(excinfo.value) + + import asyncio + + if sys.version_info < (3, 11): + + @asyncio.coroutine + def generator_based_coro(): # pragma: no cover + yield from asyncio.sleep(1) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(generator_based_coro()) + assert "asyncio" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(create_asyncio_future_in_new_loop()) + assert "asyncio" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(create_asyncio_future_in_new_loop) + assert "asyncio" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(Deferred()) + assert "twisted" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(lambda: Deferred()) + assert "twisted" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(len, [[1, 2, 3]]) + + assert "appears to be synchronous" in str(excinfo.value) + + async def async_gen(arg): # pragma: no cover + yield + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(async_gen, [0]) + msg = "expected an async function but got an async generator" + assert msg in str(excinfo.value) + + # Make sure no references are kept around to keep anything alive + del excinfo + + +def test_generic_function(): + @generic_function + def test_func(arg): + """Look, a docstring!""" + return arg + + assert test_func is test_func[int] is test_func[int, str] + assert test_func(42) == test_func[int](42) == 42 + assert test_func.__doc__ == "Look, a docstring!" + assert test_func.__qualname__ == "test_generic_function..test_func" + assert test_func.__name__ == "test_func" + assert test_func.__module__ == __name__ + + +def test_final_metaclass(): + class FinalClass(metaclass=Final): + pass + + with pytest.raises(TypeError): + + class SubClass(FinalClass): + pass + + +def test_no_public_constructor_metaclass(): + class SpecialClass(metaclass=NoPublicConstructor): + pass + + with pytest.raises(TypeError): + SpecialClass() + + with pytest.raises(TypeError): + + class SubClass(SpecialClass): + pass + + # Private constructor should not raise + assert isinstance(SpecialClass._create(), SpecialClass) diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_wait_for_object.py b/venv/lib/python3.9/site-packages/trio/tests/test_wait_for_object.py new file mode 100644 index 00000000..38acfa80 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_wait_for_object.py @@ -0,0 +1,220 @@ +import os + +import pytest + +on_windows = os.name == "nt" +# Mark all the tests in this file as being windows-only +pytestmark = pytest.mark.skipif(not on_windows, reason="windows only") + +from .._core.tests.tutil import slow +import trio +from .. import _core +from .. import _timeouts + +if on_windows: + from .._core._windows_cffi import ffi, kernel32 + from .._wait_for_object import ( + WaitForSingleObject, + WaitForMultipleObjects_sync, + ) + + +async def test_WaitForMultipleObjects_sync(): + # This does a series of tests where we set/close the handle before + # initiating the waiting for it. + # + # Note that closing the handle (not signaling) will cause the + # *initiation* of a wait to return immediately. But closing a handle + # that is already being waited on will not stop whatever is waiting + # for it. + + # One handle + handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + kernel32.SetEvent(handle1) + WaitForMultipleObjects_sync(handle1) + kernel32.CloseHandle(handle1) + print("test_WaitForMultipleObjects_sync one OK") + + # Two handles, signal first + handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + kernel32.SetEvent(handle1) + WaitForMultipleObjects_sync(handle1, handle2) + kernel32.CloseHandle(handle1) + kernel32.CloseHandle(handle2) + print("test_WaitForMultipleObjects_sync set first OK") + + # Two handles, signal second + handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + kernel32.SetEvent(handle2) + WaitForMultipleObjects_sync(handle1, handle2) + kernel32.CloseHandle(handle1) + kernel32.CloseHandle(handle2) + print("test_WaitForMultipleObjects_sync set second OK") + + # Two handles, close first + handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + kernel32.CloseHandle(handle1) + with pytest.raises(OSError): + WaitForMultipleObjects_sync(handle1, handle2) + kernel32.CloseHandle(handle2) + print("test_WaitForMultipleObjects_sync close first OK") + + # Two handles, close second + handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + kernel32.CloseHandle(handle2) + with pytest.raises(OSError): + WaitForMultipleObjects_sync(handle1, handle2) + kernel32.CloseHandle(handle1) + print("test_WaitForMultipleObjects_sync close second OK") + + +@slow +async def test_WaitForMultipleObjects_sync_slow(): + # This does a series of test in which the main thread sync-waits for + # handles, while we spawn a thread to set the handles after a short while. + + TIMEOUT = 0.3 + + # One handle + handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + t0 = _core.current_time() + async with _core.open_nursery() as nursery: + nursery.start_soon( + trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1 + ) + await _timeouts.sleep(TIMEOUT) + # If we would comment the line below, the above thread will be stuck, + # and Trio won't exit this scope + kernel32.SetEvent(handle1) + t1 = _core.current_time() + assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT + kernel32.CloseHandle(handle1) + print("test_WaitForMultipleObjects_sync_slow one OK") + + # Two handles, signal first + handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + t0 = _core.current_time() + async with _core.open_nursery() as nursery: + nursery.start_soon( + trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2 + ) + await _timeouts.sleep(TIMEOUT) + kernel32.SetEvent(handle1) + t1 = _core.current_time() + assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT + kernel32.CloseHandle(handle1) + kernel32.CloseHandle(handle2) + print("test_WaitForMultipleObjects_sync_slow thread-set first OK") + + # Two handles, signal second + handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + t0 = _core.current_time() + async with _core.open_nursery() as nursery: + nursery.start_soon( + trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2 + ) + await _timeouts.sleep(TIMEOUT) + kernel32.SetEvent(handle2) + t1 = _core.current_time() + assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT + kernel32.CloseHandle(handle1) + kernel32.CloseHandle(handle2) + print("test_WaitForMultipleObjects_sync_slow thread-set second OK") + + +async def test_WaitForSingleObject(): + # This does a series of test for setting/closing the handle before + # initiating the wait. + + # Test already set + handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + kernel32.SetEvent(handle) + await WaitForSingleObject(handle) # should return at once + kernel32.CloseHandle(handle) + print("test_WaitForSingleObject already set OK") + + # Test already set, as int + handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + handle_int = int(ffi.cast("intptr_t", handle)) + kernel32.SetEvent(handle) + await WaitForSingleObject(handle_int) # should return at once + kernel32.CloseHandle(handle) + print("test_WaitForSingleObject already set OK") + + # Test already closed + handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + kernel32.CloseHandle(handle) + with pytest.raises(OSError): + await WaitForSingleObject(handle) # should return at once + print("test_WaitForSingleObject already closed OK") + + # Not a handle + with pytest.raises(TypeError): + await WaitForSingleObject("not a handle") # Wrong type + # with pytest.raises(OSError): + # await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :( + print("test_WaitForSingleObject not a handle OK") + + +@slow +async def test_WaitForSingleObject_slow(): + # This does a series of test for setting the handle in another task, + # and cancelling the wait task. + + # Set the timeout used in the tests. We test the waiting time against + # the timeout with a certain margin. + TIMEOUT = 0.3 + + async def signal_soon_async(handle): + await _timeouts.sleep(TIMEOUT) + kernel32.SetEvent(handle) + + # Test handle is SET after TIMEOUT in separate coroutine + + handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + t0 = _core.current_time() + + async with _core.open_nursery() as nursery: + nursery.start_soon(WaitForSingleObject, handle) + nursery.start_soon(signal_soon_async, handle) + + kernel32.CloseHandle(handle) + t1 = _core.current_time() + assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT + print("test_WaitForSingleObject_slow set from task OK") + + # Test handle is SET after TIMEOUT in separate coroutine, as int + + handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + handle_int = int(ffi.cast("intptr_t", handle)) + t0 = _core.current_time() + + async with _core.open_nursery() as nursery: + nursery.start_soon(WaitForSingleObject, handle_int) + nursery.start_soon(signal_soon_async, handle) + + kernel32.CloseHandle(handle) + t1 = _core.current_time() + assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT + print("test_WaitForSingleObject_slow set from task as int OK") + + # Test handle is CLOSED after 1 sec - NOPE see comment above + + # Test cancellation + + handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) + t0 = _core.current_time() + + with _timeouts.move_on_after(TIMEOUT): + await WaitForSingleObject(handle) + + kernel32.CloseHandle(handle) + t1 = _core.current_time() + assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT + print("test_WaitForSingleObject_slow cancellation OK") diff --git a/venv/lib/python3.9/site-packages/trio/tests/test_windows_pipes.py b/venv/lib/python3.9/site-packages/trio/tests/test_windows_pipes.py new file mode 100644 index 00000000..0a6b3516 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/test_windows_pipes.py @@ -0,0 +1,110 @@ +import errno +import select + +import os +import sys +import pytest + +from .._core.tests.tutil import gc_collect_harder +from .. import _core, move_on_after +from ..testing import wait_all_tasks_blocked, check_one_way_stream + +if sys.platform == "win32": + from .._windows_pipes import PipeSendStream, PipeReceiveStream + from .._core._windows_cffi import _handle, kernel32 + from asyncio.windows_utils import pipe +else: + pytestmark = pytest.mark.skip(reason="windows only") + pipe = None # type: Any + PipeSendStream = None # type: Any + PipeReceiveStream = None # type: Any + + +async def make_pipe() -> "Tuple[PipeSendStream, PipeReceiveStream]": + """Makes a new pair of pipes.""" + (r, w) = pipe() + return PipeSendStream(w), PipeReceiveStream(r) + + +async def test_pipe_typecheck(): + with pytest.raises(TypeError): + PipeSendStream(1.0) + with pytest.raises(TypeError): + PipeReceiveStream(None) + + +async def test_pipe_error_on_close(): + # Make sure we correctly handle a failure from kernel32.CloseHandle + r, w = pipe() + + send_stream = PipeSendStream(w) + receive_stream = PipeReceiveStream(r) + + assert kernel32.CloseHandle(_handle(r)) + assert kernel32.CloseHandle(_handle(w)) + + with pytest.raises(OSError): + await send_stream.aclose() + with pytest.raises(OSError): + await receive_stream.aclose() + + +async def test_pipes_combined(): + write, read = await make_pipe() + count = 2**20 + replicas = 3 + + async def sender(): + async with write: + big = bytearray(count) + for _ in range(replicas): + await write.send_all(big) + + async def reader(): + async with read: + await wait_all_tasks_blocked() + total_received = 0 + while True: + # 5000 is chosen because it doesn't evenly divide 2**20 + received = len(await read.receive_some(5000)) + if not received: + break + total_received += received + + assert total_received == count * replicas + + async with _core.open_nursery() as n: + n.start_soon(sender) + n.start_soon(reader) + + +async def test_async_with(): + w, r = await make_pipe() + async with w, r: + pass + + with pytest.raises(_core.ClosedResourceError): + await w.send_all(b"") + with pytest.raises(_core.ClosedResourceError): + await r.receive_some(10) + + +async def test_close_during_write(): + w, r = await make_pipe() + async with _core.open_nursery() as nursery: + + async def write_forever(): + with pytest.raises(_core.ClosedResourceError) as excinfo: + while True: + await w.send_all(b"x" * 4096) + assert "another task" in str(excinfo.value) + + nursery.start_soon(write_forever) + await wait_all_tasks_blocked(0.1) + await w.aclose() + + +async def test_pipe_fully(): + # passing make_clogged_pipe tests wait_send_all_might_not_block, and we + # can't implement that on Windows + await check_one_way_stream(make_pipe, None) diff --git a/venv/lib/python3.9/site-packages/trio/tests/tools/__init__.py b/venv/lib/python3.9/site-packages/trio/tests/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/venv/lib/python3.9/site-packages/trio/tests/tools/test_gen_exports.py b/venv/lib/python3.9/site-packages/trio/tests/tools/test_gen_exports.py new file mode 100644 index 00000000..e4e388c2 --- /dev/null +++ b/venv/lib/python3.9/site-packages/trio/tests/tools/test_gen_exports.py @@ -0,0 +1,72 @@ +import ast +import astor +import pytest +import os +import sys + +from shutil import copyfile +from trio._tools.gen_exports import ( + get_public_methods, + create_passthrough_args, + process, +) + +SOURCE = '''from _run import _public + +class Test: + @_public + def public_func(self): + """With doc string""" + + @ignore_this + @_public + @another_decorator + async def public_async_func(self): + pass # no doc string + + def not_public(self): + pass + + async def not_public_async(self): + pass +''' + + +def test_get_public_methods(): + methods = list(get_public_methods(ast.parse(SOURCE))) + assert {m.name for m in methods} == {"public_func", "public_async_func"} + + +def test_create_pass_through_args(): + testcases = [ + ("def f()", "()"), + ("def f(one)", "(one)"), + ("def f(one, two)", "(one, two)"), + ("def f(one, *args)", "(one, *args)"), + ( + "def f(one, *args, kw1, kw2=None, **kwargs)", + "(one, *args, kw1=kw1, kw2=kw2, **kwargs)", + ), + ] + + for (funcdef, expected) in testcases: + func_node = ast.parse(funcdef + ":\n pass").body[0] + assert isinstance(func_node, ast.FunctionDef) + assert create_passthrough_args(func_node) == expected + + +def test_process(tmp_path): + modpath = tmp_path / "_module.py" + genpath = tmp_path / "_generated_module.py" + modpath.write_text(SOURCE, encoding="utf-8") + assert not genpath.exists() + with pytest.raises(SystemExit) as excinfo: + process([(str(modpath), "runner")], do_test=True) + assert excinfo.value.code == 1 + process([(str(modpath), "runner")], do_test=False) + assert genpath.exists() + process([(str(modpath), "runner")], do_test=True) + # But if we change the lookup path it notices + with pytest.raises(SystemExit) as excinfo: + process([(str(modpath), "runner.io_manager")], do_test=True) + assert excinfo.value.code == 1 -- cgit v1.2.3