1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
|
# coding: utf-8
# Little utilities we use internally
from abc import ABCMeta
import os
import signal
import sys
import pathlib
from functools import wraps, update_wrapper
import typing as t
import threading
import collections
from async_generator import isasyncgen
import trio
# Equivalent to the C function raise(), which Python doesn't wrap
if os.name == "nt":
# On windows, os.kill exists but is really weird.
#
# If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver
# those using GenerateConsoleCtrlEvent. But I found that when I tried
# to run my test normally, it would freeze waiting... unless I added
# print statements, in which case the test suddenly worked. So I guess
# these signals are only delivered if/when you access the console? I
# don't really know what was going on there. From reading the
# GenerateConsoleCtrlEvent docs I don't know how it worked at all.
#
# I later spent a bunch of time trying to make GenerateConsoleCtrlEvent
# work for creating synthetic control-C events, and... failed
# utterly. There are lots of details in the code and comments
# removed/added at this commit:
# https://github.com/python-trio/trio/commit/95843654173e3e826c34d70a90b369ba6edf2c23
#
# OTOH, if you pass os.kill any *other* signal number... then CPython
# just calls TerminateProcess (wtf).
#
# So, anyway, os.kill is not so useful for testing purposes. Instead
# we use raise():
#
# https://msdn.microsoft.com/en-us/library/dwwzkt4c.aspx
#
# Have to import cffi inside the 'if os.name' block because we don't
# depend on cffi on non-Windows platforms. (It would be easy to switch
# this to ctypes though if we ever remove the cffi dependency.)
#
# Some more information:
# https://bugs.python.org/issue26350
#
# Anyway, we use this for two things:
# - redelivering unhandled signals
# - generating synthetic signals for tests
# and for both of those purposes, 'raise' works fine.
import cffi
_ffi = cffi.FFI()
_ffi.cdef("int raise(int);")
_lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll")
signal_raise = getattr(_lib, "raise")
else:
def signal_raise(signum):
signal.pthread_kill(threading.get_ident(), signum)
# See: #461 as to why this is needed.
# The gist is that threading.main_thread() has the capability to lie to us
# if somebody else edits the threading ident cache to replace the main
# thread; causing threading.current_thread() to return a _DummyThread,
# causing the C-c check to fail, and so on.
# Trying to use signal out of the main thread will fail, so we can then
# reliably check if this is the main thread without relying on a
# potentially modified threading.
def is_main_thread():
"""Attempt to reliably check if we are in the main thread."""
try:
signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT))
return True
except ValueError:
return False
######
# Call the function and get the coroutine object, while giving helpful
# errors for common mistakes. Returns coroutine object.
######
def coroutine_or_error(async_fn, *args):
def _return_value_looks_like_wrong_library(value):
# Returned by legacy @asyncio.coroutine functions, which includes
# a surprising proportion of asyncio builtins.
if isinstance(value, collections.abc.Generator):
return True
# The protocol for detecting an asyncio Future-like object
if getattr(value, "_asyncio_future_blocking", None) is not None:
return True
# This janky check catches tornado Futures and twisted Deferreds.
# By the time we're calling this function, we already know
# something has gone wrong, so a heuristic is pretty safe.
if value.__class__.__name__ in ("Future", "Deferred"):
return True
return False
try:
coro = async_fn(*args)
except TypeError:
# Give good error for: nursery.start_soon(trio.sleep(1))
if isinstance(async_fn, collections.abc.Coroutine):
# explicitly close coroutine to avoid RuntimeWarning
async_fn.close()
raise TypeError(
"Trio was expecting an async function, but instead it got "
"a coroutine object {async_fn!r}\n"
"\n"
"Probably you did something like:\n"
"\n"
" trio.run({async_fn.__name__}(...)) # incorrect!\n"
" nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n"
"\n"
"Instead, you want (notice the parentheses!):\n"
"\n"
" trio.run({async_fn.__name__}, ...) # correct!\n"
" nursery.start_soon({async_fn.__name__}, ...) # correct!".format(
async_fn=async_fn
)
) from None
# Give good error for: nursery.start_soon(future)
if _return_value_looks_like_wrong_library(async_fn):
raise TypeError(
"Trio was expecting an async function, but instead it got "
"{!r} – are you trying to use a library written for "
"asyncio/twisted/tornado or similar? That won't work "
"without some sort of compatibility shim.".format(async_fn)
) from None
raise
# We can't check iscoroutinefunction(async_fn), because that will fail
# for things like functools.partial objects wrapping an async
# function. So we have to just call it and then check whether the
# return value is a coroutine object.
if not isinstance(coro, collections.abc.Coroutine):
# Give good error for: nursery.start_soon(func_returning_future)
if _return_value_looks_like_wrong_library(coro):
raise TypeError(
"Trio got unexpected {!r} – are you trying to use a "
"library written for asyncio/twisted/tornado or similar? "
"That won't work without some sort of compatibility shim.".format(coro)
)
if isasyncgen(coro):
raise TypeError(
"start_soon expected an async function but got an async "
"generator {!r}".format(coro)
)
# Give good error for: nursery.start_soon(some_sync_fn)
raise TypeError(
"Trio expected an async function, but {!r} appears to be "
"synchronous".format(getattr(async_fn, "__qualname__", async_fn))
)
return coro
class ConflictDetector:
"""Detect when two tasks are about to perform operations that would
conflict.
Use as a synchronous context manager; if two tasks enter it at the same
time then the second one raises an error. You can use it when there are
two pieces of code that *would* collide and need a lock if they ever were
called at the same time, but that should never happen.
We use this in particular for things like, making sure that two different
tasks don't call sendall simultaneously on the same stream.
"""
def __init__(self, msg):
self._msg = msg
self._held = False
def __enter__(self):
if self._held:
raise trio.BusyResourceError(self._msg)
else:
self._held = True
def __exit__(self, *args):
self._held = False
def async_wraps(cls, wrapped_cls, attr_name):
"""Similar to wraps, but for async wrappers of non-async functions."""
def decorator(func):
func.__name__ = attr_name
func.__qualname__ = ".".join((cls.__qualname__, attr_name))
func.__doc__ = """Like :meth:`~{}.{}.{}`, but async.
""".format(
wrapped_cls.__module__, wrapped_cls.__qualname__, attr_name
)
return func
return decorator
def fixup_module_metadata(module_name, namespace):
seen_ids = set()
def fix_one(qualname, name, obj):
# avoid infinite recursion (relevant when using
# typing.Generic, for example)
if id(obj) in seen_ids:
return
seen_ids.add(id(obj))
mod = getattr(obj, "__module__", None)
if mod is not None and mod.startswith("trio."):
obj.__module__ = module_name
# Modules, unlike everything else in Python, put fully-qualitied
# names into their __name__ attribute. We check for "." to avoid
# rewriting these.
if hasattr(obj, "__name__") and "." not in obj.__name__:
obj.__name__ = name
obj.__qualname__ = qualname
if isinstance(obj, type):
for attr_name, attr_value in obj.__dict__.items():
fix_one(objname + "." + attr_name, attr_name, attr_value)
for objname, obj in namespace.items():
if not objname.startswith("_"): # ignore private attributes
fix_one(objname, objname, obj)
class generic_function:
"""Decorator that makes a function indexable, to communicate
non-inferrable generic type parameters to a static type checker.
If you write::
@generic_function
def open_memory_channel(max_buffer_size: int) -> Tuple[
SendChannel[T], ReceiveChannel[T]
]: ...
it is valid at runtime to say ``open_memory_channel[bytes](5)``.
This behaves identically to ``open_memory_channel(5)`` at runtime,
and currently won't type-check without a mypy plugin or clever stubs,
but at least it becomes possible to write those.
"""
def __init__(self, fn):
update_wrapper(self, fn)
self._fn = fn
def __call__(self, *args, **kwargs):
return self._fn(*args, **kwargs)
def __getitem__(self, _):
return self
class Final(ABCMeta):
"""Metaclass that enforces a class to be final (i.e., subclass not allowed).
If a class uses this metaclass like this::
class SomeClass(metaclass=Final):
pass
The metaclass will ensure that no sub class can be created.
Raises
------
- TypeError if a sub class is created
"""
def __new__(cls, name, bases, cls_namespace):
for base in bases:
if isinstance(base, Final):
raise TypeError(
f"{base.__module__}.{base.__qualname__} does not support subclassing"
)
return super().__new__(cls, name, bases, cls_namespace)
T = t.TypeVar("T")
class NoPublicConstructor(Final):
"""Metaclass that enforces a class to be final (i.e., subclass not allowed)
and ensures a private constructor.
If a class uses this metaclass like this::
class SomeClass(metaclass=NoPublicConstructor):
pass
The metaclass will ensure that no sub class can be created, and that no instance
can be initialized.
If you try to instantiate your class (SomeClass()), a TypeError will be thrown.
Raises
------
- TypeError if a sub class or an instance is created.
"""
def __call__(cls, *args, **kwargs):
raise TypeError(
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
)
def _create(cls: t.Type[T], *args: t.Any, **kwargs: t.Any) -> T:
return super().__call__(*args, **kwargs) # type: ignore
def name_asyncgen(agen):
"""Return the fully-qualified name of the async generator function
that produced the async generator iterator *agen*.
"""
if not hasattr(agen, "ag_code"): # pragma: no cover
return repr(agen)
try:
module = agen.ag_frame.f_globals["__name__"]
except (AttributeError, KeyError):
module = "<{}>".format(agen.ag_code.co_filename)
try:
qualname = agen.__qualname__
except AttributeError:
qualname = agen.ag_code.co_name
return f"{module}.{qualname}"
|