diff options
author | noptuno <repollo.marrero@gmail.com> | 2023-04-28 02:29:30 +0200 |
---|---|---|
committer | noptuno <repollo.marrero@gmail.com> | 2023-04-28 02:29:30 +0200 |
commit | 355dee533bb34a571b9367820a63cccb668cf866 (patch) | |
tree | 838af886b4fec07320aeb10f0d1e74ba79e79b5c /venv/lib/python3.9/site-packages/tornado | |
parent | added pyproject.toml file (diff) | |
download | gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar.gz gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar.bz2 gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar.lz gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar.xz gpt4free-355dee533bb34a571b9367820a63cccb668cf866.tar.zst gpt4free-355dee533bb34a571b9367820a63cccb668cf866.zip |
Diffstat (limited to 'venv/lib/python3.9/site-packages/tornado')
89 files changed, 42177 insertions, 0 deletions
diff --git a/venv/lib/python3.9/site-packages/tornado/__init__.py b/venv/lib/python3.9/site-packages/tornado/__init__.py new file mode 100644 index 00000000..afbd7150 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/__init__.py @@ -0,0 +1,67 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""The Tornado web server and tools.""" + +# version is a human-readable version number. + +# version_info is a four-tuple for programmatic comparison. The first +# three numbers are the components of the version number. The fourth +# is zero for an official release, positive for a development branch, +# or negative for a release candidate or beta (after the base version +# number has been incremented) +version = "6.3.1" +version_info = (6, 3, 1, 0) + +import importlib +import typing + +__all__ = [ + "auth", + "autoreload", + "concurrent", + "curl_httpclient", + "escape", + "gen", + "http1connection", + "httpclient", + "httpserver", + "httputil", + "ioloop", + "iostream", + "locale", + "locks", + "log", + "netutil", + "options", + "platform", + "process", + "queues", + "routing", + "simple_httpclient", + "tcpclient", + "tcpserver", + "template", + "testing", + "util", + "web", +] + + +# Copied from https://peps.python.org/pep-0562/ +def __getattr__(name: str) -> typing.Any: + if name in __all__: + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/venv/lib/python3.9/site-packages/tornado/_locale_data.py b/venv/lib/python3.9/site-packages/tornado/_locale_data.py new file mode 100644 index 00000000..7a5d2852 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/_locale_data.py @@ -0,0 +1,80 @@ +# Copyright 2012 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Data used by the tornado.locale module.""" + +LOCALE_NAMES = { + "af_ZA": {"name_en": "Afrikaans", "name": "Afrikaans"}, + "am_ET": {"name_en": "Amharic", "name": "አማርኛ"}, + "ar_AR": {"name_en": "Arabic", "name": "العربية"}, + "bg_BG": {"name_en": "Bulgarian", "name": "Български"}, + "bn_IN": {"name_en": "Bengali", "name": "বাংলা"}, + "bs_BA": {"name_en": "Bosnian", "name": "Bosanski"}, + "ca_ES": {"name_en": "Catalan", "name": "Català"}, + "cs_CZ": {"name_en": "Czech", "name": "Čeština"}, + "cy_GB": {"name_en": "Welsh", "name": "Cymraeg"}, + "da_DK": {"name_en": "Danish", "name": "Dansk"}, + "de_DE": {"name_en": "German", "name": "Deutsch"}, + "el_GR": {"name_en": "Greek", "name": "Ελληνικά"}, + "en_GB": {"name_en": "English (UK)", "name": "English (UK)"}, + "en_US": {"name_en": "English (US)", "name": "English (US)"}, + "es_ES": {"name_en": "Spanish (Spain)", "name": "Español (España)"}, + "es_LA": {"name_en": "Spanish", "name": "Español"}, + "et_EE": {"name_en": "Estonian", "name": "Eesti"}, + "eu_ES": {"name_en": "Basque", "name": "Euskara"}, + "fa_IR": {"name_en": "Persian", "name": "فارسی"}, + "fi_FI": {"name_en": "Finnish", "name": "Suomi"}, + "fr_CA": {"name_en": "French (Canada)", "name": "Français (Canada)"}, + "fr_FR": {"name_en": "French", "name": "Français"}, + "ga_IE": {"name_en": "Irish", "name": "Gaeilge"}, + "gl_ES": {"name_en": "Galician", "name": "Galego"}, + "he_IL": {"name_en": "Hebrew", "name": "עברית"}, + "hi_IN": {"name_en": "Hindi", "name": "हिन्दी"}, + "hr_HR": {"name_en": "Croatian", "name": "Hrvatski"}, + "hu_HU": {"name_en": "Hungarian", "name": "Magyar"}, + "id_ID": {"name_en": "Indonesian", "name": "Bahasa Indonesia"}, + "is_IS": {"name_en": "Icelandic", "name": "Íslenska"}, + "it_IT": {"name_en": "Italian", "name": "Italiano"}, + "ja_JP": {"name_en": "Japanese", "name": "日本語"}, + "ko_KR": {"name_en": "Korean", "name": "한국어"}, + "lt_LT": {"name_en": "Lithuanian", "name": "Lietuvių"}, + "lv_LV": {"name_en": "Latvian", "name": "Latviešu"}, + "mk_MK": {"name_en": "Macedonian", "name": "Македонски"}, + "ml_IN": {"name_en": "Malayalam", "name": "മലയാളം"}, + "ms_MY": {"name_en": "Malay", "name": "Bahasa Melayu"}, + "nb_NO": {"name_en": "Norwegian (bokmal)", "name": "Norsk (bokmål)"}, + "nl_NL": {"name_en": "Dutch", "name": "Nederlands"}, + "nn_NO": {"name_en": "Norwegian (nynorsk)", "name": "Norsk (nynorsk)"}, + "pa_IN": {"name_en": "Punjabi", "name": "ਪੰਜਾਬੀ"}, + "pl_PL": {"name_en": "Polish", "name": "Polski"}, + "pt_BR": {"name_en": "Portuguese (Brazil)", "name": "Português (Brasil)"}, + "pt_PT": {"name_en": "Portuguese (Portugal)", "name": "Português (Portugal)"}, + "ro_RO": {"name_en": "Romanian", "name": "Română"}, + "ru_RU": {"name_en": "Russian", "name": "Русский"}, + "sk_SK": {"name_en": "Slovak", "name": "Slovenčina"}, + "sl_SI": {"name_en": "Slovenian", "name": "Slovenščina"}, + "sq_AL": {"name_en": "Albanian", "name": "Shqip"}, + "sr_RS": {"name_en": "Serbian", "name": "Српски"}, + "sv_SE": {"name_en": "Swedish", "name": "Svenska"}, + "sw_KE": {"name_en": "Swahili", "name": "Kiswahili"}, + "ta_IN": {"name_en": "Tamil", "name": "தமிழ்"}, + "te_IN": {"name_en": "Telugu", "name": "తెలుగు"}, + "th_TH": {"name_en": "Thai", "name": "ภาษาไทย"}, + "tl_PH": {"name_en": "Filipino", "name": "Filipino"}, + "tr_TR": {"name_en": "Turkish", "name": "Türkçe"}, + "uk_UA": {"name_en": "Ukraini ", "name": "Українська"}, + "vi_VN": {"name_en": "Vietnamese", "name": "Tiếng Việt"}, + "zh_CN": {"name_en": "Chinese (Simplified)", "name": "中文(简体)"}, + "zh_TW": {"name_en": "Chinese (Traditional)", "name": "中文(繁體)"}, +} diff --git a/venv/lib/python3.9/site-packages/tornado/auth.py b/venv/lib/python3.9/site-packages/tornado/auth.py new file mode 100644 index 00000000..59501f56 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/auth.py @@ -0,0 +1,1212 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""This module contains implementations of various third-party +authentication schemes. + +All the classes in this file are class mixins designed to be used with +the `tornado.web.RequestHandler` class. They are used in two ways: + +* On a login handler, use methods such as ``authenticate_redirect()``, + ``authorize_redirect()``, and ``get_authenticated_user()`` to + establish the user's identity and store authentication tokens to your + database and/or cookies. +* In non-login handlers, use methods such as ``facebook_request()`` + or ``twitter_request()`` to use the authentication tokens to make + requests to the respective services. + +They all take slightly different arguments due to the fact all these +services implement authentication and authorization slightly differently. +See the individual service classes below for complete documentation. + +Example usage for Google OAuth: + +.. testcode:: + + class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, + tornado.auth.GoogleOAuth2Mixin): + async def get(self): + if self.get_argument('code', False): + user = await self.get_authenticated_user( + redirect_uri='http://your.site.com/auth/google', + code=self.get_argument('code')) + # Save the user with e.g. set_signed_cookie + else: + self.authorize_redirect( + redirect_uri='http://your.site.com/auth/google', + client_id=self.settings['google_oauth']['key'], + scope=['profile', 'email'], + response_type='code', + extra_params={'approval_prompt': 'auto'}) + +.. testoutput:: + :hide: + +""" + +import base64 +import binascii +import hashlib +import hmac +import time +import urllib.parse +import uuid + +from tornado import httpclient +from tornado import escape +from tornado.httputil import url_concat +from tornado.util import unicode_type +from tornado.web import RequestHandler + +from typing import List, Any, Dict, cast, Iterable, Union, Optional + + +class AuthError(Exception): + pass + + +class OpenIdMixin(object): + """Abstract implementation of OpenID and Attribute Exchange. + + Class attributes: + + * ``_OPENID_ENDPOINT``: the identity provider's URI. + """ + + def authenticate_redirect( + self, + callback_uri: Optional[str] = None, + ax_attrs: List[str] = ["name", "email", "language", "username"], + ) -> None: + """Redirects to the authentication URL for this service. + + After authentication, the service will redirect back to the given + callback URI with additional parameters including ``openid.mode``. + + We request the given attributes for the authenticated user by + default (name, email, language, and username). If you don't need + all those attributes for your app, you can request fewer with + the ax_attrs keyword argument. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed and this method no + longer returns an awaitable object. It is now an ordinary + synchronous function. + """ + handler = cast(RequestHandler, self) + callback_uri = callback_uri or handler.request.uri + assert callback_uri is not None + args = self._openid_args(callback_uri, ax_attrs=ax_attrs) + endpoint = self._OPENID_ENDPOINT # type: ignore + handler.redirect(endpoint + "?" + urllib.parse.urlencode(args)) + + async def get_authenticated_user( + self, http_client: Optional[httpclient.AsyncHTTPClient] = None + ) -> Dict[str, Any]: + """Fetches the authenticated user data upon redirect. + + This method should be called by the handler that receives the + redirect from the `authenticate_redirect()` method (which is + often the same as the one that calls it; in that case you would + call `get_authenticated_user` if the ``openid.mode`` parameter + is present and `authenticate_redirect` if it is not). + + The result of this method will generally be used to set a cookie. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + awaitable object instead. + """ + handler = cast(RequestHandler, self) + # Verify the OpenID response via direct request to the OP + args = dict( + (k, v[-1]) for k, v in handler.request.arguments.items() + ) # type: Dict[str, Union[str, bytes]] + args["openid.mode"] = "check_authentication" + url = self._OPENID_ENDPOINT # type: ignore + if http_client is None: + http_client = self.get_auth_http_client() + resp = await http_client.fetch( + url, method="POST", body=urllib.parse.urlencode(args) + ) + return self._on_authentication_verified(resp) + + def _openid_args( + self, + callback_uri: str, + ax_attrs: Iterable[str] = [], + oauth_scope: Optional[str] = None, + ) -> Dict[str, str]: + handler = cast(RequestHandler, self) + url = urllib.parse.urljoin(handler.request.full_url(), callback_uri) + args = { + "openid.ns": "http://specs.openid.net/auth/2.0", + "openid.claimed_id": "http://specs.openid.net/auth/2.0/identifier_select", + "openid.identity": "http://specs.openid.net/auth/2.0/identifier_select", + "openid.return_to": url, + "openid.realm": urllib.parse.urljoin(url, "/"), + "openid.mode": "checkid_setup", + } + if ax_attrs: + args.update( + { + "openid.ns.ax": "http://openid.net/srv/ax/1.0", + "openid.ax.mode": "fetch_request", + } + ) + ax_attrs = set(ax_attrs) + required = [] # type: List[str] + if "name" in ax_attrs: + ax_attrs -= set(["name", "firstname", "fullname", "lastname"]) + required += ["firstname", "fullname", "lastname"] + args.update( + { + "openid.ax.type.firstname": "http://axschema.org/namePerson/first", + "openid.ax.type.fullname": "http://axschema.org/namePerson", + "openid.ax.type.lastname": "http://axschema.org/namePerson/last", + } + ) + known_attrs = { + "email": "http://axschema.org/contact/email", + "language": "http://axschema.org/pref/language", + "username": "http://axschema.org/namePerson/friendly", + } + for name in ax_attrs: + args["openid.ax.type." + name] = known_attrs[name] + required.append(name) + args["openid.ax.required"] = ",".join(required) + if oauth_scope: + args.update( + { + "openid.ns.oauth": "http://specs.openid.net/extensions/oauth/1.0", + "openid.oauth.consumer": handler.request.host.split(":")[0], + "openid.oauth.scope": oauth_scope, + } + ) + return args + + def _on_authentication_verified( + self, response: httpclient.HTTPResponse + ) -> Dict[str, Any]: + handler = cast(RequestHandler, self) + if b"is_valid:true" not in response.body: + raise AuthError("Invalid OpenID response: %r" % response.body) + + # Make sure we got back at least an email from attribute exchange + ax_ns = None + for key in handler.request.arguments: + if ( + key.startswith("openid.ns.") + and handler.get_argument(key) == "http://openid.net/srv/ax/1.0" + ): + ax_ns = key[10:] + break + + def get_ax_arg(uri: str) -> str: + if not ax_ns: + return "" + prefix = "openid." + ax_ns + ".type." + ax_name = None + for name in handler.request.arguments.keys(): + if handler.get_argument(name) == uri and name.startswith(prefix): + part = name[len(prefix) :] + ax_name = "openid." + ax_ns + ".value." + part + break + if not ax_name: + return "" + return handler.get_argument(ax_name, "") + + email = get_ax_arg("http://axschema.org/contact/email") + name = get_ax_arg("http://axschema.org/namePerson") + first_name = get_ax_arg("http://axschema.org/namePerson/first") + last_name = get_ax_arg("http://axschema.org/namePerson/last") + username = get_ax_arg("http://axschema.org/namePerson/friendly") + locale = get_ax_arg("http://axschema.org/pref/language").lower() + user = dict() + name_parts = [] + if first_name: + user["first_name"] = first_name + name_parts.append(first_name) + if last_name: + user["last_name"] = last_name + name_parts.append(last_name) + if name: + user["name"] = name + elif name_parts: + user["name"] = " ".join(name_parts) + elif email: + user["name"] = email.split("@")[0] + if email: + user["email"] = email + if locale: + user["locale"] = locale + if username: + user["username"] = username + claimed_id = handler.get_argument("openid.claimed_id", None) + if claimed_id: + user["claimed_id"] = claimed_id + return user + + def get_auth_http_client(self) -> httpclient.AsyncHTTPClient: + """Returns the `.AsyncHTTPClient` instance to be used for auth requests. + + May be overridden by subclasses to use an HTTP client other than + the default. + """ + return httpclient.AsyncHTTPClient() + + +class OAuthMixin(object): + """Abstract implementation of OAuth 1.0 and 1.0a. + + See `TwitterMixin` below for an example implementation. + + Class attributes: + + * ``_OAUTH_AUTHORIZE_URL``: The service's OAuth authorization url. + * ``_OAUTH_ACCESS_TOKEN_URL``: The service's OAuth access token url. + * ``_OAUTH_VERSION``: May be either "1.0" or "1.0a". + * ``_OAUTH_NO_CALLBACKS``: Set this to True if the service requires + advance registration of callbacks. + + Subclasses must also override the `_oauth_get_user_future` and + `_oauth_consumer_token` methods. + """ + + async def authorize_redirect( + self, + callback_uri: Optional[str] = None, + extra_params: Optional[Dict[str, Any]] = None, + http_client: Optional[httpclient.AsyncHTTPClient] = None, + ) -> None: + """Redirects the user to obtain OAuth authorization for this service. + + The ``callback_uri`` may be omitted if you have previously + registered a callback URI with the third-party service. For + some services, you must use a previously-registered callback + URI and cannot specify a callback via this method. + + This method sets a cookie called ``_oauth_request_token`` which is + subsequently used (and cleared) in `get_authenticated_user` for + security purposes. + + This method is asynchronous and must be called with ``await`` + or ``yield`` (This is different from other ``auth*_redirect`` + methods defined in this module). It calls + `.RequestHandler.finish` for you so you should not write any + other response after it returns. + + .. versionchanged:: 3.1 + Now returns a `.Future` and takes an optional callback, for + compatibility with `.gen.coroutine`. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + awaitable object instead. + + """ + if callback_uri and getattr(self, "_OAUTH_NO_CALLBACKS", False): + raise Exception("This service does not support oauth_callback") + if http_client is None: + http_client = self.get_auth_http_client() + assert http_client is not None + if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": + response = await http_client.fetch( + self._oauth_request_token_url( + callback_uri=callback_uri, extra_params=extra_params + ) + ) + else: + response = await http_client.fetch(self._oauth_request_token_url()) + url = self._OAUTH_AUTHORIZE_URL # type: ignore + self._on_request_token(url, callback_uri, response) + + async def get_authenticated_user( + self, http_client: Optional[httpclient.AsyncHTTPClient] = None + ) -> Dict[str, Any]: + """Gets the OAuth authorized user and access token. + + This method should be called from the handler for your + OAuth callback URL to complete the registration process. We run the + callback with the authenticated user dictionary. This dictionary + will contain an ``access_key`` which can be used to make authorized + requests to this service on behalf of the user. The dictionary will + also contain other fields such as ``name``, depending on the service + used. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + awaitable object instead. + """ + handler = cast(RequestHandler, self) + request_key = escape.utf8(handler.get_argument("oauth_token")) + oauth_verifier = handler.get_argument("oauth_verifier", None) + request_cookie = handler.get_cookie("_oauth_request_token") + if not request_cookie: + raise AuthError("Missing OAuth request token cookie") + handler.clear_cookie("_oauth_request_token") + cookie_key, cookie_secret = [ + base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|") + ] + if cookie_key != request_key: + raise AuthError("Request token does not match cookie") + token = dict( + key=cookie_key, secret=cookie_secret + ) # type: Dict[str, Union[str, bytes]] + if oauth_verifier: + token["verifier"] = oauth_verifier + if http_client is None: + http_client = self.get_auth_http_client() + assert http_client is not None + response = await http_client.fetch(self._oauth_access_token_url(token)) + access_token = _oauth_parse_response(response.body) + user = await self._oauth_get_user_future(access_token) + if not user: + raise AuthError("Error getting user") + user["access_token"] = access_token + return user + + def _oauth_request_token_url( + self, + callback_uri: Optional[str] = None, + extra_params: Optional[Dict[str, Any]] = None, + ) -> str: + handler = cast(RequestHandler, self) + consumer_token = self._oauth_consumer_token() + url = self._OAUTH_REQUEST_TOKEN_URL # type: ignore + args = dict( + oauth_consumer_key=escape.to_basestring(consumer_token["key"]), + oauth_signature_method="HMAC-SHA1", + oauth_timestamp=str(int(time.time())), + oauth_nonce=escape.to_basestring(binascii.b2a_hex(uuid.uuid4().bytes)), + oauth_version="1.0", + ) + if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": + if callback_uri == "oob": + args["oauth_callback"] = "oob" + elif callback_uri: + args["oauth_callback"] = urllib.parse.urljoin( + handler.request.full_url(), callback_uri + ) + if extra_params: + args.update(extra_params) + signature = _oauth10a_signature(consumer_token, "GET", url, args) + else: + signature = _oauth_signature(consumer_token, "GET", url, args) + + args["oauth_signature"] = signature + return url + "?" + urllib.parse.urlencode(args) + + def _on_request_token( + self, + authorize_url: str, + callback_uri: Optional[str], + response: httpclient.HTTPResponse, + ) -> None: + handler = cast(RequestHandler, self) + request_token = _oauth_parse_response(response.body) + data = ( + base64.b64encode(escape.utf8(request_token["key"])) + + b"|" + + base64.b64encode(escape.utf8(request_token["secret"])) + ) + handler.set_cookie("_oauth_request_token", data) + args = dict(oauth_token=request_token["key"]) + if callback_uri == "oob": + handler.finish(authorize_url + "?" + urllib.parse.urlencode(args)) + return + elif callback_uri: + args["oauth_callback"] = urllib.parse.urljoin( + handler.request.full_url(), callback_uri + ) + handler.redirect(authorize_url + "?" + urllib.parse.urlencode(args)) + + def _oauth_access_token_url(self, request_token: Dict[str, Any]) -> str: + consumer_token = self._oauth_consumer_token() + url = self._OAUTH_ACCESS_TOKEN_URL # type: ignore + args = dict( + oauth_consumer_key=escape.to_basestring(consumer_token["key"]), + oauth_token=escape.to_basestring(request_token["key"]), + oauth_signature_method="HMAC-SHA1", + oauth_timestamp=str(int(time.time())), + oauth_nonce=escape.to_basestring(binascii.b2a_hex(uuid.uuid4().bytes)), + oauth_version="1.0", + ) + if "verifier" in request_token: + args["oauth_verifier"] = request_token["verifier"] + + if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": + signature = _oauth10a_signature( + consumer_token, "GET", url, args, request_token + ) + else: + signature = _oauth_signature( + consumer_token, "GET", url, args, request_token + ) + + args["oauth_signature"] = signature + return url + "?" + urllib.parse.urlencode(args) + + def _oauth_consumer_token(self) -> Dict[str, Any]: + """Subclasses must override this to return their OAuth consumer keys. + + The return value should be a `dict` with keys ``key`` and ``secret``. + """ + raise NotImplementedError() + + async def _oauth_get_user_future( + self, access_token: Dict[str, Any] + ) -> Dict[str, Any]: + """Subclasses must override this to get basic information about the + user. + + Should be a coroutine whose result is a dictionary + containing information about the user, which may have been + retrieved by using ``access_token`` to make a request to the + service. + + The access token will be added to the returned dictionary to make + the result of `get_authenticated_user`. + + .. versionchanged:: 5.1 + + Subclasses may also define this method with ``async def``. + + .. versionchanged:: 6.0 + + A synchronous fallback to ``_oauth_get_user`` was removed. + """ + raise NotImplementedError() + + def _oauth_request_parameters( + self, + url: str, + access_token: Dict[str, Any], + parameters: Dict[str, Any] = {}, + method: str = "GET", + ) -> Dict[str, Any]: + """Returns the OAuth parameters as a dict for the given request. + + parameters should include all POST arguments and query string arguments + that will be sent with the request. + """ + consumer_token = self._oauth_consumer_token() + base_args = dict( + oauth_consumer_key=escape.to_basestring(consumer_token["key"]), + oauth_token=escape.to_basestring(access_token["key"]), + oauth_signature_method="HMAC-SHA1", + oauth_timestamp=str(int(time.time())), + oauth_nonce=escape.to_basestring(binascii.b2a_hex(uuid.uuid4().bytes)), + oauth_version="1.0", + ) + args = {} + args.update(base_args) + args.update(parameters) + if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": + signature = _oauth10a_signature( + consumer_token, method, url, args, access_token + ) + else: + signature = _oauth_signature( + consumer_token, method, url, args, access_token + ) + base_args["oauth_signature"] = escape.to_basestring(signature) + return base_args + + def get_auth_http_client(self) -> httpclient.AsyncHTTPClient: + """Returns the `.AsyncHTTPClient` instance to be used for auth requests. + + May be overridden by subclasses to use an HTTP client other than + the default. + """ + return httpclient.AsyncHTTPClient() + + +class OAuth2Mixin(object): + """Abstract implementation of OAuth 2.0. + + See `FacebookGraphMixin` or `GoogleOAuth2Mixin` below for example + implementations. + + Class attributes: + + * ``_OAUTH_AUTHORIZE_URL``: The service's authorization url. + * ``_OAUTH_ACCESS_TOKEN_URL``: The service's access token url. + """ + + def authorize_redirect( + self, + redirect_uri: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + extra_params: Optional[Dict[str, Any]] = None, + scope: Optional[List[str]] = None, + response_type: str = "code", + ) -> None: + """Redirects the user to obtain OAuth authorization for this service. + + Some providers require that you register a redirect URL with + your application instead of passing one via this method. You + should call this method to log the user in, and then call + ``get_authenticated_user`` in the handler for your + redirect URL to complete the authorization process. + + .. versionchanged:: 6.0 + + The ``callback`` argument and returned awaitable were removed; + this is now an ordinary synchronous function. + """ + handler = cast(RequestHandler, self) + args = {"response_type": response_type} + if redirect_uri is not None: + args["redirect_uri"] = redirect_uri + if client_id is not None: + args["client_id"] = client_id + if extra_params: + args.update(extra_params) + if scope: + args["scope"] = " ".join(scope) + url = self._OAUTH_AUTHORIZE_URL # type: ignore + handler.redirect(url_concat(url, args)) + + def _oauth_request_token_url( + self, + redirect_uri: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + code: Optional[str] = None, + extra_params: Optional[Dict[str, Any]] = None, + ) -> str: + url = self._OAUTH_ACCESS_TOKEN_URL # type: ignore + args = {} # type: Dict[str, str] + if redirect_uri is not None: + args["redirect_uri"] = redirect_uri + if code is not None: + args["code"] = code + if client_id is not None: + args["client_id"] = client_id + if client_secret is not None: + args["client_secret"] = client_secret + if extra_params: + args.update(extra_params) + return url_concat(url, args) + + async def oauth2_request( + self, + url: str, + access_token: Optional[str] = None, + post_args: Optional[Dict[str, Any]] = None, + **args: Any + ) -> Any: + """Fetches the given URL auth an OAuth2 access token. + + If the request is a POST, ``post_args`` should be provided. Query + string arguments should be given as keyword arguments. + + Example usage: + + ..testcode:: + + class MainHandler(tornado.web.RequestHandler, + tornado.auth.FacebookGraphMixin): + @tornado.web.authenticated + async def get(self): + new_entry = await self.oauth2_request( + "https://graph.facebook.com/me/feed", + post_args={"message": "I am posting from my Tornado application!"}, + access_token=self.current_user["access_token"]) + + if not new_entry: + # Call failed; perhaps missing permission? + self.authorize_redirect() + return + self.finish("Posted a message!") + + .. testoutput:: + :hide: + + .. versionadded:: 4.3 + + .. versionchanged::: 6.0 + + The ``callback`` argument was removed. Use the returned awaitable object instead. + """ + all_args = {} + if access_token: + all_args["access_token"] = access_token + all_args.update(args) + + if all_args: + url += "?" + urllib.parse.urlencode(all_args) + http = self.get_auth_http_client() + if post_args is not None: + response = await http.fetch( + url, method="POST", body=urllib.parse.urlencode(post_args) + ) + else: + response = await http.fetch(url) + return escape.json_decode(response.body) + + def get_auth_http_client(self) -> httpclient.AsyncHTTPClient: + """Returns the `.AsyncHTTPClient` instance to be used for auth requests. + + May be overridden by subclasses to use an HTTP client other than + the default. + + .. versionadded:: 4.3 + """ + return httpclient.AsyncHTTPClient() + + +class TwitterMixin(OAuthMixin): + """Twitter OAuth authentication. + + To authenticate with Twitter, register your application with + Twitter at http://twitter.com/apps. Then copy your Consumer Key + and Consumer Secret to the application + `~tornado.web.Application.settings` ``twitter_consumer_key`` and + ``twitter_consumer_secret``. Use this mixin on the handler for the + URL you registered as your application's callback URL. + + When your application is set up, you can use this mixin like this + to authenticate the user with Twitter and get access to their stream: + + .. testcode:: + + class TwitterLoginHandler(tornado.web.RequestHandler, + tornado.auth.TwitterMixin): + async def get(self): + if self.get_argument("oauth_token", None): + user = await self.get_authenticated_user() + # Save the user using e.g. set_signed_cookie() + else: + await self.authorize_redirect() + + .. testoutput:: + :hide: + + The user object returned by `~OAuthMixin.get_authenticated_user` + includes the attributes ``username``, ``name``, ``access_token``, + and all of the custom Twitter user attributes described at + https://dev.twitter.com/docs/api/1.1/get/users/show + """ + + _OAUTH_REQUEST_TOKEN_URL = "https://api.twitter.com/oauth/request_token" + _OAUTH_ACCESS_TOKEN_URL = "https://api.twitter.com/oauth/access_token" + _OAUTH_AUTHORIZE_URL = "https://api.twitter.com/oauth/authorize" + _OAUTH_AUTHENTICATE_URL = "https://api.twitter.com/oauth/authenticate" + _OAUTH_NO_CALLBACKS = False + _TWITTER_BASE_URL = "https://api.twitter.com/1.1" + + async def authenticate_redirect(self, callback_uri: Optional[str] = None) -> None: + """Just like `~OAuthMixin.authorize_redirect`, but + auto-redirects if authorized. + + This is generally the right interface to use if you are using + Twitter for single-sign on. + + .. versionchanged:: 3.1 + Now returns a `.Future` and takes an optional callback, for + compatibility with `.gen.coroutine`. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + awaitable object instead. + """ + http = self.get_auth_http_client() + response = await http.fetch( + self._oauth_request_token_url(callback_uri=callback_uri) + ) + self._on_request_token(self._OAUTH_AUTHENTICATE_URL, None, response) + + async def twitter_request( + self, + path: str, + access_token: Dict[str, Any], + post_args: Optional[Dict[str, Any]] = None, + **args: Any + ) -> Any: + """Fetches the given API path, e.g., ``statuses/user_timeline/btaylor`` + + The path should not include the format or API version number. + (we automatically use JSON format and API version 1). + + If the request is a POST, ``post_args`` should be provided. Query + string arguments should be given as keyword arguments. + + All the Twitter methods are documented at http://dev.twitter.com/ + + Many methods require an OAuth access token which you can + obtain through `~OAuthMixin.authorize_redirect` and + `~OAuthMixin.get_authenticated_user`. The user returned through that + process includes an 'access_token' attribute that can be used + to make authenticated requests via this method. Example + usage: + + .. testcode:: + + class MainHandler(tornado.web.RequestHandler, + tornado.auth.TwitterMixin): + @tornado.web.authenticated + async def get(self): + new_entry = await self.twitter_request( + "/statuses/update", + post_args={"status": "Testing Tornado Web Server"}, + access_token=self.current_user["access_token"]) + if not new_entry: + # Call failed; perhaps missing permission? + await self.authorize_redirect() + return + self.finish("Posted a message!") + + .. testoutput:: + :hide: + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + awaitable object instead. + """ + if path.startswith("http:") or path.startswith("https:"): + # Raw urls are useful for e.g. search which doesn't follow the + # usual pattern: http://search.twitter.com/search.json + url = path + else: + url = self._TWITTER_BASE_URL + path + ".json" + # Add the OAuth resource request signature if we have credentials + if access_token: + all_args = {} + all_args.update(args) + all_args.update(post_args or {}) + method = "POST" if post_args is not None else "GET" + oauth = self._oauth_request_parameters( + url, access_token, all_args, method=method + ) + args.update(oauth) + if args: + url += "?" + urllib.parse.urlencode(args) + http = self.get_auth_http_client() + if post_args is not None: + response = await http.fetch( + url, method="POST", body=urllib.parse.urlencode(post_args) + ) + else: + response = await http.fetch(url) + return escape.json_decode(response.body) + + def _oauth_consumer_token(self) -> Dict[str, Any]: + handler = cast(RequestHandler, self) + handler.require_setting("twitter_consumer_key", "Twitter OAuth") + handler.require_setting("twitter_consumer_secret", "Twitter OAuth") + return dict( + key=handler.settings["twitter_consumer_key"], + secret=handler.settings["twitter_consumer_secret"], + ) + + async def _oauth_get_user_future( + self, access_token: Dict[str, Any] + ) -> Dict[str, Any]: + user = await self.twitter_request( + "/account/verify_credentials", access_token=access_token + ) + if user: + user["username"] = user["screen_name"] + return user + + +class GoogleOAuth2Mixin(OAuth2Mixin): + """Google authentication using OAuth2. + + In order to use, register your application with Google and copy the + relevant parameters to your application settings. + + * Go to the Google Dev Console at http://console.developers.google.com + * Select a project, or create a new one. + * In the sidebar on the left, select Credentials. + * Click CREATE CREDENTIALS and click OAuth client ID. + * Under Application type, select Web application. + * Name OAuth 2.0 client and click Create. + * Copy the "Client secret" and "Client ID" to the application settings as + ``{"google_oauth": {"key": CLIENT_ID, "secret": CLIENT_SECRET}}`` + + .. versionadded:: 3.2 + """ + + _OAUTH_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth" + _OAUTH_ACCESS_TOKEN_URL = "https://www.googleapis.com/oauth2/v4/token" + _OAUTH_USERINFO_URL = "https://www.googleapis.com/oauth2/v1/userinfo" + _OAUTH_NO_CALLBACKS = False + _OAUTH_SETTINGS_KEY = "google_oauth" + + def get_google_oauth_settings(self) -> Dict[str, str]: + """Return the Google OAuth 2.0 credentials that you created with + [Google Cloud + Platform](https://console.cloud.google.com/apis/credentials). The dict + format is:: + + { + "key": "your_client_id", "secret": "your_client_secret" + } + + If your credentials are stored differently (e.g. in a db) you can + override this method for custom provision. + """ + handler = cast(RequestHandler, self) + return handler.settings[self._OAUTH_SETTINGS_KEY] + + async def get_authenticated_user( + self, + redirect_uri: str, + code: str, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + ) -> Dict[str, Any]: + """Handles the login for the Google user, returning an access token. + + The result is a dictionary containing an ``access_token`` field + ([among others](https://developers.google.com/identity/protocols/OAuth2WebServer#handlingtheresponse)). + Unlike other ``get_authenticated_user`` methods in this package, + this method does not return any additional information about the user. + The returned access token can be used with `OAuth2Mixin.oauth2_request` + to request additional information (perhaps from + ``https://www.googleapis.com/oauth2/v2/userinfo``) + + Example usage: + + .. testcode:: + + class GoogleOAuth2LoginHandler(tornado.web.RequestHandler, + tornado.auth.GoogleOAuth2Mixin): + async def get(self): + if self.get_argument('code', False): + access = await self.get_authenticated_user( + redirect_uri='http://your.site.com/auth/google', + code=self.get_argument('code')) + user = await self.oauth2_request( + "https://www.googleapis.com/oauth2/v1/userinfo", + access_token=access["access_token"]) + # Save the user and access token with + # e.g. set_signed_cookie. + else: + self.authorize_redirect( + redirect_uri='http://your.site.com/auth/google', + client_id=self.get_google_oauth_settings()['key'], + scope=['profile', 'email'], + response_type='code', + extra_params={'approval_prompt': 'auto'}) + + .. testoutput:: + :hide: + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned awaitable object instead. + """ # noqa: E501 + + if client_id is None or client_secret is None: + settings = self.get_google_oauth_settings() + if client_id is None: + client_id = settings["key"] + if client_secret is None: + client_secret = settings["secret"] + http = self.get_auth_http_client() + body = urllib.parse.urlencode( + { + "redirect_uri": redirect_uri, + "code": code, + "client_id": client_id, + "client_secret": client_secret, + "grant_type": "authorization_code", + } + ) + + response = await http.fetch( + self._OAUTH_ACCESS_TOKEN_URL, + method="POST", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + body=body, + ) + return escape.json_decode(response.body) + + +class FacebookGraphMixin(OAuth2Mixin): + """Facebook authentication using the new Graph API and OAuth2.""" + + _OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?" + _OAUTH_AUTHORIZE_URL = "https://www.facebook.com/dialog/oauth?" + _OAUTH_NO_CALLBACKS = False + _FACEBOOK_BASE_URL = "https://graph.facebook.com" + + async def get_authenticated_user( + self, + redirect_uri: str, + client_id: str, + client_secret: str, + code: str, + extra_fields: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: + """Handles the login for the Facebook user, returning a user object. + + Example usage: + + .. testcode:: + + class FacebookGraphLoginHandler(tornado.web.RequestHandler, + tornado.auth.FacebookGraphMixin): + async def get(self): + if self.get_argument("code", False): + user = await self.get_authenticated_user( + redirect_uri='/auth/facebookgraph/', + client_id=self.settings["facebook_api_key"], + client_secret=self.settings["facebook_secret"], + code=self.get_argument("code")) + # Save the user with e.g. set_signed_cookie + else: + self.authorize_redirect( + redirect_uri='/auth/facebookgraph/', + client_id=self.settings["facebook_api_key"], + extra_params={"scope": "read_stream,offline_access"}) + + .. testoutput:: + :hide: + + This method returns a dictionary which may contain the following fields: + + * ``access_token``, a string which may be passed to `facebook_request` + * ``session_expires``, an integer encoded as a string representing + the time until the access token expires in seconds. This field should + be used like ``int(user['session_expires'])``; in a future version of + Tornado it will change from a string to an integer. + * ``id``, ``name``, ``first_name``, ``last_name``, ``locale``, ``picture``, + ``link``, plus any fields named in the ``extra_fields`` argument. These + fields are copied from the Facebook graph API + `user object <https://developers.facebook.com/docs/graph-api/reference/user>`_ + + .. versionchanged:: 4.5 + The ``session_expires`` field was updated to support changes made to the + Facebook API in March 2017. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned awaitable object instead. + """ + http = self.get_auth_http_client() + args = { + "redirect_uri": redirect_uri, + "code": code, + "client_id": client_id, + "client_secret": client_secret, + } + + fields = set( + ["id", "name", "first_name", "last_name", "locale", "picture", "link"] + ) + if extra_fields: + fields.update(extra_fields) + + response = await http.fetch( + self._oauth_request_token_url(**args) # type: ignore + ) + args = escape.json_decode(response.body) + session = { + "access_token": args.get("access_token"), + "expires_in": args.get("expires_in"), + } + assert session["access_token"] is not None + + user = await self.facebook_request( + path="/me", + access_token=session["access_token"], + appsecret_proof=hmac.new( + key=client_secret.encode("utf8"), + msg=session["access_token"].encode("utf8"), + digestmod=hashlib.sha256, + ).hexdigest(), + fields=",".join(fields), + ) + + if user is None: + return None + + fieldmap = {} + for field in fields: + fieldmap[field] = user.get(field) + + # session_expires is converted to str for compatibility with + # older versions in which the server used url-encoding and + # this code simply returned the string verbatim. + # This should change in Tornado 5.0. + fieldmap.update( + { + "access_token": session["access_token"], + "session_expires": str(session.get("expires_in")), + } + ) + return fieldmap + + async def facebook_request( + self, + path: str, + access_token: Optional[str] = None, + post_args: Optional[Dict[str, Any]] = None, + **args: Any + ) -> Any: + """Fetches the given relative API path, e.g., "/btaylor/picture" + + If the request is a POST, ``post_args`` should be provided. Query + string arguments should be given as keyword arguments. + + An introduction to the Facebook Graph API can be found at + http://developers.facebook.com/docs/api + + Many methods require an OAuth access token which you can + obtain through `~OAuth2Mixin.authorize_redirect` and + `get_authenticated_user`. The user returned through that + process includes an ``access_token`` attribute that can be + used to make authenticated requests via this method. + + Example usage: + + .. testcode:: + + class MainHandler(tornado.web.RequestHandler, + tornado.auth.FacebookGraphMixin): + @tornado.web.authenticated + async def get(self): + new_entry = await self.facebook_request( + "/me/feed", + post_args={"message": "I am posting from my Tornado application!"}, + access_token=self.current_user["access_token"]) + + if not new_entry: + # Call failed; perhaps missing permission? + self.authorize_redirect() + return + self.finish("Posted a message!") + + .. testoutput:: + :hide: + + The given path is relative to ``self._FACEBOOK_BASE_URL``, + by default "https://graph.facebook.com". + + This method is a wrapper around `OAuth2Mixin.oauth2_request`; + the only difference is that this method takes a relative path, + while ``oauth2_request`` takes a complete url. + + .. versionchanged:: 3.1 + Added the ability to override ``self._FACEBOOK_BASE_URL``. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned awaitable object instead. + """ + url = self._FACEBOOK_BASE_URL + path + return await self.oauth2_request( + url, access_token=access_token, post_args=post_args, **args + ) + + +def _oauth_signature( + consumer_token: Dict[str, Any], + method: str, + url: str, + parameters: Dict[str, Any] = {}, + token: Optional[Dict[str, Any]] = None, +) -> bytes: + """Calculates the HMAC-SHA1 OAuth signature for the given request. + + See http://oauth.net/core/1.0/#signing_process + """ + parts = urllib.parse.urlparse(url) + scheme, netloc, path = parts[:3] + normalized_url = scheme.lower() + "://" + netloc.lower() + path + + base_elems = [] + base_elems.append(method.upper()) + base_elems.append(normalized_url) + base_elems.append( + "&".join( + "%s=%s" % (k, _oauth_escape(str(v))) for k, v in sorted(parameters.items()) + ) + ) + base_string = "&".join(_oauth_escape(e) for e in base_elems) + + key_elems = [escape.utf8(consumer_token["secret"])] + key_elems.append(escape.utf8(token["secret"] if token else "")) + key = b"&".join(key_elems) + + hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1) + return binascii.b2a_base64(hash.digest())[:-1] + + +def _oauth10a_signature( + consumer_token: Dict[str, Any], + method: str, + url: str, + parameters: Dict[str, Any] = {}, + token: Optional[Dict[str, Any]] = None, +) -> bytes: + """Calculates the HMAC-SHA1 OAuth 1.0a signature for the given request. + + See http://oauth.net/core/1.0a/#signing_process + """ + parts = urllib.parse.urlparse(url) + scheme, netloc, path = parts[:3] + normalized_url = scheme.lower() + "://" + netloc.lower() + path + + base_elems = [] + base_elems.append(method.upper()) + base_elems.append(normalized_url) + base_elems.append( + "&".join( + "%s=%s" % (k, _oauth_escape(str(v))) for k, v in sorted(parameters.items()) + ) + ) + + base_string = "&".join(_oauth_escape(e) for e in base_elems) + key_elems = [escape.utf8(urllib.parse.quote(consumer_token["secret"], safe="~"))] + key_elems.append( + escape.utf8(urllib.parse.quote(token["secret"], safe="~") if token else "") + ) + key = b"&".join(key_elems) + + hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1) + return binascii.b2a_base64(hash.digest())[:-1] + + +def _oauth_escape(val: Union[str, bytes]) -> str: + if isinstance(val, unicode_type): + val = val.encode("utf-8") + return urllib.parse.quote(val, safe="~") + + +def _oauth_parse_response(body: bytes) -> Dict[str, Any]: + # I can't find an officially-defined encoding for oauth responses and + # have never seen anyone use non-ascii. Leave the response in a byte + # string for python 2, and use utf8 on python 3. + body_str = escape.native_str(body) + p = urllib.parse.parse_qs(body_str, keep_blank_values=False) + token = dict(key=p["oauth_token"][0], secret=p["oauth_token_secret"][0]) + + # Add the extra parameters the Provider included to the token + special = ("oauth_token", "oauth_token_secret") + token.update((k, p[k][0]) for k in p if k not in special) + return token diff --git a/venv/lib/python3.9/site-packages/tornado/autoreload.py b/venv/lib/python3.9/site-packages/tornado/autoreload.py new file mode 100644 index 00000000..0ac44966 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/autoreload.py @@ -0,0 +1,360 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Automatically restart the server when a source file is modified. + +Most applications should not access this module directly. Instead, +pass the keyword argument ``autoreload=True`` to the +`tornado.web.Application` constructor (or ``debug=True``, which +enables this setting and several others). This will enable autoreload +mode as well as checking for changes to templates and static +resources. Note that restarting is a destructive operation and any +requests in progress will be aborted when the process restarts. (If +you want to disable autoreload while using other debug-mode features, +pass both ``debug=True`` and ``autoreload=False``). + +This module can also be used as a command-line wrapper around scripts +such as unit test runners. See the `main` method for details. + +The command-line wrapper and Application debug modes can be used together. +This combination is encouraged as the wrapper catches syntax errors and +other import-time failures, while debug mode catches changes once +the server has started. + +This module will not work correctly when `.HTTPServer`'s multi-process +mode is used. + +Reloading loses any Python interpreter command-line arguments (e.g. ``-u``) +because it re-executes Python using ``sys.executable`` and ``sys.argv``. +Additionally, modifying these variables will cause reloading to behave +incorrectly. + +""" + +import os +import sys + +# sys.path handling +# ----------------- +# +# If a module is run with "python -m", the current directory (i.e. "") +# is automatically prepended to sys.path, but not if it is run as +# "path/to/file.py". The processing for "-m" rewrites the former to +# the latter, so subsequent executions won't have the same path as the +# original. +# +# Conversely, when run as path/to/file.py, the directory containing +# file.py gets added to the path, which can cause confusion as imports +# may become relative in spite of the future import. +# +# We address the former problem by reconstructing the original command +# line (Python >= 3.4) or by setting the $PYTHONPATH environment +# variable (Python < 3.4) before re-execution so the new process will +# see the correct path. We attempt to address the latter problem when +# tornado.autoreload is run as __main__. + +if __name__ == "__main__": + # This sys.path manipulation must come before our imports (as much + # as possible - if we introduced a tornado.sys or tornado.os + # module we'd be in trouble), or else our imports would become + # relative again despite the future import. + # + # There is a separate __main__ block at the end of the file to call main(). + if sys.path[0] == os.path.dirname(__file__): + del sys.path[0] + +import functools +import os +import pkgutil # type: ignore +import sys +import traceback +import types +import subprocess +import weakref + +from tornado import ioloop +from tornado.log import gen_log +from tornado import process +from tornado.util import exec_in + +try: + import signal +except ImportError: + signal = None # type: ignore + +import typing +from typing import Callable, Dict + +if typing.TYPE_CHECKING: + from typing import List, Optional, Union # noqa: F401 + +# os.execv is broken on Windows and can't properly parse command line +# arguments and executable name if they contain whitespaces. subprocess +# fixes that behavior. +_has_execv = sys.platform != "win32" + +_watched_files = set() +_reload_hooks = [] +_reload_attempted = False +_io_loops = weakref.WeakKeyDictionary() # type: ignore +_autoreload_is_main = False +_original_argv = None # type: Optional[List[str]] +_original_spec = None + + +def start(check_time: int = 500) -> None: + """Begins watching source files for changes. + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + """ + io_loop = ioloop.IOLoop.current() + if io_loop in _io_loops: + return + _io_loops[io_loop] = True + if len(_io_loops) > 1: + gen_log.warning("tornado.autoreload started more than once in the same process") + modify_times = {} # type: Dict[str, float] + callback = functools.partial(_reload_on_update, modify_times) + scheduler = ioloop.PeriodicCallback(callback, check_time) + scheduler.start() + + +def wait() -> None: + """Wait for a watched file to change, then restart the process. + + Intended to be used at the end of scripts like unit test runners, + to run the tests again after any source file changes (but see also + the command-line interface in `main`) + """ + io_loop = ioloop.IOLoop() + io_loop.add_callback(start) + io_loop.start() + + +def watch(filename: str) -> None: + """Add a file to the watch list. + + All imported modules are watched by default. + """ + _watched_files.add(filename) + + +def add_reload_hook(fn: Callable[[], None]) -> None: + """Add a function to be called before reloading the process. + + Note that for open file and socket handles it is generally + preferable to set the ``FD_CLOEXEC`` flag (using `fcntl` or + `os.set_inheritable`) instead of using a reload hook to close them. + """ + _reload_hooks.append(fn) + + +def _reload_on_update(modify_times: Dict[str, float]) -> None: + if _reload_attempted: + # We already tried to reload and it didn't work, so don't try again. + return + if process.task_id() is not None: + # We're in a child process created by fork_processes. If child + # processes restarted themselves, they'd all restart and then + # all call fork_processes again. + return + for module in list(sys.modules.values()): + # Some modules play games with sys.modules (e.g. email/__init__.py + # in the standard library), and occasionally this can cause strange + # failures in getattr. Just ignore anything that's not an ordinary + # module. + if not isinstance(module, types.ModuleType): + continue + path = getattr(module, "__file__", None) + if not path: + continue + if path.endswith(".pyc") or path.endswith(".pyo"): + path = path[:-1] + _check_file(modify_times, path) + for path in _watched_files: + _check_file(modify_times, path) + + +def _check_file(modify_times: Dict[str, float], path: str) -> None: + try: + modified = os.stat(path).st_mtime + except Exception: + return + if path not in modify_times: + modify_times[path] = modified + return + if modify_times[path] != modified: + gen_log.info("%s modified; restarting server", path) + _reload() + + +def _reload() -> None: + global _reload_attempted + _reload_attempted = True + for fn in _reload_hooks: + fn() + if sys.platform != "win32": + # Clear the alarm signal set by + # ioloop.set_blocking_log_threshold so it doesn't fire + # after the exec. + signal.setitimer(signal.ITIMER_REAL, 0, 0) + # sys.path fixes: see comments at top of file. If __main__.__spec__ + # exists, we were invoked with -m and the effective path is about to + # change on re-exec. Reconstruct the original command line to + # ensure that the new process sees the same path we did. If + # __spec__ is not available (Python < 3.4), check instead if + # sys.path[0] is an empty string and add the current directory to + # $PYTHONPATH. + if _autoreload_is_main: + assert _original_argv is not None + spec = _original_spec + argv = _original_argv + else: + spec = getattr(sys.modules["__main__"], "__spec__", None) + argv = sys.argv + if spec: + argv = ["-m", spec.name] + argv[1:] + else: + path_prefix = "." + os.pathsep + if sys.path[0] == "" and not os.environ.get("PYTHONPATH", "").startswith( + path_prefix + ): + os.environ["PYTHONPATH"] = path_prefix + os.environ.get("PYTHONPATH", "") + if not _has_execv: + subprocess.Popen([sys.executable] + argv) + os._exit(0) + else: + try: + os.execv(sys.executable, [sys.executable] + argv) + except OSError: + # Mac OS X versions prior to 10.6 do not support execv in + # a process that contains multiple threads. Instead of + # re-executing in the current process, start a new one + # and cause the current process to exit. This isn't + # ideal since the new process is detached from the parent + # terminal and thus cannot easily be killed with ctrl-C, + # but it's better than not being able to autoreload at + # all. + # Unfortunately the errno returned in this case does not + # appear to be consistent, so we can't easily check for + # this error specifically. + os.spawnv( + os.P_NOWAIT, sys.executable, [sys.executable] + argv # type: ignore + ) + # At this point the IOLoop has been closed and finally + # blocks will experience errors if we allow the stack to + # unwind, so just exit uncleanly. + os._exit(0) + + +_USAGE = """\ +Usage: + python -m tornado.autoreload -m module.to.run [args...] + python -m tornado.autoreload path/to/script.py [args...] +""" + + +def main() -> None: + """Command-line wrapper to re-run a script whenever its source changes. + + Scripts may be specified by filename or module name:: + + python -m tornado.autoreload -m tornado.test.runtests + python -m tornado.autoreload tornado/test/runtests.py + + Running a script with this wrapper is similar to calling + `tornado.autoreload.wait` at the end of the script, but this wrapper + can catch import-time problems like syntax errors that would otherwise + prevent the script from reaching its call to `wait`. + """ + # Remember that we were launched with autoreload as main. + # The main module can be tricky; set the variables both in our globals + # (which may be __main__) and the real importable version. + import tornado.autoreload + + global _autoreload_is_main + global _original_argv, _original_spec + tornado.autoreload._autoreload_is_main = _autoreload_is_main = True + original_argv = sys.argv + tornado.autoreload._original_argv = _original_argv = original_argv + original_spec = getattr(sys.modules["__main__"], "__spec__", None) + tornado.autoreload._original_spec = _original_spec = original_spec + sys.argv = sys.argv[:] + if len(sys.argv) >= 3 and sys.argv[1] == "-m": + mode = "module" + module = sys.argv[2] + del sys.argv[1:3] + elif len(sys.argv) >= 2: + mode = "script" + script = sys.argv[1] + sys.argv = sys.argv[1:] + else: + print(_USAGE, file=sys.stderr) + sys.exit(1) + + try: + if mode == "module": + import runpy + + runpy.run_module(module, run_name="__main__", alter_sys=True) + elif mode == "script": + with open(script) as f: + # Execute the script in our namespace instead of creating + # a new one so that something that tries to import __main__ + # (e.g. the unittest module) will see names defined in the + # script instead of just those defined in this module. + global __file__ + __file__ = script + # If __package__ is defined, imports may be incorrectly + # interpreted as relative to this module. + global __package__ + del __package__ + exec_in(f.read(), globals(), globals()) + except SystemExit as e: + gen_log.info("Script exited with status %s", e.code) + except Exception as e: + gen_log.warning("Script exited with uncaught exception", exc_info=True) + # If an exception occurred at import time, the file with the error + # never made it into sys.modules and so we won't know to watch it. + # Just to make sure we've covered everything, walk the stack trace + # from the exception and watch every file. + for (filename, lineno, name, line) in traceback.extract_tb(sys.exc_info()[2]): + watch(filename) + if isinstance(e, SyntaxError): + # SyntaxErrors are special: their innermost stack frame is fake + # so extract_tb won't see it and we have to get the filename + # from the exception object. + if e.filename is not None: + watch(e.filename) + else: + gen_log.info("Script exited normally") + # restore sys.argv so subsequent executions will include autoreload + sys.argv = original_argv + + if mode == "module": + # runpy did a fake import of the module as __main__, but now it's + # no longer in sys.modules. Figure out where it is and watch it. + loader = pkgutil.get_loader(module) + if loader is not None: + watch(loader.get_filename()) # type: ignore + + wait() + + +if __name__ == "__main__": + # See also the other __main__ block at the top of the file, which modifies + # sys.path before our imports + main() diff --git a/venv/lib/python3.9/site-packages/tornado/concurrent.py b/venv/lib/python3.9/site-packages/tornado/concurrent.py new file mode 100644 index 00000000..6e05346b --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/concurrent.py @@ -0,0 +1,265 @@ +# +# Copyright 2012 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""Utilities for working with ``Future`` objects. + +Tornado previously provided its own ``Future`` class, but now uses +`asyncio.Future`. This module contains utility functions for working +with `asyncio.Future` in a way that is backwards-compatible with +Tornado's old ``Future`` implementation. + +While this module is an important part of Tornado's internal +implementation, applications rarely need to interact with it +directly. + +""" + +import asyncio +from concurrent import futures +import functools +import sys +import types + +from tornado.log import app_log + +import typing +from typing import Any, Callable, Optional, Tuple, Union + +_T = typing.TypeVar("_T") + + +class ReturnValueIgnoredError(Exception): + # No longer used; was previously used by @return_future + pass + + +Future = asyncio.Future + +FUTURES = (futures.Future, Future) + + +def is_future(x: Any) -> bool: + return isinstance(x, FUTURES) + + +class DummyExecutor(futures.Executor): + def submit( + self, fn: Callable[..., _T], *args: Any, **kwargs: Any + ) -> "futures.Future[_T]": + future = futures.Future() # type: futures.Future[_T] + try: + future_set_result_unless_cancelled(future, fn(*args, **kwargs)) + except Exception: + future_set_exc_info(future, sys.exc_info()) + return future + + def shutdown(self, wait: bool = True) -> None: + pass + + +dummy_executor = DummyExecutor() + + +def run_on_executor(*args: Any, **kwargs: Any) -> Callable: + """Decorator to run a synchronous method asynchronously on an executor. + + Returns a future. + + The executor to be used is determined by the ``executor`` + attributes of ``self``. To use a different attribute name, pass a + keyword argument to the decorator:: + + @run_on_executor(executor='_thread_pool') + def foo(self): + pass + + This decorator should not be confused with the similarly-named + `.IOLoop.run_in_executor`. In general, using ``run_in_executor`` + when *calling* a blocking method is recommended instead of using + this decorator when *defining* a method. If compatibility with older + versions of Tornado is required, consider defining an executor + and using ``executor.submit()`` at the call site. + + .. versionchanged:: 4.2 + Added keyword arguments to use alternative attributes. + + .. versionchanged:: 5.0 + Always uses the current IOLoop instead of ``self.io_loop``. + + .. versionchanged:: 5.1 + Returns a `.Future` compatible with ``await`` instead of a + `concurrent.futures.Future`. + + .. deprecated:: 5.1 + + The ``callback`` argument is deprecated and will be removed in + 6.0. The decorator itself is discouraged in new code but will + not be removed in 6.0. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. + """ + # Fully type-checking decorators is tricky, and this one is + # discouraged anyway so it doesn't have all the generic magic. + def run_on_executor_decorator(fn: Callable) -> Callable[..., Future]: + executor = kwargs.get("executor", "executor") + + @functools.wraps(fn) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Future: + async_future = Future() # type: Future + conc_future = getattr(self, executor).submit(fn, self, *args, **kwargs) + chain_future(conc_future, async_future) + return async_future + + return wrapper + + if args and kwargs: + raise ValueError("cannot combine positional and keyword args") + if len(args) == 1: + return run_on_executor_decorator(args[0]) + elif len(args) != 0: + raise ValueError("expected 1 argument, got %d", len(args)) + return run_on_executor_decorator + + +_NO_RESULT = object() + + +def chain_future(a: "Future[_T]", b: "Future[_T]") -> None: + """Chain two futures together so that when one completes, so does the other. + + The result (success or failure) of ``a`` will be copied to ``b``, unless + ``b`` has already been completed or cancelled by the time ``a`` finishes. + + .. versionchanged:: 5.0 + + Now accepts both Tornado/asyncio `Future` objects and + `concurrent.futures.Future`. + + """ + + def copy(future: "Future[_T]") -> None: + assert future is a + if b.done(): + return + if hasattr(a, "exc_info") and a.exc_info() is not None: # type: ignore + future_set_exc_info(b, a.exc_info()) # type: ignore + else: + a_exc = a.exception() + if a_exc is not None: + b.set_exception(a_exc) + else: + b.set_result(a.result()) + + if isinstance(a, Future): + future_add_done_callback(a, copy) + else: + # concurrent.futures.Future + from tornado.ioloop import IOLoop + + IOLoop.current().add_future(a, copy) + + +def future_set_result_unless_cancelled( + future: "Union[futures.Future[_T], Future[_T]]", value: _T +) -> None: + """Set the given ``value`` as the `Future`'s result, if not cancelled. + + Avoids ``asyncio.InvalidStateError`` when calling ``set_result()`` on + a cancelled `asyncio.Future`. + + .. versionadded:: 5.0 + """ + if not future.cancelled(): + future.set_result(value) + + +def future_set_exception_unless_cancelled( + future: "Union[futures.Future[_T], Future[_T]]", exc: BaseException +) -> None: + """Set the given ``exc`` as the `Future`'s exception. + + If the Future is already canceled, logs the exception instead. If + this logging is not desired, the caller should explicitly check + the state of the Future and call ``Future.set_exception`` instead of + this wrapper. + + Avoids ``asyncio.InvalidStateError`` when calling ``set_exception()`` on + a cancelled `asyncio.Future`. + + .. versionadded:: 6.0 + + """ + if not future.cancelled(): + future.set_exception(exc) + else: + app_log.error("Exception after Future was cancelled", exc_info=exc) + + +def future_set_exc_info( + future: "Union[futures.Future[_T], Future[_T]]", + exc_info: Tuple[ + Optional[type], Optional[BaseException], Optional[types.TracebackType] + ], +) -> None: + """Set the given ``exc_info`` as the `Future`'s exception. + + Understands both `asyncio.Future` and the extensions in older + versions of Tornado to enable better tracebacks on Python 2. + + .. versionadded:: 5.0 + + .. versionchanged:: 6.0 + + If the future is already cancelled, this function is a no-op. + (previously ``asyncio.InvalidStateError`` would be raised) + + """ + if exc_info[1] is None: + raise Exception("future_set_exc_info called with no exception") + future_set_exception_unless_cancelled(future, exc_info[1]) + + +@typing.overload +def future_add_done_callback( + future: "futures.Future[_T]", callback: Callable[["futures.Future[_T]"], None] +) -> None: + pass + + +@typing.overload # noqa: F811 +def future_add_done_callback( + future: "Future[_T]", callback: Callable[["Future[_T]"], None] +) -> None: + pass + + +def future_add_done_callback( # noqa: F811 + future: "Union[futures.Future[_T], Future[_T]]", callback: Callable[..., None] +) -> None: + """Arrange to call ``callback`` when ``future`` is complete. + + ``callback`` is invoked with one argument, the ``future``. + + If ``future`` is already done, ``callback`` is invoked immediately. + This may differ from the behavior of ``Future.add_done_callback``, + which makes no such guarantee. + + .. versionadded:: 5.0 + """ + if future.done(): + callback(future) + else: + future.add_done_callback(callback) diff --git a/venv/lib/python3.9/site-packages/tornado/curl_httpclient.py b/venv/lib/python3.9/site-packages/tornado/curl_httpclient.py new file mode 100644 index 00000000..23320e48 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/curl_httpclient.py @@ -0,0 +1,584 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Non-blocking HTTP client implementation using pycurl.""" + +import collections +import functools +import logging +import pycurl +import threading +import time +from io import BytesIO + +from tornado import httputil +from tornado import ioloop + +from tornado.escape import utf8, native_str +from tornado.httpclient import ( + HTTPRequest, + HTTPResponse, + HTTPError, + AsyncHTTPClient, + main, +) +from tornado.log import app_log + +from typing import Dict, Any, Callable, Union, Optional +import typing + +if typing.TYPE_CHECKING: + from typing import Deque, Tuple # noqa: F401 + +curl_log = logging.getLogger("tornado.curl_httpclient") + + +class CurlAsyncHTTPClient(AsyncHTTPClient): + def initialize( # type: ignore + self, max_clients: int = 10, defaults: Optional[Dict[str, Any]] = None + ) -> None: + super().initialize(defaults=defaults) + # Typeshed is incomplete for CurlMulti, so just use Any for now. + self._multi = pycurl.CurlMulti() # type: Any + self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout) + self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket) + self._curls = [self._curl_create() for i in range(max_clients)] + self._free_list = self._curls[:] + self._requests = ( + collections.deque() + ) # type: Deque[Tuple[HTTPRequest, Callable[[HTTPResponse], None], float]] + self._fds = {} # type: Dict[int, int] + self._timeout = None # type: Optional[object] + + # libcurl has bugs that sometimes cause it to not report all + # relevant file descriptors and timeouts to TIMERFUNCTION/ + # SOCKETFUNCTION. Mitigate the effects of such bugs by + # forcing a periodic scan of all active requests. + self._force_timeout_callback = ioloop.PeriodicCallback( + self._handle_force_timeout, 1000 + ) + self._force_timeout_callback.start() + + # Work around a bug in libcurl 7.29.0: Some fields in the curl + # multi object are initialized lazily, and its destructor will + # segfault if it is destroyed without having been used. Add + # and remove a dummy handle to make sure everything is + # initialized. + dummy_curl_handle = pycurl.Curl() + self._multi.add_handle(dummy_curl_handle) + self._multi.remove_handle(dummy_curl_handle) + + def close(self) -> None: + self._force_timeout_callback.stop() + if self._timeout is not None: + self.io_loop.remove_timeout(self._timeout) + for curl in self._curls: + curl.close() + self._multi.close() + super().close() + + # Set below properties to None to reduce the reference count of current + # instance, because those properties hold some methods of current + # instance that will case circular reference. + self._force_timeout_callback = None # type: ignore + self._multi = None + + def fetch_impl( + self, request: HTTPRequest, callback: Callable[[HTTPResponse], None] + ) -> None: + self._requests.append((request, callback, self.io_loop.time())) + self._process_queue() + self._set_timeout(0) + + def _handle_socket(self, event: int, fd: int, multi: Any, data: bytes) -> None: + """Called by libcurl when it wants to change the file descriptors + it cares about. + """ + event_map = { + pycurl.POLL_NONE: ioloop.IOLoop.NONE, + pycurl.POLL_IN: ioloop.IOLoop.READ, + pycurl.POLL_OUT: ioloop.IOLoop.WRITE, + pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE, + } + if event == pycurl.POLL_REMOVE: + if fd in self._fds: + self.io_loop.remove_handler(fd) + del self._fds[fd] + else: + ioloop_event = event_map[event] + # libcurl sometimes closes a socket and then opens a new + # one using the same FD without giving us a POLL_NONE in + # between. This is a problem with the epoll IOLoop, + # because the kernel can tell when a socket is closed and + # removes it from the epoll automatically, causing future + # update_handler calls to fail. Since we can't tell when + # this has happened, always use remove and re-add + # instead of update. + if fd in self._fds: + self.io_loop.remove_handler(fd) + self.io_loop.add_handler(fd, self._handle_events, ioloop_event) + self._fds[fd] = ioloop_event + + def _set_timeout(self, msecs: int) -> None: + """Called by libcurl to schedule a timeout.""" + if self._timeout is not None: + self.io_loop.remove_timeout(self._timeout) + self._timeout = self.io_loop.add_timeout( + self.io_loop.time() + msecs / 1000.0, self._handle_timeout + ) + + def _handle_events(self, fd: int, events: int) -> None: + """Called by IOLoop when there is activity on one of our + file descriptors. + """ + action = 0 + if events & ioloop.IOLoop.READ: + action |= pycurl.CSELECT_IN + if events & ioloop.IOLoop.WRITE: + action |= pycurl.CSELECT_OUT + while True: + try: + ret, num_handles = self._multi.socket_action(fd, action) + except pycurl.error as e: + ret = e.args[0] + if ret != pycurl.E_CALL_MULTI_PERFORM: + break + self._finish_pending_requests() + + def _handle_timeout(self) -> None: + """Called by IOLoop when the requested timeout has passed.""" + self._timeout = None + while True: + try: + ret, num_handles = self._multi.socket_action(pycurl.SOCKET_TIMEOUT, 0) + except pycurl.error as e: + ret = e.args[0] + if ret != pycurl.E_CALL_MULTI_PERFORM: + break + self._finish_pending_requests() + + # In theory, we shouldn't have to do this because curl will + # call _set_timeout whenever the timeout changes. However, + # sometimes after _handle_timeout we will need to reschedule + # immediately even though nothing has changed from curl's + # perspective. This is because when socket_action is + # called with SOCKET_TIMEOUT, libcurl decides internally which + # timeouts need to be processed by using a monotonic clock + # (where available) while tornado uses python's time.time() + # to decide when timeouts have occurred. When those clocks + # disagree on elapsed time (as they will whenever there is an + # NTP adjustment), tornado might call _handle_timeout before + # libcurl is ready. After each timeout, resync the scheduled + # timeout with libcurl's current state. + new_timeout = self._multi.timeout() + if new_timeout >= 0: + self._set_timeout(new_timeout) + + def _handle_force_timeout(self) -> None: + """Called by IOLoop periodically to ask libcurl to process any + events it may have forgotten about. + """ + while True: + try: + ret, num_handles = self._multi.socket_all() + except pycurl.error as e: + ret = e.args[0] + if ret != pycurl.E_CALL_MULTI_PERFORM: + break + self._finish_pending_requests() + + def _finish_pending_requests(self) -> None: + """Process any requests that were completed by the last + call to multi.socket_action. + """ + while True: + num_q, ok_list, err_list = self._multi.info_read() + for curl in ok_list: + self._finish(curl) + for curl, errnum, errmsg in err_list: + self._finish(curl, errnum, errmsg) + if num_q == 0: + break + self._process_queue() + + def _process_queue(self) -> None: + while True: + started = 0 + while self._free_list and self._requests: + started += 1 + curl = self._free_list.pop() + (request, callback, queue_start_time) = self._requests.popleft() + # TODO: Don't smuggle extra data on an attribute of the Curl object. + curl.info = { # type: ignore + "headers": httputil.HTTPHeaders(), + "buffer": BytesIO(), + "request": request, + "callback": callback, + "queue_start_time": queue_start_time, + "curl_start_time": time.time(), + "curl_start_ioloop_time": self.io_loop.current().time(), # type: ignore + } + try: + self._curl_setup_request( + curl, + request, + curl.info["buffer"], # type: ignore + curl.info["headers"], # type: ignore + ) + except Exception as e: + # If there was an error in setup, pass it on + # to the callback. Note that allowing the + # error to escape here will appear to work + # most of the time since we are still in the + # caller's original stack frame, but when + # _process_queue() is called from + # _finish_pending_requests the exceptions have + # nowhere to go. + self._free_list.append(curl) + callback(HTTPResponse(request=request, code=599, error=e)) + else: + self._multi.add_handle(curl) + + if not started: + break + + def _finish( + self, + curl: pycurl.Curl, + curl_error: Optional[int] = None, + curl_message: Optional[str] = None, + ) -> None: + info = curl.info # type: ignore + curl.info = None # type: ignore + self._multi.remove_handle(curl) + self._free_list.append(curl) + buffer = info["buffer"] + if curl_error: + assert curl_message is not None + error = CurlError(curl_error, curl_message) # type: Optional[CurlError] + assert error is not None + code = error.code + effective_url = None + buffer.close() + buffer = None + else: + error = None + code = curl.getinfo(pycurl.HTTP_CODE) + effective_url = curl.getinfo(pycurl.EFFECTIVE_URL) + buffer.seek(0) + # the various curl timings are documented at + # http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html + time_info = dict( + queue=info["curl_start_ioloop_time"] - info["queue_start_time"], + namelookup=curl.getinfo(pycurl.NAMELOOKUP_TIME), + connect=curl.getinfo(pycurl.CONNECT_TIME), + appconnect=curl.getinfo(pycurl.APPCONNECT_TIME), + pretransfer=curl.getinfo(pycurl.PRETRANSFER_TIME), + starttransfer=curl.getinfo(pycurl.STARTTRANSFER_TIME), + total=curl.getinfo(pycurl.TOTAL_TIME), + redirect=curl.getinfo(pycurl.REDIRECT_TIME), + ) + try: + info["callback"]( + HTTPResponse( + request=info["request"], + code=code, + headers=info["headers"], + buffer=buffer, + effective_url=effective_url, + error=error, + reason=info["headers"].get("X-Http-Reason", None), + request_time=self.io_loop.time() - info["curl_start_ioloop_time"], + start_time=info["curl_start_time"], + time_info=time_info, + ) + ) + except Exception: + self.handle_callback_exception(info["callback"]) + + def handle_callback_exception(self, callback: Any) -> None: + app_log.error("Exception in callback %r", callback, exc_info=True) + + def _curl_create(self) -> pycurl.Curl: + curl = pycurl.Curl() + if curl_log.isEnabledFor(logging.DEBUG): + curl.setopt(pycurl.VERBOSE, 1) + curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug) + if hasattr( + pycurl, "PROTOCOLS" + ): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12) + curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS) + curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS) + return curl + + def _curl_setup_request( + self, + curl: pycurl.Curl, + request: HTTPRequest, + buffer: BytesIO, + headers: httputil.HTTPHeaders, + ) -> None: + curl.setopt(pycurl.URL, native_str(request.url)) + + # libcurl's magic "Expect: 100-continue" behavior causes delays + # with servers that don't support it (which include, among others, + # Google's OpenID endpoint). Additionally, this behavior has + # a bug in conjunction with the curl_multi_socket_action API + # (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976), + # which increases the delays. It's more trouble than it's worth, + # so just turn off the feature (yes, setting Expect: to an empty + # value is the official way to disable this) + if "Expect" not in request.headers: + request.headers["Expect"] = "" + + # libcurl adds Pragma: no-cache by default; disable that too + if "Pragma" not in request.headers: + request.headers["Pragma"] = "" + + curl.setopt( + pycurl.HTTPHEADER, + [ + b"%s: %s" + % (native_str(k).encode("ASCII"), native_str(v).encode("ISO8859-1")) + for k, v in request.headers.get_all() + ], + ) + + curl.setopt( + pycurl.HEADERFUNCTION, + functools.partial( + self._curl_header_callback, headers, request.header_callback + ), + ) + if request.streaming_callback: + + def write_function(b: Union[bytes, bytearray]) -> int: + assert request.streaming_callback is not None + self.io_loop.add_callback(request.streaming_callback, b) + return len(b) + + else: + write_function = buffer.write # type: ignore + curl.setopt(pycurl.WRITEFUNCTION, write_function) + curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects) + curl.setopt(pycurl.MAXREDIRS, request.max_redirects) + assert request.connect_timeout is not None + curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout)) + assert request.request_timeout is not None + curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout)) + if request.user_agent: + curl.setopt(pycurl.USERAGENT, native_str(request.user_agent)) + else: + curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)") + if request.network_interface: + curl.setopt(pycurl.INTERFACE, request.network_interface) + if request.decompress_response: + curl.setopt(pycurl.ENCODING, "gzip,deflate") + else: + curl.setopt(pycurl.ENCODING, None) + if request.proxy_host and request.proxy_port: + curl.setopt(pycurl.PROXY, request.proxy_host) + curl.setopt(pycurl.PROXYPORT, request.proxy_port) + if request.proxy_username: + assert request.proxy_password is not None + credentials = httputil.encode_username_password( + request.proxy_username, request.proxy_password + ) + curl.setopt(pycurl.PROXYUSERPWD, credentials) + + if request.proxy_auth_mode is None or request.proxy_auth_mode == "basic": + curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC) + elif request.proxy_auth_mode == "digest": + curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST) + else: + raise ValueError( + "Unsupported proxy_auth_mode %s" % request.proxy_auth_mode + ) + else: + try: + curl.unsetopt(pycurl.PROXY) + except TypeError: # not supported, disable proxy + curl.setopt(pycurl.PROXY, "") + curl.unsetopt(pycurl.PROXYUSERPWD) + if request.validate_cert: + curl.setopt(pycurl.SSL_VERIFYPEER, 1) + curl.setopt(pycurl.SSL_VERIFYHOST, 2) + else: + curl.setopt(pycurl.SSL_VERIFYPEER, 0) + curl.setopt(pycurl.SSL_VERIFYHOST, 0) + if request.ca_certs is not None: + curl.setopt(pycurl.CAINFO, request.ca_certs) + else: + # There is no way to restore pycurl.CAINFO to its default value + # (Using unsetopt makes it reject all certificates). + # I don't see any way to read the default value from python so it + # can be restored later. We'll have to just leave CAINFO untouched + # if no ca_certs file was specified, and require that if any + # request uses a custom ca_certs file, they all must. + pass + + if request.allow_ipv6 is False: + # Curl behaves reasonably when DNS resolution gives an ipv6 address + # that we can't reach, so allow ipv6 unless the user asks to disable. + curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4) + else: + curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER) + + # Set the request method through curl's irritating interface which makes + # up names for almost every single method + curl_options = { + "GET": pycurl.HTTPGET, + "POST": pycurl.POST, + "PUT": pycurl.UPLOAD, + "HEAD": pycurl.NOBODY, + } + custom_methods = set(["DELETE", "OPTIONS", "PATCH"]) + for o in curl_options.values(): + curl.setopt(o, False) + if request.method in curl_options: + curl.unsetopt(pycurl.CUSTOMREQUEST) + curl.setopt(curl_options[request.method], True) + elif request.allow_nonstandard_methods or request.method in custom_methods: + curl.setopt(pycurl.CUSTOMREQUEST, request.method) + else: + raise KeyError("unknown method " + request.method) + + body_expected = request.method in ("POST", "PATCH", "PUT") + body_present = request.body is not None + if not request.allow_nonstandard_methods: + # Some HTTP methods nearly always have bodies while others + # almost never do. Fail in this case unless the user has + # opted out of sanity checks with allow_nonstandard_methods. + if (body_expected and not body_present) or ( + body_present and not body_expected + ): + raise ValueError( + "Body must %sbe None for method %s (unless " + "allow_nonstandard_methods is true)" + % ("not " if body_expected else "", request.method) + ) + + if body_expected or body_present: + if request.method == "GET": + # Even with `allow_nonstandard_methods` we disallow + # GET with a body (because libcurl doesn't allow it + # unless we use CUSTOMREQUEST). While the spec doesn't + # forbid clients from sending a body, it arguably + # disallows the server from doing anything with them. + raise ValueError("Body must be None for GET request") + request_buffer = BytesIO(utf8(request.body or "")) + + def ioctl(cmd: int) -> None: + if cmd == curl.IOCMD_RESTARTREAD: # type: ignore + request_buffer.seek(0) + + curl.setopt(pycurl.READFUNCTION, request_buffer.read) + curl.setopt(pycurl.IOCTLFUNCTION, ioctl) + if request.method == "POST": + curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or "")) + else: + curl.setopt(pycurl.UPLOAD, True) + curl.setopt(pycurl.INFILESIZE, len(request.body or "")) + + if request.auth_username is not None: + assert request.auth_password is not None + if request.auth_mode is None or request.auth_mode == "basic": + curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC) + elif request.auth_mode == "digest": + curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST) + else: + raise ValueError("Unsupported auth_mode %s" % request.auth_mode) + + userpwd = httputil.encode_username_password( + request.auth_username, request.auth_password + ) + curl.setopt(pycurl.USERPWD, userpwd) + curl_log.debug( + "%s %s (username: %r)", + request.method, + request.url, + request.auth_username, + ) + else: + curl.unsetopt(pycurl.USERPWD) + curl_log.debug("%s %s", request.method, request.url) + + if request.client_cert is not None: + curl.setopt(pycurl.SSLCERT, request.client_cert) + + if request.client_key is not None: + curl.setopt(pycurl.SSLKEY, request.client_key) + + if request.ssl_options is not None: + raise ValueError("ssl_options not supported in curl_httpclient") + + if threading.active_count() > 1: + # libcurl/pycurl is not thread-safe by default. When multiple threads + # are used, signals should be disabled. This has the side effect + # of disabling DNS timeouts in some environments (when libcurl is + # not linked against ares), so we don't do it when there is only one + # thread. Applications that use many short-lived threads may need + # to set NOSIGNAL manually in a prepare_curl_callback since + # there may not be any other threads running at the time we call + # threading.activeCount. + curl.setopt(pycurl.NOSIGNAL, 1) + if request.prepare_curl_callback is not None: + request.prepare_curl_callback(curl) + + def _curl_header_callback( + self, + headers: httputil.HTTPHeaders, + header_callback: Callable[[str], None], + header_line_bytes: bytes, + ) -> None: + header_line = native_str(header_line_bytes.decode("latin1")) + if header_callback is not None: + self.io_loop.add_callback(header_callback, header_line) + # header_line as returned by curl includes the end-of-line characters. + # whitespace at the start should be preserved to allow multi-line headers + header_line = header_line.rstrip() + if header_line.startswith("HTTP/"): + headers.clear() + try: + (__, __, reason) = httputil.parse_response_start_line(header_line) + header_line = "X-Http-Reason: %s" % reason + except httputil.HTTPInputError: + return + if not header_line: + return + headers.parse_line(header_line) + + def _curl_debug(self, debug_type: int, debug_msg: str) -> None: + debug_types = ("I", "<", ">", "<", ">") + if debug_type == 0: + debug_msg = native_str(debug_msg) + curl_log.debug("%s", debug_msg.strip()) + elif debug_type in (1, 2): + debug_msg = native_str(debug_msg) + for line in debug_msg.splitlines(): + curl_log.debug("%s %s", debug_types[debug_type], line) + elif debug_type == 4: + curl_log.debug("%s %r", debug_types[debug_type], debug_msg) + + +class CurlError(HTTPError): + def __init__(self, errno: int, message: str) -> None: + HTTPError.__init__(self, 599, message) + self.errno = errno + + +if __name__ == "__main__": + AsyncHTTPClient.configure(CurlAsyncHTTPClient) + main() diff --git a/venv/lib/python3.9/site-packages/tornado/escape.py b/venv/lib/python3.9/site-packages/tornado/escape.py new file mode 100644 index 00000000..55354c30 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/escape.py @@ -0,0 +1,402 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Escaping/unescaping methods for HTML, JSON, URLs, and others. + +Also includes a few other miscellaneous string manipulation functions that +have crept in over time. +""" + +import html.entities +import json +import re +import urllib.parse + +from tornado.util import unicode_type + +import typing +from typing import Union, Any, Optional, Dict, List, Callable + + +_XHTML_ESCAPE_RE = re.compile("[&<>\"']") +_XHTML_ESCAPE_DICT = { + "&": "&", + "<": "<", + ">": ">", + '"': """, + "'": "'", +} + + +def xhtml_escape(value: Union[str, bytes]) -> str: + """Escapes a string so it is valid within HTML or XML. + + Escapes the characters ``<``, ``>``, ``"``, ``'``, and ``&``. + When used in attribute values the escaped strings must be enclosed + in quotes. + + .. versionchanged:: 3.2 + + Added the single quote to the list of escaped characters. + """ + return _XHTML_ESCAPE_RE.sub( + lambda match: _XHTML_ESCAPE_DICT[match.group(0)], to_basestring(value) + ) + + +def xhtml_unescape(value: Union[str, bytes]) -> str: + """Un-escapes an XML-escaped string.""" + return re.sub(r"&(#?)(\w+?);", _convert_entity, _unicode(value)) + + +# The fact that json_encode wraps json.dumps is an implementation detail. +# Please see https://github.com/tornadoweb/tornado/pull/706 +# before sending a pull request that adds **kwargs to this function. +def json_encode(value: Any) -> str: + """JSON-encodes the given Python object.""" + # JSON permits but does not require forward slashes to be escaped. + # This is useful when json data is emitted in a <script> tag + # in HTML, as it prevents </script> tags from prematurely terminating + # the JavaScript. Some json libraries do this escaping by default, + # although python's standard library does not, so we do it here. + # http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped + return json.dumps(value).replace("</", "<\\/") + + +def json_decode(value: Union[str, bytes]) -> Any: + """Returns Python objects for the given JSON string. + + Supports both `str` and `bytes` inputs. + """ + return json.loads(to_basestring(value)) + + +def squeeze(value: str) -> str: + """Replace all sequences of whitespace chars with a single space.""" + return re.sub(r"[\x00-\x20]+", " ", value).strip() + + +def url_escape(value: Union[str, bytes], plus: bool = True) -> str: + """Returns a URL-encoded version of the given value. + + If ``plus`` is true (the default), spaces will be represented + as "+" instead of "%20". This is appropriate for query strings + but not for the path component of a URL. Note that this default + is the reverse of Python's urllib module. + + .. versionadded:: 3.1 + The ``plus`` argument + """ + quote = urllib.parse.quote_plus if plus else urllib.parse.quote + return quote(utf8(value)) + + +@typing.overload +def url_unescape(value: Union[str, bytes], encoding: None, plus: bool = True) -> bytes: + pass + + +@typing.overload # noqa: F811 +def url_unescape( + value: Union[str, bytes], encoding: str = "utf-8", plus: bool = True +) -> str: + pass + + +def url_unescape( # noqa: F811 + value: Union[str, bytes], encoding: Optional[str] = "utf-8", plus: bool = True +) -> Union[str, bytes]: + """Decodes the given value from a URL. + + The argument may be either a byte or unicode string. + + If encoding is None, the result will be a byte string. Otherwise, + the result is a unicode string in the specified encoding. + + If ``plus`` is true (the default), plus signs will be interpreted + as spaces (literal plus signs must be represented as "%2B"). This + is appropriate for query strings and form-encoded values but not + for the path component of a URL. Note that this default is the + reverse of Python's urllib module. + + .. versionadded:: 3.1 + The ``plus`` argument + """ + if encoding is None: + if plus: + # unquote_to_bytes doesn't have a _plus variant + value = to_basestring(value).replace("+", " ") + return urllib.parse.unquote_to_bytes(value) + else: + unquote = urllib.parse.unquote_plus if plus else urllib.parse.unquote + return unquote(to_basestring(value), encoding=encoding) + + +def parse_qs_bytes( + qs: Union[str, bytes], keep_blank_values: bool = False, strict_parsing: bool = False +) -> Dict[str, List[bytes]]: + """Parses a query string like urlparse.parse_qs, + but takes bytes and returns the values as byte strings. + + Keys still become type str (interpreted as latin1 in python3!) + because it's too painful to keep them as byte strings in + python3 and in practice they're nearly always ascii anyway. + """ + # This is gross, but python3 doesn't give us another way. + # Latin1 is the universal donor of character encodings. + if isinstance(qs, bytes): + qs = qs.decode("latin1") + result = urllib.parse.parse_qs( + qs, keep_blank_values, strict_parsing, encoding="latin1", errors="strict" + ) + encoded = {} + for k, v in result.items(): + encoded[k] = [i.encode("latin1") for i in v] + return encoded + + +_UTF8_TYPES = (bytes, type(None)) + + +@typing.overload +def utf8(value: bytes) -> bytes: + pass + + +@typing.overload # noqa: F811 +def utf8(value: str) -> bytes: + pass + + +@typing.overload # noqa: F811 +def utf8(value: None) -> None: + pass + + +def utf8(value: Union[None, str, bytes]) -> Optional[bytes]: # noqa: F811 + """Converts a string argument to a byte string. + + If the argument is already a byte string or None, it is returned unchanged. + Otherwise it must be a unicode string and is encoded as utf8. + """ + if isinstance(value, _UTF8_TYPES): + return value + if not isinstance(value, unicode_type): + raise TypeError("Expected bytes, unicode, or None; got %r" % type(value)) + return value.encode("utf-8") + + +_TO_UNICODE_TYPES = (unicode_type, type(None)) + + +@typing.overload +def to_unicode(value: str) -> str: + pass + + +@typing.overload # noqa: F811 +def to_unicode(value: bytes) -> str: + pass + + +@typing.overload # noqa: F811 +def to_unicode(value: None) -> None: + pass + + +def to_unicode(value: Union[None, str, bytes]) -> Optional[str]: # noqa: F811 + """Converts a string argument to a unicode string. + + If the argument is already a unicode string or None, it is returned + unchanged. Otherwise it must be a byte string and is decoded as utf8. + """ + if isinstance(value, _TO_UNICODE_TYPES): + return value + if not isinstance(value, bytes): + raise TypeError("Expected bytes, unicode, or None; got %r" % type(value)) + return value.decode("utf-8") + + +# to_unicode was previously named _unicode not because it was private, +# but to avoid conflicts with the built-in unicode() function/type +_unicode = to_unicode + +# When dealing with the standard library across python 2 and 3 it is +# sometimes useful to have a direct conversion to the native string type +native_str = to_unicode +to_basestring = to_unicode + + +def recursive_unicode(obj: Any) -> Any: + """Walks a simple data structure, converting byte strings to unicode. + + Supports lists, tuples, and dictionaries. + """ + if isinstance(obj, dict): + return dict( + (recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items() + ) + elif isinstance(obj, list): + return list(recursive_unicode(i) for i in obj) + elif isinstance(obj, tuple): + return tuple(recursive_unicode(i) for i in obj) + elif isinstance(obj, bytes): + return to_unicode(obj) + else: + return obj + + +# I originally used the regex from +# http://daringfireball.net/2010/07/improved_regex_for_matching_urls +# but it gets all exponential on certain patterns (such as too many trailing +# dots), causing the regex matcher to never return. +# This regex should avoid those problems. +# Use to_unicode instead of tornado.util.u - we don't want backslashes getting +# processed as escapes. +_URL_RE = re.compile( + to_unicode( + r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&|")*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&|")*\)))+)""" # noqa: E501 + ) +) + + +def linkify( + text: Union[str, bytes], + shorten: bool = False, + extra_params: Union[str, Callable[[str], str]] = "", + require_protocol: bool = False, + permitted_protocols: List[str] = ["http", "https"], +) -> str: + """Converts plain text into HTML with links. + + For example: ``linkify("Hello http://tornadoweb.org!")`` would return + ``Hello <a href="http://tornadoweb.org">http://tornadoweb.org</a>!`` + + Parameters: + + * ``shorten``: Long urls will be shortened for display. + + * ``extra_params``: Extra text to include in the link tag, or a callable + taking the link as an argument and returning the extra text + e.g. ``linkify(text, extra_params='rel="nofollow" class="external"')``, + or:: + + def extra_params_cb(url): + if url.startswith("http://example.com"): + return 'class="internal"' + else: + return 'class="external" rel="nofollow"' + linkify(text, extra_params=extra_params_cb) + + * ``require_protocol``: Only linkify urls which include a protocol. If + this is False, urls such as www.facebook.com will also be linkified. + + * ``permitted_protocols``: List (or set) of protocols which should be + linkified, e.g. ``linkify(text, permitted_protocols=["http", "ftp", + "mailto"])``. It is very unsafe to include protocols such as + ``javascript``. + """ + if extra_params and not callable(extra_params): + extra_params = " " + extra_params.strip() + + def make_link(m: typing.Match) -> str: + url = m.group(1) + proto = m.group(2) + if require_protocol and not proto: + return url # not protocol, no linkify + + if proto and proto not in permitted_protocols: + return url # bad protocol, no linkify + + href = m.group(1) + if not proto: + href = "http://" + href # no proto specified, use http + + if callable(extra_params): + params = " " + extra_params(href).strip() + else: + params = extra_params + + # clip long urls. max_len is just an approximation + max_len = 30 + if shorten and len(url) > max_len: + before_clip = url + if proto: + proto_len = len(proto) + 1 + len(m.group(3) or "") # +1 for : + else: + proto_len = 0 + + parts = url[proto_len:].split("/") + if len(parts) > 1: + # Grab the whole host part plus the first bit of the path + # The path is usually not that interesting once shortened + # (no more slug, etc), so it really just provides a little + # extra indication of shortening. + url = ( + url[:proto_len] + + parts[0] + + "/" + + parts[1][:8].split("?")[0].split(".")[0] + ) + + if len(url) > max_len * 1.5: # still too long + url = url[:max_len] + + if url != before_clip: + amp = url.rfind("&") + # avoid splitting html char entities + if amp > max_len - 5: + url = url[:amp] + url += "..." + + if len(url) >= len(before_clip): + url = before_clip + else: + # full url is visible on mouse-over (for those who don't + # have a status bar, such as Safari by default) + params += ' title="%s"' % href + + return '<a href="%s"%s>%s</a>' % (href, params, url) + + # First HTML-escape so that our strings are all safe. + # The regex is modified to avoid character entites other than & so + # that we won't pick up ", etc. + text = _unicode(xhtml_escape(text)) + return _URL_RE.sub(make_link, text) + + +def _convert_entity(m: typing.Match) -> str: + if m.group(1) == "#": + try: + if m.group(2)[:1].lower() == "x": + return chr(int(m.group(2)[1:], 16)) + else: + return chr(int(m.group(2))) + except ValueError: + return "&#%s;" % m.group(2) + try: + return _HTML_UNICODE_MAP[m.group(2)] + except KeyError: + return "&%s;" % m.group(2) + + +def _build_unicode_map() -> Dict[str, str]: + unicode_map = {} + for name, value in html.entities.name2codepoint.items(): + unicode_map[name] = chr(value) + return unicode_map + + +_HTML_UNICODE_MAP = _build_unicode_map() diff --git a/venv/lib/python3.9/site-packages/tornado/gen.py b/venv/lib/python3.9/site-packages/tornado/gen.py new file mode 100644 index 00000000..4819b857 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/gen.py @@ -0,0 +1,883 @@ +"""``tornado.gen`` implements generator-based coroutines. + +.. note:: + + The "decorator and generator" approach in this module is a + precursor to native coroutines (using ``async def`` and ``await``) + which were introduced in Python 3.5. Applications that do not + require compatibility with older versions of Python should use + native coroutines instead. Some parts of this module are still + useful with native coroutines, notably `multi`, `sleep`, + `WaitIterator`, and `with_timeout`. Some of these functions have + counterparts in the `asyncio` module which may be used as well, + although the two may not necessarily be 100% compatible. + +Coroutines provide an easier way to work in an asynchronous +environment than chaining callbacks. Code using coroutines is +technically asynchronous, but it is written as a single generator +instead of a collection of separate functions. + +For example, here's a coroutine-based handler: + +.. testcode:: + + class GenAsyncHandler(RequestHandler): + @gen.coroutine + def get(self): + http_client = AsyncHTTPClient() + response = yield http_client.fetch("http://example.com") + do_something_with_response(response) + self.render("template.html") + +.. testoutput:: + :hide: + +Asynchronous functions in Tornado return an ``Awaitable`` or `.Future`; +yielding this object returns its result. + +You can also yield a list or dict of other yieldable objects, which +will be started at the same time and run in parallel; a list or dict +of results will be returned when they are all finished: + +.. testcode:: + + @gen.coroutine + def get(self): + http_client = AsyncHTTPClient() + response1, response2 = yield [http_client.fetch(url1), + http_client.fetch(url2)] + response_dict = yield dict(response3=http_client.fetch(url3), + response4=http_client.fetch(url4)) + response3 = response_dict['response3'] + response4 = response_dict['response4'] + +.. testoutput:: + :hide: + +If ``tornado.platform.twisted`` is imported, it is also possible to +yield Twisted's ``Deferred`` objects. See the `convert_yielded` +function to extend this mechanism. + +.. versionchanged:: 3.2 + Dict support added. + +.. versionchanged:: 4.1 + Support added for yielding ``asyncio`` Futures and Twisted Deferreds + via ``singledispatch``. + +""" +import asyncio +import builtins +import collections +from collections.abc import Generator +import concurrent.futures +import datetime +import functools +from functools import singledispatch +from inspect import isawaitable +import sys +import types + +from tornado.concurrent import ( + Future, + is_future, + chain_future, + future_set_exc_info, + future_add_done_callback, + future_set_result_unless_cancelled, +) +from tornado.ioloop import IOLoop +from tornado.log import app_log +from tornado.util import TimeoutError + +try: + import contextvars +except ImportError: + contextvars = None # type: ignore + +import typing +from typing import Union, Any, Callable, List, Type, Tuple, Awaitable, Dict, overload + +if typing.TYPE_CHECKING: + from typing import Sequence, Deque, Optional, Set, Iterable # noqa: F401 + +_T = typing.TypeVar("_T") + +_Yieldable = Union[ + None, Awaitable, List[Awaitable], Dict[Any, Awaitable], concurrent.futures.Future +] + + +class KeyReuseError(Exception): + pass + + +class UnknownKeyError(Exception): + pass + + +class LeakedCallbackError(Exception): + pass + + +class BadYieldError(Exception): + pass + + +class ReturnValueIgnoredError(Exception): + pass + + +def _value_from_stopiteration(e: Union[StopIteration, "Return"]) -> Any: + try: + # StopIteration has a value attribute beginning in py33. + # So does our Return class. + return e.value + except AttributeError: + pass + try: + # Cython backports coroutine functionality by putting the value in + # e.args[0]. + return e.args[0] + except (AttributeError, IndexError): + return None + + +def _create_future() -> Future: + future = Future() # type: Future + # Fixup asyncio debug info by removing extraneous stack entries + source_traceback = getattr(future, "_source_traceback", ()) + while source_traceback: + # Each traceback entry is equivalent to a + # (filename, self.lineno, self.name, self.line) tuple + filename = source_traceback[-1][0] + if filename == __file__: + del source_traceback[-1] + else: + break + return future + + +def _fake_ctx_run(f: Callable[..., _T], *args: Any, **kw: Any) -> _T: + return f(*args, **kw) + + +@overload +def coroutine( + func: Callable[..., "Generator[Any, Any, _T]"] +) -> Callable[..., "Future[_T]"]: + ... + + +@overload +def coroutine(func: Callable[..., _T]) -> Callable[..., "Future[_T]"]: + ... + + +def coroutine( + func: Union[Callable[..., "Generator[Any, Any, _T]"], Callable[..., _T]] +) -> Callable[..., "Future[_T]"]: + """Decorator for asynchronous generators. + + For compatibility with older versions of Python, coroutines may + also "return" by raising the special exception `Return(value) + <Return>`. + + Functions with this decorator return a `.Future`. + + .. warning:: + + When exceptions occur inside a coroutine, the exception + information will be stored in the `.Future` object. You must + examine the result of the `.Future` object, or the exception + may go unnoticed by your code. This means yielding the function + if called from another coroutine, using something like + `.IOLoop.run_sync` for top-level calls, or passing the `.Future` + to `.IOLoop.add_future`. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + awaitable object instead. + + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # type: (*Any, **Any) -> Future[_T] + # This function is type-annotated with a comment to work around + # https://bitbucket.org/pypy/pypy/issues/2868/segfault-with-args-type-annotation-in + future = _create_future() + if contextvars is not None: + ctx_run = contextvars.copy_context().run # type: Callable + else: + ctx_run = _fake_ctx_run + try: + result = ctx_run(func, *args, **kwargs) + except (Return, StopIteration) as e: + result = _value_from_stopiteration(e) + except Exception: + future_set_exc_info(future, sys.exc_info()) + try: + return future + finally: + # Avoid circular references + future = None # type: ignore + else: + if isinstance(result, Generator): + # Inline the first iteration of Runner.run. This lets us + # avoid the cost of creating a Runner when the coroutine + # never actually yields, which in turn allows us to + # use "optional" coroutines in critical path code without + # performance penalty for the synchronous case. + try: + yielded = ctx_run(next, result) + except (StopIteration, Return) as e: + future_set_result_unless_cancelled( + future, _value_from_stopiteration(e) + ) + except Exception: + future_set_exc_info(future, sys.exc_info()) + else: + # Provide strong references to Runner objects as long + # as their result future objects also have strong + # references (typically from the parent coroutine's + # Runner). This keeps the coroutine's Runner alive. + # We do this by exploiting the public API + # add_done_callback() instead of putting a private + # attribute on the Future. + # (GitHub issues #1769, #2229). + runner = Runner(ctx_run, result, future, yielded) + future.add_done_callback(lambda _: runner) + yielded = None + try: + return future + finally: + # Subtle memory optimization: if next() raised an exception, + # the future's exc_info contains a traceback which + # includes this stack frame. This creates a cycle, + # which will be collected at the next full GC but has + # been shown to greatly increase memory usage of + # benchmarks (relative to the refcount-based scheme + # used in the absence of cycles). We can avoid the + # cycle by clearing the local variable after we return it. + future = None # type: ignore + future_set_result_unless_cancelled(future, result) + return future + + wrapper.__wrapped__ = func # type: ignore + wrapper.__tornado_coroutine__ = True # type: ignore + return wrapper + + +def is_coroutine_function(func: Any) -> bool: + """Return whether *func* is a coroutine function, i.e. a function + wrapped with `~.gen.coroutine`. + + .. versionadded:: 4.5 + """ + return getattr(func, "__tornado_coroutine__", False) + + +class Return(Exception): + """Special exception to return a value from a `coroutine`. + + If this exception is raised, its value argument is used as the + result of the coroutine:: + + @gen.coroutine + def fetch_json(url): + response = yield AsyncHTTPClient().fetch(url) + raise gen.Return(json_decode(response.body)) + + In Python 3.3, this exception is no longer necessary: the ``return`` + statement can be used directly to return a value (previously + ``yield`` and ``return`` with a value could not be combined in the + same function). + + By analogy with the return statement, the value argument is optional, + but it is never necessary to ``raise gen.Return()``. The ``return`` + statement can be used with no arguments instead. + """ + + def __init__(self, value: Any = None) -> None: + super().__init__() + self.value = value + # Cython recognizes subclasses of StopIteration with a .args tuple. + self.args = (value,) + + +class WaitIterator(object): + """Provides an iterator to yield the results of awaitables as they finish. + + Yielding a set of awaitables like this: + + ``results = yield [awaitable1, awaitable2]`` + + pauses the coroutine until both ``awaitable1`` and ``awaitable2`` + return, and then restarts the coroutine with the results of both + awaitables. If either awaitable raises an exception, the + expression will raise that exception and all the results will be + lost. + + If you need to get the result of each awaitable as soon as possible, + or if you need the result of some awaitables even if others produce + errors, you can use ``WaitIterator``:: + + wait_iterator = gen.WaitIterator(awaitable1, awaitable2) + while not wait_iterator.done(): + try: + result = yield wait_iterator.next() + except Exception as e: + print("Error {} from {}".format(e, wait_iterator.current_future)) + else: + print("Result {} received from {} at {}".format( + result, wait_iterator.current_future, + wait_iterator.current_index)) + + Because results are returned as soon as they are available the + output from the iterator *will not be in the same order as the + input arguments*. If you need to know which future produced the + current result, you can use the attributes + ``WaitIterator.current_future``, or ``WaitIterator.current_index`` + to get the index of the awaitable from the input list. (if keyword + arguments were used in the construction of the `WaitIterator`, + ``current_index`` will use the corresponding keyword). + + On Python 3.5, `WaitIterator` implements the async iterator + protocol, so it can be used with the ``async for`` statement (note + that in this version the entire iteration is aborted if any value + raises an exception, while the previous example can continue past + individual errors):: + + async for result in gen.WaitIterator(future1, future2): + print("Result {} received from {} at {}".format( + result, wait_iterator.current_future, + wait_iterator.current_index)) + + .. versionadded:: 4.1 + + .. versionchanged:: 4.3 + Added ``async for`` support in Python 3.5. + + """ + + _unfinished = {} # type: Dict[Future, Union[int, str]] + + def __init__(self, *args: Future, **kwargs: Future) -> None: + if args and kwargs: + raise ValueError("You must provide args or kwargs, not both") + + if kwargs: + self._unfinished = dict((f, k) for (k, f) in kwargs.items()) + futures = list(kwargs.values()) # type: Sequence[Future] + else: + self._unfinished = dict((f, i) for (i, f) in enumerate(args)) + futures = args + + self._finished = collections.deque() # type: Deque[Future] + self.current_index = None # type: Optional[Union[str, int]] + self.current_future = None # type: Optional[Future] + self._running_future = None # type: Optional[Future] + + for future in futures: + future_add_done_callback(future, self._done_callback) + + def done(self) -> bool: + """Returns True if this iterator has no more results.""" + if self._finished or self._unfinished: + return False + # Clear the 'current' values when iteration is done. + self.current_index = self.current_future = None + return True + + def next(self) -> Future: + """Returns a `.Future` that will yield the next available result. + + Note that this `.Future` will not be the same object as any of + the inputs. + """ + self._running_future = Future() + + if self._finished: + return self._return_result(self._finished.popleft()) + + return self._running_future + + def _done_callback(self, done: Future) -> None: + if self._running_future and not self._running_future.done(): + self._return_result(done) + else: + self._finished.append(done) + + def _return_result(self, done: Future) -> Future: + """Called set the returned future's state that of the future + we yielded, and set the current future for the iterator. + """ + if self._running_future is None: + raise Exception("no future is running") + chain_future(done, self._running_future) + + res = self._running_future + self._running_future = None + self.current_future = done + self.current_index = self._unfinished.pop(done) + + return res + + def __aiter__(self) -> typing.AsyncIterator: + return self + + def __anext__(self) -> Future: + if self.done(): + # Lookup by name to silence pyflakes on older versions. + raise getattr(builtins, "StopAsyncIteration")() + return self.next() + + +def multi( + children: Union[List[_Yieldable], Dict[Any, _Yieldable]], + quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (), +) -> "Union[Future[List], Future[Dict]]": + """Runs multiple asynchronous operations in parallel. + + ``children`` may either be a list or a dict whose values are + yieldable objects. ``multi()`` returns a new yieldable + object that resolves to a parallel structure containing their + results. If ``children`` is a list, the result is a list of + results in the same order; if it is a dict, the result is a dict + with the same keys. + + That is, ``results = yield multi(list_of_futures)`` is equivalent + to:: + + results = [] + for future in list_of_futures: + results.append(yield future) + + If any children raise exceptions, ``multi()`` will raise the first + one. All others will be logged, unless they are of types + contained in the ``quiet_exceptions`` argument. + + In a ``yield``-based coroutine, it is not normally necessary to + call this function directly, since the coroutine runner will + do it automatically when a list or dict is yielded. However, + it is necessary in ``await``-based coroutines, or to pass + the ``quiet_exceptions`` argument. + + This function is available under the names ``multi()`` and ``Multi()`` + for historical reasons. + + Cancelling a `.Future` returned by ``multi()`` does not cancel its + children. `asyncio.gather` is similar to ``multi()``, but it does + cancel its children. + + .. versionchanged:: 4.2 + If multiple yieldables fail, any exceptions after the first + (which is raised) will be logged. Added the ``quiet_exceptions`` + argument to suppress this logging for selected exception types. + + .. versionchanged:: 4.3 + Replaced the class ``Multi`` and the function ``multi_future`` + with a unified function ``multi``. Added support for yieldables + other than ``YieldPoint`` and `.Future`. + + """ + return multi_future(children, quiet_exceptions=quiet_exceptions) + + +Multi = multi + + +def multi_future( + children: Union[List[_Yieldable], Dict[Any, _Yieldable]], + quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (), +) -> "Union[Future[List], Future[Dict]]": + """Wait for multiple asynchronous futures in parallel. + + Since Tornado 6.0, this function is exactly the same as `multi`. + + .. versionadded:: 4.0 + + .. versionchanged:: 4.2 + If multiple ``Futures`` fail, any exceptions after the first (which is + raised) will be logged. Added the ``quiet_exceptions`` + argument to suppress this logging for selected exception types. + + .. deprecated:: 4.3 + Use `multi` instead. + """ + if isinstance(children, dict): + keys = list(children.keys()) # type: Optional[List] + children_seq = children.values() # type: Iterable + else: + keys = None + children_seq = children + children_futs = list(map(convert_yielded, children_seq)) + assert all(is_future(i) or isinstance(i, _NullFuture) for i in children_futs) + unfinished_children = set(children_futs) + + future = _create_future() + if not children_futs: + future_set_result_unless_cancelled(future, {} if keys is not None else []) + + def callback(fut: Future) -> None: + unfinished_children.remove(fut) + if not unfinished_children: + result_list = [] + for f in children_futs: + try: + result_list.append(f.result()) + except Exception as e: + if future.done(): + if not isinstance(e, quiet_exceptions): + app_log.error( + "Multiple exceptions in yield list", exc_info=True + ) + else: + future_set_exc_info(future, sys.exc_info()) + if not future.done(): + if keys is not None: + future_set_result_unless_cancelled( + future, dict(zip(keys, result_list)) + ) + else: + future_set_result_unless_cancelled(future, result_list) + + listening = set() # type: Set[Future] + for f in children_futs: + if f not in listening: + listening.add(f) + future_add_done_callback(f, callback) + return future + + +def maybe_future(x: Any) -> Future: + """Converts ``x`` into a `.Future`. + + If ``x`` is already a `.Future`, it is simply returned; otherwise + it is wrapped in a new `.Future`. This is suitable for use as + ``result = yield gen.maybe_future(f())`` when you don't know whether + ``f()`` returns a `.Future` or not. + + .. deprecated:: 4.3 + This function only handles ``Futures``, not other yieldable objects. + Instead of `maybe_future`, check for the non-future result types + you expect (often just ``None``), and ``yield`` anything unknown. + """ + if is_future(x): + return x + else: + fut = _create_future() + fut.set_result(x) + return fut + + +def with_timeout( + timeout: Union[float, datetime.timedelta], + future: _Yieldable, + quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (), +) -> Future: + """Wraps a `.Future` (or other yieldable object) in a timeout. + + Raises `tornado.util.TimeoutError` if the input future does not + complete before ``timeout``, which may be specified in any form + allowed by `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or + an absolute time relative to `.IOLoop.time`) + + If the wrapped `.Future` fails after it has timed out, the exception + will be logged unless it is either of a type contained in + ``quiet_exceptions`` (which may be an exception type or a sequence of + types), or an ``asyncio.CancelledError``. + + The wrapped `.Future` is not canceled when the timeout expires, + permitting it to be reused. `asyncio.wait_for` is similar to this + function but it does cancel the wrapped `.Future` on timeout. + + .. versionadded:: 4.0 + + .. versionchanged:: 4.1 + Added the ``quiet_exceptions`` argument and the logging of unhandled + exceptions. + + .. versionchanged:: 4.4 + Added support for yieldable objects other than `.Future`. + + .. versionchanged:: 6.0.3 + ``asyncio.CancelledError`` is now always considered "quiet". + + .. versionchanged:: 6.2 + ``tornado.util.TimeoutError`` is now an alias to ``asyncio.TimeoutError``. + + """ + # It's tempting to optimize this by cancelling the input future on timeout + # instead of creating a new one, but A) we can't know if we are the only + # one waiting on the input future, so cancelling it might disrupt other + # callers and B) concurrent futures can only be cancelled while they are + # in the queue, so cancellation cannot reliably bound our waiting time. + future_converted = convert_yielded(future) + result = _create_future() + chain_future(future_converted, result) + io_loop = IOLoop.current() + + def error_callback(future: Future) -> None: + try: + future.result() + except asyncio.CancelledError: + pass + except Exception as e: + if not isinstance(e, quiet_exceptions): + app_log.error( + "Exception in Future %r after timeout", future, exc_info=True + ) + + def timeout_callback() -> None: + if not result.done(): + result.set_exception(TimeoutError("Timeout")) + # In case the wrapped future goes on to fail, log it. + future_add_done_callback(future_converted, error_callback) + + timeout_handle = io_loop.add_timeout(timeout, timeout_callback) + if isinstance(future_converted, Future): + # We know this future will resolve on the IOLoop, so we don't + # need the extra thread-safety of IOLoop.add_future (and we also + # don't care about StackContext here. + future_add_done_callback( + future_converted, lambda future: io_loop.remove_timeout(timeout_handle) + ) + else: + # concurrent.futures.Futures may resolve on any thread, so we + # need to route them back to the IOLoop. + io_loop.add_future( + future_converted, lambda future: io_loop.remove_timeout(timeout_handle) + ) + return result + + +def sleep(duration: float) -> "Future[None]": + """Return a `.Future` that resolves after the given number of seconds. + + When used with ``yield`` in a coroutine, this is a non-blocking + analogue to `time.sleep` (which should not be used in coroutines + because it is blocking):: + + yield gen.sleep(0.5) + + Note that calling this function on its own does nothing; you must + wait on the `.Future` it returns (usually by yielding it). + + .. versionadded:: 4.1 + """ + f = _create_future() + IOLoop.current().call_later( + duration, lambda: future_set_result_unless_cancelled(f, None) + ) + return f + + +class _NullFuture(object): + """_NullFuture resembles a Future that finished with a result of None. + + It's not actually a `Future` to avoid depending on a particular event loop. + Handled as a special case in the coroutine runner. + + We lie and tell the type checker that a _NullFuture is a Future so + we don't have to leak _NullFuture into lots of public APIs. But + this means that the type checker can't warn us when we're passing + a _NullFuture into a code path that doesn't understand what to do + with it. + """ + + def result(self) -> None: + return None + + def done(self) -> bool: + return True + + +# _null_future is used as a dummy value in the coroutine runner. It differs +# from moment in that moment always adds a delay of one IOLoop iteration +# while _null_future is processed as soon as possible. +_null_future = typing.cast(Future, _NullFuture()) + +moment = typing.cast(Future, _NullFuture()) +moment.__doc__ = """A special object which may be yielded to allow the IOLoop to run for +one iteration. + +This is not needed in normal use but it can be helpful in long-running +coroutines that are likely to yield Futures that are ready instantly. + +Usage: ``yield gen.moment`` + +In native coroutines, the equivalent of ``yield gen.moment`` is +``await asyncio.sleep(0)``. + +.. versionadded:: 4.0 + +.. deprecated:: 4.5 + ``yield None`` (or ``yield`` with no argument) is now equivalent to + ``yield gen.moment``. +""" + + +class Runner(object): + """Internal implementation of `tornado.gen.coroutine`. + + Maintains information about pending callbacks and their results. + + The results of the generator are stored in ``result_future`` (a + `.Future`) + """ + + def __init__( + self, + ctx_run: Callable, + gen: "Generator[_Yieldable, Any, _T]", + result_future: "Future[_T]", + first_yielded: _Yieldable, + ) -> None: + self.ctx_run = ctx_run + self.gen = gen + self.result_future = result_future + self.future = _null_future # type: Union[None, Future] + self.running = False + self.finished = False + self.io_loop = IOLoop.current() + if self.ctx_run(self.handle_yield, first_yielded): + gen = result_future = first_yielded = None # type: ignore + self.ctx_run(self.run) + + def run(self) -> None: + """Starts or resumes the generator, running until it reaches a + yield point that is not ready. + """ + if self.running or self.finished: + return + try: + self.running = True + while True: + future = self.future + if future is None: + raise Exception("No pending future") + if not future.done(): + return + self.future = None + try: + try: + value = future.result() + except Exception as e: + # Save the exception for later. It's important that + # gen.throw() not be called inside this try/except block + # because that makes sys.exc_info behave unexpectedly. + exc: Optional[Exception] = e + else: + exc = None + finally: + future = None + + if exc is not None: + try: + yielded = self.gen.throw(exc) + finally: + # Break up a circular reference for faster GC on + # CPython. + del exc + else: + yielded = self.gen.send(value) + + except (StopIteration, Return) as e: + self.finished = True + self.future = _null_future + future_set_result_unless_cancelled( + self.result_future, _value_from_stopiteration(e) + ) + self.result_future = None # type: ignore + return + except Exception: + self.finished = True + self.future = _null_future + future_set_exc_info(self.result_future, sys.exc_info()) + self.result_future = None # type: ignore + return + if not self.handle_yield(yielded): + return + yielded = None + finally: + self.running = False + + def handle_yield(self, yielded: _Yieldable) -> bool: + try: + self.future = convert_yielded(yielded) + except BadYieldError: + self.future = Future() + future_set_exc_info(self.future, sys.exc_info()) + + if self.future is moment: + self.io_loop.add_callback(self.ctx_run, self.run) + return False + elif self.future is None: + raise Exception("no pending future") + elif not self.future.done(): + + def inner(f: Any) -> None: + # Break a reference cycle to speed GC. + f = None # noqa: F841 + self.ctx_run(self.run) + + self.io_loop.add_future(self.future, inner) + return False + return True + + def handle_exception( + self, typ: Type[Exception], value: Exception, tb: types.TracebackType + ) -> bool: + if not self.running and not self.finished: + self.future = Future() + future_set_exc_info(self.future, (typ, value, tb)) + self.ctx_run(self.run) + return True + else: + return False + + +# Convert Awaitables into Futures. +try: + _wrap_awaitable = asyncio.ensure_future +except AttributeError: + # asyncio.ensure_future was introduced in Python 3.4.4, but + # Debian jessie still ships with 3.4.2 so try the old name. + _wrap_awaitable = getattr(asyncio, "async") + + +def convert_yielded(yielded: _Yieldable) -> Future: + """Convert a yielded object into a `.Future`. + + The default implementation accepts lists, dictionaries, and + Futures. This has the side effect of starting any coroutines that + did not start themselves, similar to `asyncio.ensure_future`. + + If the `~functools.singledispatch` library is available, this function + may be extended to support additional types. For example:: + + @convert_yielded.register(asyncio.Future) + def _(asyncio_future): + return tornado.platform.asyncio.to_tornado_future(asyncio_future) + + .. versionadded:: 4.1 + + """ + if yielded is None or yielded is moment: + return moment + elif yielded is _null_future: + return _null_future + elif isinstance(yielded, (list, dict)): + return multi(yielded) # type: ignore + elif is_future(yielded): + return typing.cast(Future, yielded) + elif isawaitable(yielded): + return _wrap_awaitable(yielded) # type: ignore + else: + raise BadYieldError("yielded unknown object %r" % (yielded,)) + + +convert_yielded = singledispatch(convert_yielded) diff --git a/venv/lib/python3.9/site-packages/tornado/http1connection.py b/venv/lib/python3.9/site-packages/tornado/http1connection.py new file mode 100644 index 00000000..5ca91688 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/http1connection.py @@ -0,0 +1,844 @@ +# +# Copyright 2014 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Client and server implementations of HTTP/1.x. + +.. versionadded:: 4.0 +""" + +import asyncio +import logging +import re +import types + +from tornado.concurrent import ( + Future, + future_add_done_callback, + future_set_result_unless_cancelled, +) +from tornado.escape import native_str, utf8 +from tornado import gen +from tornado import httputil +from tornado import iostream +from tornado.log import gen_log, app_log +from tornado.util import GzipDecompressor + + +from typing import cast, Optional, Type, Awaitable, Callable, Union, Tuple + + +class _QuietException(Exception): + def __init__(self) -> None: + pass + + +class _ExceptionLoggingContext(object): + """Used with the ``with`` statement when calling delegate methods to + log any exceptions with the given logger. Any exceptions caught are + converted to _QuietException + """ + + def __init__(self, logger: logging.Logger) -> None: + self.logger = logger + + def __enter__(self) -> None: + pass + + def __exit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: types.TracebackType, + ) -> None: + if value is not None: + assert typ is not None + self.logger.error("Uncaught exception", exc_info=(typ, value, tb)) + raise _QuietException + + +class HTTP1ConnectionParameters(object): + """Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.""" + + def __init__( + self, + no_keep_alive: bool = False, + chunk_size: Optional[int] = None, + max_header_size: Optional[int] = None, + header_timeout: Optional[float] = None, + max_body_size: Optional[int] = None, + body_timeout: Optional[float] = None, + decompress: bool = False, + ) -> None: + """ + :arg bool no_keep_alive: If true, always close the connection after + one request. + :arg int chunk_size: how much data to read into memory at once + :arg int max_header_size: maximum amount of data for HTTP headers + :arg float header_timeout: how long to wait for all headers (seconds) + :arg int max_body_size: maximum amount of data for body + :arg float body_timeout: how long to wait while reading body (seconds) + :arg bool decompress: if true, decode incoming + ``Content-Encoding: gzip`` + """ + self.no_keep_alive = no_keep_alive + self.chunk_size = chunk_size or 65536 + self.max_header_size = max_header_size or 65536 + self.header_timeout = header_timeout + self.max_body_size = max_body_size + self.body_timeout = body_timeout + self.decompress = decompress + + +class HTTP1Connection(httputil.HTTPConnection): + """Implements the HTTP/1.x protocol. + + This class can be on its own for clients, or via `HTTP1ServerConnection` + for servers. + """ + + def __init__( + self, + stream: iostream.IOStream, + is_client: bool, + params: Optional[HTTP1ConnectionParameters] = None, + context: Optional[object] = None, + ) -> None: + """ + :arg stream: an `.IOStream` + :arg bool is_client: client or server + :arg params: a `.HTTP1ConnectionParameters` instance or ``None`` + :arg context: an opaque application-defined object that can be accessed + as ``connection.context``. + """ + self.is_client = is_client + self.stream = stream + if params is None: + params = HTTP1ConnectionParameters() + self.params = params + self.context = context + self.no_keep_alive = params.no_keep_alive + # The body limits can be altered by the delegate, so save them + # here instead of just referencing self.params later. + self._max_body_size = ( + self.params.max_body_size + if self.params.max_body_size is not None + else self.stream.max_buffer_size + ) + self._body_timeout = self.params.body_timeout + # _write_finished is set to True when finish() has been called, + # i.e. there will be no more data sent. Data may still be in the + # stream's write buffer. + self._write_finished = False + # True when we have read the entire incoming body. + self._read_finished = False + # _finish_future resolves when all data has been written and flushed + # to the IOStream. + self._finish_future = Future() # type: Future[None] + # If true, the connection should be closed after this request + # (after the response has been written in the server side, + # and after it has been read in the client) + self._disconnect_on_finish = False + self._clear_callbacks() + # Save the start lines after we read or write them; they + # affect later processing (e.g. 304 responses and HEAD methods + # have content-length but no bodies) + self._request_start_line = None # type: Optional[httputil.RequestStartLine] + self._response_start_line = None # type: Optional[httputil.ResponseStartLine] + self._request_headers = None # type: Optional[httputil.HTTPHeaders] + # True if we are writing output with chunked encoding. + self._chunking_output = False + # While reading a body with a content-length, this is the + # amount left to read. + self._expected_content_remaining = None # type: Optional[int] + # A Future for our outgoing writes, returned by IOStream.write. + self._pending_write = None # type: Optional[Future[None]] + + def read_response(self, delegate: httputil.HTTPMessageDelegate) -> Awaitable[bool]: + """Read a single HTTP response. + + Typical client-mode usage is to write a request using `write_headers`, + `write`, and `finish`, and then call ``read_response``. + + :arg delegate: a `.HTTPMessageDelegate` + + Returns a `.Future` that resolves to a bool after the full response has + been read. The result is true if the stream is still open. + """ + if self.params.decompress: + delegate = _GzipMessageDelegate(delegate, self.params.chunk_size) + return self._read_message(delegate) + + async def _read_message(self, delegate: httputil.HTTPMessageDelegate) -> bool: + need_delegate_close = False + try: + header_future = self.stream.read_until_regex( + b"\r?\n\r?\n", max_bytes=self.params.max_header_size + ) + if self.params.header_timeout is None: + header_data = await header_future + else: + try: + header_data = await gen.with_timeout( + self.stream.io_loop.time() + self.params.header_timeout, + header_future, + quiet_exceptions=iostream.StreamClosedError, + ) + except gen.TimeoutError: + self.close() + return False + start_line_str, headers = self._parse_headers(header_data) + if self.is_client: + resp_start_line = httputil.parse_response_start_line(start_line_str) + self._response_start_line = resp_start_line + start_line = ( + resp_start_line + ) # type: Union[httputil.RequestStartLine, httputil.ResponseStartLine] + # TODO: this will need to change to support client-side keepalive + self._disconnect_on_finish = False + else: + req_start_line = httputil.parse_request_start_line(start_line_str) + self._request_start_line = req_start_line + self._request_headers = headers + start_line = req_start_line + self._disconnect_on_finish = not self._can_keep_alive( + req_start_line, headers + ) + need_delegate_close = True + with _ExceptionLoggingContext(app_log): + header_recv_future = delegate.headers_received(start_line, headers) + if header_recv_future is not None: + await header_recv_future + if self.stream is None: + # We've been detached. + need_delegate_close = False + return False + skip_body = False + if self.is_client: + assert isinstance(start_line, httputil.ResponseStartLine) + if ( + self._request_start_line is not None + and self._request_start_line.method == "HEAD" + ): + skip_body = True + code = start_line.code + if code == 304: + # 304 responses may include the content-length header + # but do not actually have a body. + # http://tools.ietf.org/html/rfc7230#section-3.3 + skip_body = True + if 100 <= code < 200: + # 1xx responses should never indicate the presence of + # a body. + if "Content-Length" in headers or "Transfer-Encoding" in headers: + raise httputil.HTTPInputError( + "Response code %d cannot have body" % code + ) + # TODO: client delegates will get headers_received twice + # in the case of a 100-continue. Document or change? + await self._read_message(delegate) + else: + if headers.get("Expect") == "100-continue" and not self._write_finished: + self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n") + if not skip_body: + body_future = self._read_body( + resp_start_line.code if self.is_client else 0, headers, delegate + ) + if body_future is not None: + if self._body_timeout is None: + await body_future + else: + try: + await gen.with_timeout( + self.stream.io_loop.time() + self._body_timeout, + body_future, + quiet_exceptions=iostream.StreamClosedError, + ) + except gen.TimeoutError: + gen_log.info("Timeout reading body from %s", self.context) + self.stream.close() + return False + self._read_finished = True + if not self._write_finished or self.is_client: + need_delegate_close = False + with _ExceptionLoggingContext(app_log): + delegate.finish() + # If we're waiting for the application to produce an asynchronous + # response, and we're not detached, register a close callback + # on the stream (we didn't need one while we were reading) + if ( + not self._finish_future.done() + and self.stream is not None + and not self.stream.closed() + ): + self.stream.set_close_callback(self._on_connection_close) + await self._finish_future + if self.is_client and self._disconnect_on_finish: + self.close() + if self.stream is None: + return False + except httputil.HTTPInputError as e: + gen_log.info("Malformed HTTP message from %s: %s", self.context, e) + if not self.is_client: + await self.stream.write(b"HTTP/1.1 400 Bad Request\r\n\r\n") + self.close() + return False + finally: + if need_delegate_close: + with _ExceptionLoggingContext(app_log): + delegate.on_connection_close() + header_future = None # type: ignore + self._clear_callbacks() + return True + + def _clear_callbacks(self) -> None: + """Clears the callback attributes. + + This allows the request handler to be garbage collected more + quickly in CPython by breaking up reference cycles. + """ + self._write_callback = None + self._write_future = None # type: Optional[Future[None]] + self._close_callback = None # type: Optional[Callable[[], None]] + if self.stream is not None: + self.stream.set_close_callback(None) + + def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None: + """Sets a callback that will be run when the connection is closed. + + Note that this callback is slightly different from + `.HTTPMessageDelegate.on_connection_close`: The + `.HTTPMessageDelegate` method is called when the connection is + closed while receiving a message. This callback is used when + there is not an active delegate (for example, on the server + side this callback is used if the client closes the connection + after sending its request but before receiving all the + response. + """ + self._close_callback = callback + + def _on_connection_close(self) -> None: + # Note that this callback is only registered on the IOStream + # when we have finished reading the request and are waiting for + # the application to produce its response. + if self._close_callback is not None: + callback = self._close_callback + self._close_callback = None + callback() + if not self._finish_future.done(): + future_set_result_unless_cancelled(self._finish_future, None) + self._clear_callbacks() + + def close(self) -> None: + if self.stream is not None: + self.stream.close() + self._clear_callbacks() + if not self._finish_future.done(): + future_set_result_unless_cancelled(self._finish_future, None) + + def detach(self) -> iostream.IOStream: + """Take control of the underlying stream. + + Returns the underlying `.IOStream` object and stops all further + HTTP processing. May only be called during + `.HTTPMessageDelegate.headers_received`. Intended for implementing + protocols like websockets that tunnel over an HTTP handshake. + """ + self._clear_callbacks() + stream = self.stream + self.stream = None # type: ignore + if not self._finish_future.done(): + future_set_result_unless_cancelled(self._finish_future, None) + return stream + + def set_body_timeout(self, timeout: float) -> None: + """Sets the body timeout for a single request. + + Overrides the value from `.HTTP1ConnectionParameters`. + """ + self._body_timeout = timeout + + def set_max_body_size(self, max_body_size: int) -> None: + """Sets the body size limit for a single request. + + Overrides the value from `.HTTP1ConnectionParameters`. + """ + self._max_body_size = max_body_size + + def write_headers( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + chunk: Optional[bytes] = None, + ) -> "Future[None]": + """Implements `.HTTPConnection.write_headers`.""" + lines = [] + if self.is_client: + assert isinstance(start_line, httputil.RequestStartLine) + self._request_start_line = start_line + lines.append(utf8("%s %s HTTP/1.1" % (start_line[0], start_line[1]))) + # Client requests with a non-empty body must have either a + # Content-Length or a Transfer-Encoding. + self._chunking_output = ( + start_line.method in ("POST", "PUT", "PATCH") + and "Content-Length" not in headers + and ( + "Transfer-Encoding" not in headers + or headers["Transfer-Encoding"] == "chunked" + ) + ) + else: + assert isinstance(start_line, httputil.ResponseStartLine) + assert self._request_start_line is not None + assert self._request_headers is not None + self._response_start_line = start_line + lines.append(utf8("HTTP/1.1 %d %s" % (start_line[1], start_line[2]))) + self._chunking_output = ( + # TODO: should this use + # self._request_start_line.version or + # start_line.version? + self._request_start_line.version == "HTTP/1.1" + # Omit payload header field for HEAD request. + and self._request_start_line.method != "HEAD" + # 1xx, 204 and 304 responses have no body (not even a zero-length + # body), and so should not have either Content-Length or + # Transfer-Encoding headers. + and start_line.code not in (204, 304) + and (start_line.code < 100 or start_line.code >= 200) + # No need to chunk the output if a Content-Length is specified. + and "Content-Length" not in headers + # Applications are discouraged from touching Transfer-Encoding, + # but if they do, leave it alone. + and "Transfer-Encoding" not in headers + ) + # If connection to a 1.1 client will be closed, inform client + if ( + self._request_start_line.version == "HTTP/1.1" + and self._disconnect_on_finish + ): + headers["Connection"] = "close" + # If a 1.0 client asked for keep-alive, add the header. + if ( + self._request_start_line.version == "HTTP/1.0" + and self._request_headers.get("Connection", "").lower() == "keep-alive" + ): + headers["Connection"] = "Keep-Alive" + if self._chunking_output: + headers["Transfer-Encoding"] = "chunked" + if not self.is_client and ( + self._request_start_line.method == "HEAD" + or cast(httputil.ResponseStartLine, start_line).code == 304 + ): + self._expected_content_remaining = 0 + elif "Content-Length" in headers: + self._expected_content_remaining = int(headers["Content-Length"]) + else: + self._expected_content_remaining = None + # TODO: headers are supposed to be of type str, but we still have some + # cases that let bytes slip through. Remove these native_str calls when those + # are fixed. + header_lines = ( + native_str(n) + ": " + native_str(v) for n, v in headers.get_all() + ) + lines.extend(line.encode("latin1") for line in header_lines) + for line in lines: + if b"\n" in line: + raise ValueError("Newline in header: " + repr(line)) + future = None + if self.stream.closed(): + future = self._write_future = Future() + future.set_exception(iostream.StreamClosedError()) + future.exception() + else: + future = self._write_future = Future() + data = b"\r\n".join(lines) + b"\r\n\r\n" + if chunk: + data += self._format_chunk(chunk) + self._pending_write = self.stream.write(data) + future_add_done_callback(self._pending_write, self._on_write_complete) + return future + + def _format_chunk(self, chunk: bytes) -> bytes: + if self._expected_content_remaining is not None: + self._expected_content_remaining -= len(chunk) + if self._expected_content_remaining < 0: + # Close the stream now to stop further framing errors. + self.stream.close() + raise httputil.HTTPOutputError( + "Tried to write more data than Content-Length" + ) + if self._chunking_output and chunk: + # Don't write out empty chunks because that means END-OF-STREAM + # with chunked encoding + return utf8("%x" % len(chunk)) + b"\r\n" + chunk + b"\r\n" + else: + return chunk + + def write(self, chunk: bytes) -> "Future[None]": + """Implements `.HTTPConnection.write`. + + For backwards compatibility it is allowed but deprecated to + skip `write_headers` and instead call `write()` with a + pre-encoded header block. + """ + future = None + if self.stream.closed(): + future = self._write_future = Future() + self._write_future.set_exception(iostream.StreamClosedError()) + self._write_future.exception() + else: + future = self._write_future = Future() + self._pending_write = self.stream.write(self._format_chunk(chunk)) + future_add_done_callback(self._pending_write, self._on_write_complete) + return future + + def finish(self) -> None: + """Implements `.HTTPConnection.finish`.""" + if ( + self._expected_content_remaining is not None + and self._expected_content_remaining != 0 + and not self.stream.closed() + ): + self.stream.close() + raise httputil.HTTPOutputError( + "Tried to write %d bytes less than Content-Length" + % self._expected_content_remaining + ) + if self._chunking_output: + if not self.stream.closed(): + self._pending_write = self.stream.write(b"0\r\n\r\n") + self._pending_write.add_done_callback(self._on_write_complete) + self._write_finished = True + # If the app finished the request while we're still reading, + # divert any remaining data away from the delegate and + # close the connection when we're done sending our response. + # Closing the connection is the only way to avoid reading the + # whole input body. + if not self._read_finished: + self._disconnect_on_finish = True + # No more data is coming, so instruct TCP to send any remaining + # data immediately instead of waiting for a full packet or ack. + self.stream.set_nodelay(True) + if self._pending_write is None: + self._finish_request(None) + else: + future_add_done_callback(self._pending_write, self._finish_request) + + def _on_write_complete(self, future: "Future[None]") -> None: + exc = future.exception() + if exc is not None and not isinstance(exc, iostream.StreamClosedError): + future.result() + if self._write_callback is not None: + callback = self._write_callback + self._write_callback = None + self.stream.io_loop.add_callback(callback) + if self._write_future is not None: + future = self._write_future + self._write_future = None + future_set_result_unless_cancelled(future, None) + + def _can_keep_alive( + self, start_line: httputil.RequestStartLine, headers: httputil.HTTPHeaders + ) -> bool: + if self.params.no_keep_alive: + return False + connection_header = headers.get("Connection") + if connection_header is not None: + connection_header = connection_header.lower() + if start_line.version == "HTTP/1.1": + return connection_header != "close" + elif ( + "Content-Length" in headers + or headers.get("Transfer-Encoding", "").lower() == "chunked" + or getattr(start_line, "method", None) in ("HEAD", "GET") + ): + # start_line may be a request or response start line; only + # the former has a method attribute. + return connection_header == "keep-alive" + return False + + def _finish_request(self, future: "Optional[Future[None]]") -> None: + self._clear_callbacks() + if not self.is_client and self._disconnect_on_finish: + self.close() + return + # Turn Nagle's algorithm back on, leaving the stream in its + # default state for the next request. + self.stream.set_nodelay(False) + if not self._finish_future.done(): + future_set_result_unless_cancelled(self._finish_future, None) + + def _parse_headers(self, data: bytes) -> Tuple[str, httputil.HTTPHeaders]: + # The lstrip removes newlines that some implementations sometimes + # insert between messages of a reused connection. Per RFC 7230, + # we SHOULD ignore at least one empty line before the request. + # http://tools.ietf.org/html/rfc7230#section-3.5 + data_str = native_str(data.decode("latin1")).lstrip("\r\n") + # RFC 7230 section allows for both CRLF and bare LF. + eol = data_str.find("\n") + start_line = data_str[:eol].rstrip("\r") + headers = httputil.HTTPHeaders.parse(data_str[eol:]) + return start_line, headers + + def _read_body( + self, + code: int, + headers: httputil.HTTPHeaders, + delegate: httputil.HTTPMessageDelegate, + ) -> Optional[Awaitable[None]]: + if "Content-Length" in headers: + if "Transfer-Encoding" in headers: + # Response cannot contain both Content-Length and + # Transfer-Encoding headers. + # http://tools.ietf.org/html/rfc7230#section-3.3.3 + raise httputil.HTTPInputError( + "Response with both Transfer-Encoding and Content-Length" + ) + if "," in headers["Content-Length"]: + # Proxies sometimes cause Content-Length headers to get + # duplicated. If all the values are identical then we can + # use them but if they differ it's an error. + pieces = re.split(r",\s*", headers["Content-Length"]) + if any(i != pieces[0] for i in pieces): + raise httputil.HTTPInputError( + "Multiple unequal Content-Lengths: %r" + % headers["Content-Length"] + ) + headers["Content-Length"] = pieces[0] + + try: + content_length = int(headers["Content-Length"]) # type: Optional[int] + except ValueError: + # Handles non-integer Content-Length value. + raise httputil.HTTPInputError( + "Only integer Content-Length is allowed: %s" + % headers["Content-Length"] + ) + + if cast(int, content_length) > self._max_body_size: + raise httputil.HTTPInputError("Content-Length too long") + else: + content_length = None + + if code == 204: + # This response code is not allowed to have a non-empty body, + # and has an implicit length of zero instead of read-until-close. + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 + if "Transfer-Encoding" in headers or content_length not in (None, 0): + raise httputil.HTTPInputError( + "Response with code %d should not have body" % code + ) + content_length = 0 + + if content_length is not None: + return self._read_fixed_body(content_length, delegate) + if headers.get("Transfer-Encoding", "").lower() == "chunked": + return self._read_chunked_body(delegate) + if self.is_client: + return self._read_body_until_close(delegate) + return None + + async def _read_fixed_body( + self, content_length: int, delegate: httputil.HTTPMessageDelegate + ) -> None: + while content_length > 0: + body = await self.stream.read_bytes( + min(self.params.chunk_size, content_length), partial=True + ) + content_length -= len(body) + if not self._write_finished or self.is_client: + with _ExceptionLoggingContext(app_log): + ret = delegate.data_received(body) + if ret is not None: + await ret + + async def _read_chunked_body(self, delegate: httputil.HTTPMessageDelegate) -> None: + # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 + total_size = 0 + while True: + chunk_len_str = await self.stream.read_until(b"\r\n", max_bytes=64) + chunk_len = int(chunk_len_str.strip(), 16) + if chunk_len == 0: + crlf = await self.stream.read_bytes(2) + if crlf != b"\r\n": + raise httputil.HTTPInputError( + "improperly terminated chunked request" + ) + return + total_size += chunk_len + if total_size > self._max_body_size: + raise httputil.HTTPInputError("chunked body too large") + bytes_to_read = chunk_len + while bytes_to_read: + chunk = await self.stream.read_bytes( + min(bytes_to_read, self.params.chunk_size), partial=True + ) + bytes_to_read -= len(chunk) + if not self._write_finished or self.is_client: + with _ExceptionLoggingContext(app_log): + ret = delegate.data_received(chunk) + if ret is not None: + await ret + # chunk ends with \r\n + crlf = await self.stream.read_bytes(2) + assert crlf == b"\r\n" + + async def _read_body_until_close( + self, delegate: httputil.HTTPMessageDelegate + ) -> None: + body = await self.stream.read_until_close() + if not self._write_finished or self.is_client: + with _ExceptionLoggingContext(app_log): + ret = delegate.data_received(body) + if ret is not None: + await ret + + +class _GzipMessageDelegate(httputil.HTTPMessageDelegate): + """Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``.""" + + def __init__(self, delegate: httputil.HTTPMessageDelegate, chunk_size: int) -> None: + self._delegate = delegate + self._chunk_size = chunk_size + self._decompressor = None # type: Optional[GzipDecompressor] + + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: + if headers.get("Content-Encoding", "").lower() == "gzip": + self._decompressor = GzipDecompressor() + # Downstream delegates will only see uncompressed data, + # so rename the content-encoding header. + # (but note that curl_httpclient doesn't do this). + headers.add("X-Consumed-Content-Encoding", headers["Content-Encoding"]) + del headers["Content-Encoding"] + return self._delegate.headers_received(start_line, headers) + + async def data_received(self, chunk: bytes) -> None: + if self._decompressor: + compressed_data = chunk + while compressed_data: + decompressed = self._decompressor.decompress( + compressed_data, self._chunk_size + ) + if decompressed: + ret = self._delegate.data_received(decompressed) + if ret is not None: + await ret + compressed_data = self._decompressor.unconsumed_tail + if compressed_data and not decompressed: + raise httputil.HTTPInputError( + "encountered unconsumed gzip data without making progress" + ) + else: + ret = self._delegate.data_received(chunk) + if ret is not None: + await ret + + def finish(self) -> None: + if self._decompressor is not None: + tail = self._decompressor.flush() + if tail: + # The tail should always be empty: decompress returned + # all that it can in data_received and the only + # purpose of the flush call is to detect errors such + # as truncated input. If we did legitimately get a new + # chunk at this point we'd need to change the + # interface to make finish() a coroutine. + raise ValueError( + "decompressor.flush returned data; possible truncated input" + ) + return self._delegate.finish() + + def on_connection_close(self) -> None: + return self._delegate.on_connection_close() + + +class HTTP1ServerConnection(object): + """An HTTP/1.x server.""" + + def __init__( + self, + stream: iostream.IOStream, + params: Optional[HTTP1ConnectionParameters] = None, + context: Optional[object] = None, + ) -> None: + """ + :arg stream: an `.IOStream` + :arg params: a `.HTTP1ConnectionParameters` or None + :arg context: an opaque application-defined object that is accessible + as ``connection.context`` + """ + self.stream = stream + if params is None: + params = HTTP1ConnectionParameters() + self.params = params + self.context = context + self._serving_future = None # type: Optional[Future[None]] + + async def close(self) -> None: + """Closes the connection. + + Returns a `.Future` that resolves after the serving loop has exited. + """ + self.stream.close() + # Block until the serving loop is done, but ignore any exceptions + # (start_serving is already responsible for logging them). + assert self._serving_future is not None + try: + await self._serving_future + except Exception: + pass + + def start_serving(self, delegate: httputil.HTTPServerConnectionDelegate) -> None: + """Starts serving requests on this connection. + + :arg delegate: a `.HTTPServerConnectionDelegate` + """ + assert isinstance(delegate, httputil.HTTPServerConnectionDelegate) + fut = gen.convert_yielded(self._server_request_loop(delegate)) + self._serving_future = fut + # Register the future on the IOLoop so its errors get logged. + self.stream.io_loop.add_future(fut, lambda f: f.result()) + + async def _server_request_loop( + self, delegate: httputil.HTTPServerConnectionDelegate + ) -> None: + try: + while True: + conn = HTTP1Connection(self.stream, False, self.params, self.context) + request_delegate = delegate.start_request(self, conn) + try: + ret = await conn.read_response(request_delegate) + except ( + iostream.StreamClosedError, + iostream.UnsatisfiableReadError, + asyncio.CancelledError, + ): + return + except _QuietException: + # This exception was already logged. + conn.close() + return + except Exception: + gen_log.error("Uncaught exception", exc_info=True) + conn.close() + return + if not ret: + return + await asyncio.sleep(0) + finally: + delegate.on_close(self) diff --git a/venv/lib/python3.9/site-packages/tornado/httpclient.py b/venv/lib/python3.9/site-packages/tornado/httpclient.py new file mode 100644 index 00000000..3011c371 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/httpclient.py @@ -0,0 +1,790 @@ +"""Blocking and non-blocking HTTP client interfaces. + +This module defines a common interface shared by two implementations, +``simple_httpclient`` and ``curl_httpclient``. Applications may either +instantiate their chosen implementation class directly or use the +`AsyncHTTPClient` class from this module, which selects an implementation +that can be overridden with the `AsyncHTTPClient.configure` method. + +The default implementation is ``simple_httpclient``, and this is expected +to be suitable for most users' needs. However, some applications may wish +to switch to ``curl_httpclient`` for reasons such as the following: + +* ``curl_httpclient`` has some features not found in ``simple_httpclient``, + including support for HTTP proxies and the ability to use a specified + network interface. + +* ``curl_httpclient`` is more likely to be compatible with sites that are + not-quite-compliant with the HTTP spec, or sites that use little-exercised + features of HTTP. + +* ``curl_httpclient`` is faster. + +Note that if you are using ``curl_httpclient``, it is highly +recommended that you use a recent version of ``libcurl`` and +``pycurl``. Currently the minimum supported version of libcurl is +7.22.0, and the minimum version of pycurl is 7.18.2. It is highly +recommended that your ``libcurl`` installation is built with +asynchronous DNS resolver (threaded or c-ares), otherwise you may +encounter various problems with request timeouts (for more +information, see +http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS +and comments in curl_httpclient.py). + +To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup:: + + AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient") +""" + +import datetime +import functools +from io import BytesIO +import ssl +import time +import weakref + +from tornado.concurrent import ( + Future, + future_set_result_unless_cancelled, + future_set_exception_unless_cancelled, +) +from tornado.escape import utf8, native_str +from tornado import gen, httputil +from tornado.ioloop import IOLoop +from tornado.util import Configurable + +from typing import Type, Any, Union, Dict, Callable, Optional, cast + + +class HTTPClient(object): + """A blocking HTTP client. + + This interface is provided to make it easier to share code between + synchronous and asynchronous applications. Applications that are + running an `.IOLoop` must use `AsyncHTTPClient` instead. + + Typical usage looks like this:: + + http_client = httpclient.HTTPClient() + try: + response = http_client.fetch("http://www.google.com/") + print(response.body) + except httpclient.HTTPError as e: + # HTTPError is raised for non-200 responses; the response + # can be found in e.response. + print("Error: " + str(e)) + except Exception as e: + # Other errors are possible, such as IOError. + print("Error: " + str(e)) + http_client.close() + + .. versionchanged:: 5.0 + + Due to limitations in `asyncio`, it is no longer possible to + use the synchronous ``HTTPClient`` while an `.IOLoop` is running. + Use `AsyncHTTPClient` instead. + + """ + + def __init__( + self, + async_client_class: "Optional[Type[AsyncHTTPClient]]" = None, + **kwargs: Any + ) -> None: + # Initialize self._closed at the beginning of the constructor + # so that an exception raised here doesn't lead to confusing + # failures in __del__. + self._closed = True + self._io_loop = IOLoop(make_current=False) + if async_client_class is None: + async_client_class = AsyncHTTPClient + + # Create the client while our IOLoop is "current", without + # clobbering the thread's real current IOLoop (if any). + async def make_client() -> "AsyncHTTPClient": + await gen.sleep(0) + assert async_client_class is not None + return async_client_class(**kwargs) + + self._async_client = self._io_loop.run_sync(make_client) + self._closed = False + + def __del__(self) -> None: + self.close() + + def close(self) -> None: + """Closes the HTTPClient, freeing any resources used.""" + if not self._closed: + self._async_client.close() + self._io_loop.close() + self._closed = True + + def fetch( + self, request: Union["HTTPRequest", str], **kwargs: Any + ) -> "HTTPResponse": + """Executes a request, returning an `HTTPResponse`. + + The request may be either a string URL or an `HTTPRequest` object. + If it is a string, we construct an `HTTPRequest` using any additional + kwargs: ``HTTPRequest(request, **kwargs)`` + + If an error occurs during the fetch, we raise an `HTTPError` unless + the ``raise_error`` keyword argument is set to False. + """ + response = self._io_loop.run_sync( + functools.partial(self._async_client.fetch, request, **kwargs) + ) + return response + + +class AsyncHTTPClient(Configurable): + """An non-blocking HTTP client. + + Example usage:: + + async def f(): + http_client = AsyncHTTPClient() + try: + response = await http_client.fetch("http://www.google.com") + except Exception as e: + print("Error: %s" % e) + else: + print(response.body) + + The constructor for this class is magic in several respects: It + actually creates an instance of an implementation-specific + subclass, and instances are reused as a kind of pseudo-singleton + (one per `.IOLoop`). The keyword argument ``force_instance=True`` + can be used to suppress this singleton behavior. Unless + ``force_instance=True`` is used, no arguments should be passed to + the `AsyncHTTPClient` constructor. The implementation subclass as + well as arguments to its constructor can be set with the static + method `configure()` + + All `AsyncHTTPClient` implementations support a ``defaults`` + keyword argument, which can be used to set default values for + `HTTPRequest` attributes. For example:: + + AsyncHTTPClient.configure( + None, defaults=dict(user_agent="MyUserAgent")) + # or with force_instance: + client = AsyncHTTPClient(force_instance=True, + defaults=dict(user_agent="MyUserAgent")) + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + """ + + _instance_cache = None # type: Dict[IOLoop, AsyncHTTPClient] + + @classmethod + def configurable_base(cls) -> Type[Configurable]: + return AsyncHTTPClient + + @classmethod + def configurable_default(cls) -> Type[Configurable]: + from tornado.simple_httpclient import SimpleAsyncHTTPClient + + return SimpleAsyncHTTPClient + + @classmethod + def _async_clients(cls) -> Dict[IOLoop, "AsyncHTTPClient"]: + attr_name = "_async_client_dict_" + cls.__name__ + if not hasattr(cls, attr_name): + setattr(cls, attr_name, weakref.WeakKeyDictionary()) + return getattr(cls, attr_name) + + def __new__(cls, force_instance: bool = False, **kwargs: Any) -> "AsyncHTTPClient": + io_loop = IOLoop.current() + if force_instance: + instance_cache = None + else: + instance_cache = cls._async_clients() + if instance_cache is not None and io_loop in instance_cache: + return instance_cache[io_loop] + instance = super(AsyncHTTPClient, cls).__new__(cls, **kwargs) # type: ignore + # Make sure the instance knows which cache to remove itself from. + # It can't simply call _async_clients() because we may be in + # __new__(AsyncHTTPClient) but instance.__class__ may be + # SimpleAsyncHTTPClient. + instance._instance_cache = instance_cache + if instance_cache is not None: + instance_cache[instance.io_loop] = instance + return instance + + def initialize(self, defaults: Optional[Dict[str, Any]] = None) -> None: + self.io_loop = IOLoop.current() + self.defaults = dict(HTTPRequest._DEFAULTS) + if defaults is not None: + self.defaults.update(defaults) + self._closed = False + + def close(self) -> None: + """Destroys this HTTP client, freeing any file descriptors used. + + This method is **not needed in normal use** due to the way + that `AsyncHTTPClient` objects are transparently reused. + ``close()`` is generally only necessary when either the + `.IOLoop` is also being closed, or the ``force_instance=True`` + argument was used when creating the `AsyncHTTPClient`. + + No other methods may be called on the `AsyncHTTPClient` after + ``close()``. + + """ + if self._closed: + return + self._closed = True + if self._instance_cache is not None: + cached_val = self._instance_cache.pop(self.io_loop, None) + # If there's an object other than self in the instance + # cache for our IOLoop, something has gotten mixed up. A + # value of None appears to be possible when this is called + # from a destructor (HTTPClient.__del__) as the weakref + # gets cleared before the destructor runs. + if cached_val is not None and cached_val is not self: + raise RuntimeError("inconsistent AsyncHTTPClient cache") + + def fetch( + self, + request: Union[str, "HTTPRequest"], + raise_error: bool = True, + **kwargs: Any + ) -> "Future[HTTPResponse]": + """Executes a request, asynchronously returning an `HTTPResponse`. + + The request may be either a string URL or an `HTTPRequest` object. + If it is a string, we construct an `HTTPRequest` using any additional + kwargs: ``HTTPRequest(request, **kwargs)`` + + This method returns a `.Future` whose result is an + `HTTPResponse`. By default, the ``Future`` will raise an + `HTTPError` if the request returned a non-200 response code + (other errors may also be raised if the server could not be + contacted). Instead, if ``raise_error`` is set to False, the + response will always be returned regardless of the response + code. + + If a ``callback`` is given, it will be invoked with the `HTTPResponse`. + In the callback interface, `HTTPError` is not automatically raised. + Instead, you must check the response's ``error`` attribute or + call its `~HTTPResponse.rethrow` method. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + `.Future` instead. + + The ``raise_error=False`` argument only affects the + `HTTPError` raised when a non-200 response code is used, + instead of suppressing all errors. + """ + if self._closed: + raise RuntimeError("fetch() called on closed AsyncHTTPClient") + if not isinstance(request, HTTPRequest): + request = HTTPRequest(url=request, **kwargs) + else: + if kwargs: + raise ValueError( + "kwargs can't be used if request is an HTTPRequest object" + ) + # We may modify this (to add Host, Accept-Encoding, etc), + # so make sure we don't modify the caller's object. This is also + # where normal dicts get converted to HTTPHeaders objects. + request.headers = httputil.HTTPHeaders(request.headers) + request_proxy = _RequestProxy(request, self.defaults) + future = Future() # type: Future[HTTPResponse] + + def handle_response(response: "HTTPResponse") -> None: + if response.error: + if raise_error or not response._error_is_response_code: + future_set_exception_unless_cancelled(future, response.error) + return + future_set_result_unless_cancelled(future, response) + + self.fetch_impl(cast(HTTPRequest, request_proxy), handle_response) + return future + + def fetch_impl( + self, request: "HTTPRequest", callback: Callable[["HTTPResponse"], None] + ) -> None: + raise NotImplementedError() + + @classmethod + def configure( + cls, impl: "Union[None, str, Type[Configurable]]", **kwargs: Any + ) -> None: + """Configures the `AsyncHTTPClient` subclass to use. + + ``AsyncHTTPClient()`` actually creates an instance of a subclass. + This method may be called with either a class object or the + fully-qualified name of such a class (or ``None`` to use the default, + ``SimpleAsyncHTTPClient``) + + If additional keyword arguments are given, they will be passed + to the constructor of each subclass instance created. The + keyword argument ``max_clients`` determines the maximum number + of simultaneous `~AsyncHTTPClient.fetch()` operations that can + execute in parallel on each `.IOLoop`. Additional arguments + may be supported depending on the implementation class in use. + + Example:: + + AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient") + """ + super(AsyncHTTPClient, cls).configure(impl, **kwargs) + + +class HTTPRequest(object): + """HTTP client request object.""" + + _headers = None # type: Union[Dict[str, str], httputil.HTTPHeaders] + + # Default values for HTTPRequest parameters. + # Merged with the values on the request object by AsyncHTTPClient + # implementations. + _DEFAULTS = dict( + connect_timeout=20.0, + request_timeout=20.0, + follow_redirects=True, + max_redirects=5, + decompress_response=True, + proxy_password="", + allow_nonstandard_methods=False, + validate_cert=True, + ) + + def __init__( + self, + url: str, + method: str = "GET", + headers: Optional[Union[Dict[str, str], httputil.HTTPHeaders]] = None, + body: Optional[Union[bytes, str]] = None, + auth_username: Optional[str] = None, + auth_password: Optional[str] = None, + auth_mode: Optional[str] = None, + connect_timeout: Optional[float] = None, + request_timeout: Optional[float] = None, + if_modified_since: Optional[Union[float, datetime.datetime]] = None, + follow_redirects: Optional[bool] = None, + max_redirects: Optional[int] = None, + user_agent: Optional[str] = None, + use_gzip: Optional[bool] = None, + network_interface: Optional[str] = None, + streaming_callback: Optional[Callable[[bytes], None]] = None, + header_callback: Optional[Callable[[str], None]] = None, + prepare_curl_callback: Optional[Callable[[Any], None]] = None, + proxy_host: Optional[str] = None, + proxy_port: Optional[int] = None, + proxy_username: Optional[str] = None, + proxy_password: Optional[str] = None, + proxy_auth_mode: Optional[str] = None, + allow_nonstandard_methods: Optional[bool] = None, + validate_cert: Optional[bool] = None, + ca_certs: Optional[str] = None, + allow_ipv6: Optional[bool] = None, + client_key: Optional[str] = None, + client_cert: Optional[str] = None, + body_producer: Optional[ + Callable[[Callable[[bytes], None]], "Future[None]"] + ] = None, + expect_100_continue: bool = False, + decompress_response: Optional[bool] = None, + ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None, + ) -> None: + r"""All parameters except ``url`` are optional. + + :arg str url: URL to fetch + :arg str method: HTTP method, e.g. "GET" or "POST" + :arg headers: Additional HTTP headers to pass on the request + :type headers: `~tornado.httputil.HTTPHeaders` or `dict` + :arg body: HTTP request body as a string (byte or unicode; if unicode + the utf-8 encoding will be used) + :type body: `str` or `bytes` + :arg collections.abc.Callable body_producer: Callable used for + lazy/asynchronous request bodies. + It is called with one argument, a ``write`` function, and should + return a `.Future`. It should call the write function with new + data as it becomes available. The write function returns a + `.Future` which can be used for flow control. + Only one of ``body`` and ``body_producer`` may + be specified. ``body_producer`` is not supported on + ``curl_httpclient``. When using ``body_producer`` it is recommended + to pass a ``Content-Length`` in the headers as otherwise chunked + encoding will be used, and many servers do not support chunked + encoding on requests. New in Tornado 4.0 + :arg str auth_username: Username for HTTP authentication + :arg str auth_password: Password for HTTP authentication + :arg str auth_mode: Authentication mode; default is "basic". + Allowed values are implementation-defined; ``curl_httpclient`` + supports "basic" and "digest"; ``simple_httpclient`` only supports + "basic" + :arg float connect_timeout: Timeout for initial connection in seconds, + default 20 seconds (0 means no timeout) + :arg float request_timeout: Timeout for entire request in seconds, + default 20 seconds (0 means no timeout) + :arg if_modified_since: Timestamp for ``If-Modified-Since`` header + :type if_modified_since: `datetime` or `float` + :arg bool follow_redirects: Should redirects be followed automatically + or return the 3xx response? Default True. + :arg int max_redirects: Limit for ``follow_redirects``, default 5. + :arg str user_agent: String to send as ``User-Agent`` header + :arg bool decompress_response: Request a compressed response from + the server and decompress it after downloading. Default is True. + New in Tornado 4.0. + :arg bool use_gzip: Deprecated alias for ``decompress_response`` + since Tornado 4.0. + :arg str network_interface: Network interface or source IP to use for request. + See ``curl_httpclient`` note below. + :arg collections.abc.Callable streaming_callback: If set, ``streaming_callback`` will + be run with each chunk of data as it is received, and + ``HTTPResponse.body`` and ``HTTPResponse.buffer`` will be empty in + the final response. + :arg collections.abc.Callable header_callback: If set, ``header_callback`` will + be run with each header line as it is received (including the + first line, e.g. ``HTTP/1.0 200 OK\r\n``, and a final line + containing only ``\r\n``. All lines include the trailing newline + characters). ``HTTPResponse.headers`` will be empty in the final + response. This is most useful in conjunction with + ``streaming_callback``, because it's the only way to get access to + header data while the request is in progress. + :arg collections.abc.Callable prepare_curl_callback: If set, will be called with + a ``pycurl.Curl`` object to allow the application to make additional + ``setopt`` calls. + :arg str proxy_host: HTTP proxy hostname. To use proxies, + ``proxy_host`` and ``proxy_port`` must be set; ``proxy_username``, + ``proxy_pass`` and ``proxy_auth_mode`` are optional. Proxies are + currently only supported with ``curl_httpclient``. + :arg int proxy_port: HTTP proxy port + :arg str proxy_username: HTTP proxy username + :arg str proxy_password: HTTP proxy password + :arg str proxy_auth_mode: HTTP proxy Authentication mode; + default is "basic". supports "basic" and "digest" + :arg bool allow_nonstandard_methods: Allow unknown values for ``method`` + argument? Default is False. + :arg bool validate_cert: For HTTPS requests, validate the server's + certificate? Default is True. + :arg str ca_certs: filename of CA certificates in PEM format, + or None to use defaults. See note below when used with + ``curl_httpclient``. + :arg str client_key: Filename for client SSL key, if any. See + note below when used with ``curl_httpclient``. + :arg str client_cert: Filename for client SSL certificate, if any. + See note below when used with ``curl_httpclient``. + :arg ssl.SSLContext ssl_options: `ssl.SSLContext` object for use in + ``simple_httpclient`` (unsupported by ``curl_httpclient``). + Overrides ``validate_cert``, ``ca_certs``, ``client_key``, + and ``client_cert``. + :arg bool allow_ipv6: Use IPv6 when available? Default is True. + :arg bool expect_100_continue: If true, send the + ``Expect: 100-continue`` header and wait for a continue response + before sending the request body. Only supported with + ``simple_httpclient``. + + .. note:: + + When using ``curl_httpclient`` certain options may be + inherited by subsequent fetches because ``pycurl`` does + not allow them to be cleanly reset. This applies to the + ``ca_certs``, ``client_key``, ``client_cert``, and + ``network_interface`` arguments. If you use these + options, you should pass them on every request (you don't + have to always use the same values, but it's not possible + to mix requests that specify these options with ones that + use the defaults). + + .. versionadded:: 3.1 + The ``auth_mode`` argument. + + .. versionadded:: 4.0 + The ``body_producer`` and ``expect_100_continue`` arguments. + + .. versionadded:: 4.2 + The ``ssl_options`` argument. + + .. versionadded:: 4.5 + The ``proxy_auth_mode`` argument. + """ + # Note that some of these attributes go through property setters + # defined below. + self.headers = headers # type: ignore + if if_modified_since: + self.headers["If-Modified-Since"] = httputil.format_timestamp( + if_modified_since + ) + self.proxy_host = proxy_host + self.proxy_port = proxy_port + self.proxy_username = proxy_username + self.proxy_password = proxy_password + self.proxy_auth_mode = proxy_auth_mode + self.url = url + self.method = method + self.body = body # type: ignore + self.body_producer = body_producer + self.auth_username = auth_username + self.auth_password = auth_password + self.auth_mode = auth_mode + self.connect_timeout = connect_timeout + self.request_timeout = request_timeout + self.follow_redirects = follow_redirects + self.max_redirects = max_redirects + self.user_agent = user_agent + if decompress_response is not None: + self.decompress_response = decompress_response # type: Optional[bool] + else: + self.decompress_response = use_gzip + self.network_interface = network_interface + self.streaming_callback = streaming_callback + self.header_callback = header_callback + self.prepare_curl_callback = prepare_curl_callback + self.allow_nonstandard_methods = allow_nonstandard_methods + self.validate_cert = validate_cert + self.ca_certs = ca_certs + self.allow_ipv6 = allow_ipv6 + self.client_key = client_key + self.client_cert = client_cert + self.ssl_options = ssl_options + self.expect_100_continue = expect_100_continue + self.start_time = time.time() + + @property + def headers(self) -> httputil.HTTPHeaders: + # TODO: headers may actually be a plain dict until fairly late in + # the process (AsyncHTTPClient.fetch), but practically speaking, + # whenever the property is used they're already HTTPHeaders. + return self._headers # type: ignore + + @headers.setter + def headers(self, value: Union[Dict[str, str], httputil.HTTPHeaders]) -> None: + if value is None: + self._headers = httputil.HTTPHeaders() + else: + self._headers = value # type: ignore + + @property + def body(self) -> bytes: + return self._body + + @body.setter + def body(self, value: Union[bytes, str]) -> None: + self._body = utf8(value) + + +class HTTPResponse(object): + """HTTP Response object. + + Attributes: + + * ``request``: HTTPRequest object + + * ``code``: numeric HTTP status code, e.g. 200 or 404 + + * ``reason``: human-readable reason phrase describing the status code + + * ``headers``: `tornado.httputil.HTTPHeaders` object + + * ``effective_url``: final location of the resource after following any + redirects + + * ``buffer``: ``cStringIO`` object for response body + + * ``body``: response body as bytes (created on demand from ``self.buffer``) + + * ``error``: Exception object, if any + + * ``request_time``: seconds from request start to finish. Includes all + network operations from DNS resolution to receiving the last byte of + data. Does not include time spent in the queue (due to the + ``max_clients`` option). If redirects were followed, only includes + the final request. + + * ``start_time``: Time at which the HTTP operation started, based on + `time.time` (not the monotonic clock used by `.IOLoop.time`). May + be ``None`` if the request timed out while in the queue. + + * ``time_info``: dictionary of diagnostic timing information from the + request. Available data are subject to change, but currently uses timings + available from http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html, + plus ``queue``, which is the delay (if any) introduced by waiting for + a slot under `AsyncHTTPClient`'s ``max_clients`` setting. + + .. versionadded:: 5.1 + + Added the ``start_time`` attribute. + + .. versionchanged:: 5.1 + + The ``request_time`` attribute previously included time spent in the queue + for ``simple_httpclient``, but not in ``curl_httpclient``. Now queueing time + is excluded in both implementations. ``request_time`` is now more accurate for + ``curl_httpclient`` because it uses a monotonic clock when available. + """ + + # I'm not sure why these don't get type-inferred from the references in __init__. + error = None # type: Optional[BaseException] + _error_is_response_code = False + request = None # type: HTTPRequest + + def __init__( + self, + request: HTTPRequest, + code: int, + headers: Optional[httputil.HTTPHeaders] = None, + buffer: Optional[BytesIO] = None, + effective_url: Optional[str] = None, + error: Optional[BaseException] = None, + request_time: Optional[float] = None, + time_info: Optional[Dict[str, float]] = None, + reason: Optional[str] = None, + start_time: Optional[float] = None, + ) -> None: + if isinstance(request, _RequestProxy): + self.request = request.request + else: + self.request = request + self.code = code + self.reason = reason or httputil.responses.get(code, "Unknown") + if headers is not None: + self.headers = headers + else: + self.headers = httputil.HTTPHeaders() + self.buffer = buffer + self._body = None # type: Optional[bytes] + if effective_url is None: + self.effective_url = request.url + else: + self.effective_url = effective_url + self._error_is_response_code = False + if error is None: + if self.code < 200 or self.code >= 300: + self._error_is_response_code = True + self.error = HTTPError(self.code, message=self.reason, response=self) + else: + self.error = None + else: + self.error = error + self.start_time = start_time + self.request_time = request_time + self.time_info = time_info or {} + + @property + def body(self) -> bytes: + if self.buffer is None: + return b"" + elif self._body is None: + self._body = self.buffer.getvalue() + + return self._body + + def rethrow(self) -> None: + """If there was an error on the request, raise an `HTTPError`.""" + if self.error: + raise self.error + + def __repr__(self) -> str: + args = ",".join("%s=%r" % i for i in sorted(self.__dict__.items())) + return "%s(%s)" % (self.__class__.__name__, args) + + +class HTTPClientError(Exception): + """Exception thrown for an unsuccessful HTTP request. + + Attributes: + + * ``code`` - HTTP error integer error code, e.g. 404. Error code 599 is + used when no HTTP response was received, e.g. for a timeout. + + * ``response`` - `HTTPResponse` object, if any. + + Note that if ``follow_redirects`` is False, redirects become HTTPErrors, + and you can look at ``error.response.headers['Location']`` to see the + destination of the redirect. + + .. versionchanged:: 5.1 + + Renamed from ``HTTPError`` to ``HTTPClientError`` to avoid collisions with + `tornado.web.HTTPError`. The name ``tornado.httpclient.HTTPError`` remains + as an alias. + """ + + def __init__( + self, + code: int, + message: Optional[str] = None, + response: Optional[HTTPResponse] = None, + ) -> None: + self.code = code + self.message = message or httputil.responses.get(code, "Unknown") + self.response = response + super().__init__(code, message, response) + + def __str__(self) -> str: + return "HTTP %d: %s" % (self.code, self.message) + + # There is a cyclic reference between self and self.response, + # which breaks the default __repr__ implementation. + # (especially on pypy, which doesn't have the same recursion + # detection as cpython). + __repr__ = __str__ + + +HTTPError = HTTPClientError + + +class _RequestProxy(object): + """Combines an object with a dictionary of defaults. + + Used internally by AsyncHTTPClient implementations. + """ + + def __init__( + self, request: HTTPRequest, defaults: Optional[Dict[str, Any]] + ) -> None: + self.request = request + self.defaults = defaults + + def __getattr__(self, name: str) -> Any: + request_attr = getattr(self.request, name) + if request_attr is not None: + return request_attr + elif self.defaults is not None: + return self.defaults.get(name, None) + else: + return None + + +def main() -> None: + from tornado.options import define, options, parse_command_line + + define("print_headers", type=bool, default=False) + define("print_body", type=bool, default=True) + define("follow_redirects", type=bool, default=True) + define("validate_cert", type=bool, default=True) + define("proxy_host", type=str) + define("proxy_port", type=int) + args = parse_command_line() + client = HTTPClient() + for arg in args: + try: + response = client.fetch( + arg, + follow_redirects=options.follow_redirects, + validate_cert=options.validate_cert, + proxy_host=options.proxy_host, + proxy_port=options.proxy_port, + ) + except HTTPError as e: + if e.response is not None: + response = e.response + else: + raise + if options.print_headers: + print(response.headers) + if options.print_body: + print(native_str(response.body)) + client.close() + + +if __name__ == "__main__": + main() diff --git a/venv/lib/python3.9/site-packages/tornado/httpserver.py b/venv/lib/python3.9/site-packages/tornado/httpserver.py new file mode 100644 index 00000000..77dc541e --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/httpserver.py @@ -0,0 +1,410 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""A non-blocking, single-threaded HTTP server. + +Typical applications have little direct interaction with the `HTTPServer` +class except to start a server at the beginning of the process +(and even that is often done indirectly via `tornado.web.Application.listen`). + +.. versionchanged:: 4.0 + + The ``HTTPRequest`` class that used to live in this module has been moved + to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias. +""" + +import socket +import ssl + +from tornado.escape import native_str +from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters +from tornado import httputil +from tornado import iostream +from tornado import netutil +from tornado.tcpserver import TCPServer +from tornado.util import Configurable + +import typing +from typing import Union, Any, Dict, Callable, List, Type, Tuple, Optional, Awaitable + +if typing.TYPE_CHECKING: + from typing import Set # noqa: F401 + + +class HTTPServer(TCPServer, Configurable, httputil.HTTPServerConnectionDelegate): + r"""A non-blocking, single-threaded HTTP server. + + A server is defined by a subclass of `.HTTPServerConnectionDelegate`, + or, for backwards compatibility, a callback that takes an + `.HTTPServerRequest` as an argument. The delegate is usually a + `tornado.web.Application`. + + `HTTPServer` supports keep-alive connections by default + (automatically for HTTP/1.1, or for HTTP/1.0 when the client + requests ``Connection: keep-alive``). + + If ``xheaders`` is ``True``, we support the + ``X-Real-Ip``/``X-Forwarded-For`` and + ``X-Scheme``/``X-Forwarded-Proto`` headers, which override the + remote IP and URI scheme/protocol for all requests. These headers + are useful when running Tornado behind a reverse proxy or load + balancer. The ``protocol`` argument can also be set to ``https`` + if Tornado is run behind an SSL-decoding proxy that does not set one of + the supported ``xheaders``. + + By default, when parsing the ``X-Forwarded-For`` header, Tornado will + select the last (i.e., the closest) address on the list of hosts as the + remote host IP address. To select the next server in the chain, a list of + trusted downstream hosts may be passed as the ``trusted_downstream`` + argument. These hosts will be skipped when parsing the ``X-Forwarded-For`` + header. + + To make this server serve SSL traffic, send the ``ssl_options`` keyword + argument with an `ssl.SSLContext` object. For compatibility with older + versions of Python ``ssl_options`` may also be a dictionary of keyword + arguments for the `ssl.wrap_socket` method.:: + + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"), + os.path.join(data_dir, "mydomain.key")) + HTTPServer(application, ssl_options=ssl_ctx) + + `HTTPServer` initialization follows one of three patterns (the + initialization methods are defined on `tornado.tcpserver.TCPServer`): + + 1. `~tornado.tcpserver.TCPServer.listen`: single-process:: + + async def main(): + server = HTTPServer() + server.listen(8888) + await asyncio.Event.wait() + + asyncio.run(main()) + + In many cases, `tornado.web.Application.listen` can be used to avoid + the need to explicitly create the `HTTPServer`. + + While this example does not create multiple processes on its own, when + the ``reuse_port=True`` argument is passed to ``listen()`` you can run + the program multiple times to create a multi-process service. + + 2. `~tornado.tcpserver.TCPServer.add_sockets`: multi-process:: + + sockets = bind_sockets(8888) + tornado.process.fork_processes(0) + async def post_fork_main(): + server = HTTPServer() + server.add_sockets(sockets) + await asyncio.Event().wait() + asyncio.run(post_fork_main()) + + The ``add_sockets`` interface is more complicated, but it can be used with + `tornado.process.fork_processes` to run a multi-process service with all + worker processes forked from a single parent. ``add_sockets`` can also be + used in single-process servers if you want to create your listening + sockets in some way other than `~tornado.netutil.bind_sockets`. + + Note that when using this pattern, nothing that touches the event loop + can be run before ``fork_processes``. + + 3. `~tornado.tcpserver.TCPServer.bind`/`~tornado.tcpserver.TCPServer.start`: + simple **deprecated** multi-process:: + + server = HTTPServer() + server.bind(8888) + server.start(0) # Forks multiple sub-processes + IOLoop.current().start() + + This pattern is deprecated because it requires interfaces in the + `asyncio` module that have been deprecated since Python 3.10. Support for + creating multiple processes in the ``start`` method will be removed in a + future version of Tornado. + + .. versionchanged:: 4.0 + Added ``decompress_request``, ``chunk_size``, ``max_header_size``, + ``idle_connection_timeout``, ``body_timeout``, ``max_body_size`` + arguments. Added support for `.HTTPServerConnectionDelegate` + instances as ``request_callback``. + + .. versionchanged:: 4.1 + `.HTTPServerConnectionDelegate.start_request` is now called with + two arguments ``(server_conn, request_conn)`` (in accordance with the + documentation) instead of one ``(request_conn)``. + + .. versionchanged:: 4.2 + `HTTPServer` is now a subclass of `tornado.util.Configurable`. + + .. versionchanged:: 4.5 + Added the ``trusted_downstream`` argument. + + .. versionchanged:: 5.0 + The ``io_loop`` argument has been removed. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + # Ignore args to __init__; real initialization belongs in + # initialize since we're Configurable. (there's something + # weird in initialization order between this class, + # Configurable, and TCPServer so we can't leave __init__ out + # completely) + pass + + def initialize( + self, + request_callback: Union[ + httputil.HTTPServerConnectionDelegate, + Callable[[httputil.HTTPServerRequest], None], + ], + no_keep_alive: bool = False, + xheaders: bool = False, + ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None, + protocol: Optional[str] = None, + decompress_request: bool = False, + chunk_size: Optional[int] = None, + max_header_size: Optional[int] = None, + idle_connection_timeout: Optional[float] = None, + body_timeout: Optional[float] = None, + max_body_size: Optional[int] = None, + max_buffer_size: Optional[int] = None, + trusted_downstream: Optional[List[str]] = None, + ) -> None: + # This method's signature is not extracted with autodoc + # because we want its arguments to appear on the class + # constructor. When changing this signature, also update the + # copy in httpserver.rst. + self.request_callback = request_callback + self.xheaders = xheaders + self.protocol = protocol + self.conn_params = HTTP1ConnectionParameters( + decompress=decompress_request, + chunk_size=chunk_size, + max_header_size=max_header_size, + header_timeout=idle_connection_timeout or 3600, + max_body_size=max_body_size, + body_timeout=body_timeout, + no_keep_alive=no_keep_alive, + ) + TCPServer.__init__( + self, + ssl_options=ssl_options, + max_buffer_size=max_buffer_size, + read_chunk_size=chunk_size, + ) + self._connections = set() # type: Set[HTTP1ServerConnection] + self.trusted_downstream = trusted_downstream + + @classmethod + def configurable_base(cls) -> Type[Configurable]: + return HTTPServer + + @classmethod + def configurable_default(cls) -> Type[Configurable]: + return HTTPServer + + async def close_all_connections(self) -> None: + """Close all open connections and asynchronously wait for them to finish. + + This method is used in combination with `~.TCPServer.stop` to + support clean shutdowns (especially for unittests). Typical + usage would call ``stop()`` first to stop accepting new + connections, then ``await close_all_connections()`` to wait for + existing connections to finish. + + This method does not currently close open websocket connections. + + Note that this method is a coroutine and must be called with ``await``. + + """ + while self._connections: + # Peek at an arbitrary element of the set + conn = next(iter(self._connections)) + await conn.close() + + def handle_stream(self, stream: iostream.IOStream, address: Tuple) -> None: + context = _HTTPRequestContext( + stream, address, self.protocol, self.trusted_downstream + ) + conn = HTTP1ServerConnection(stream, self.conn_params, context) + self._connections.add(conn) + conn.start_serving(self) + + def start_request( + self, server_conn: object, request_conn: httputil.HTTPConnection + ) -> httputil.HTTPMessageDelegate: + if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate): + delegate = self.request_callback.start_request(server_conn, request_conn) + else: + delegate = _CallableAdapter(self.request_callback, request_conn) + + if self.xheaders: + delegate = _ProxyAdapter(delegate, request_conn) + + return delegate + + def on_close(self, server_conn: object) -> None: + self._connections.remove(typing.cast(HTTP1ServerConnection, server_conn)) + + +class _CallableAdapter(httputil.HTTPMessageDelegate): + def __init__( + self, + request_callback: Callable[[httputil.HTTPServerRequest], None], + request_conn: httputil.HTTPConnection, + ) -> None: + self.connection = request_conn + self.request_callback = request_callback + self.request = None # type: Optional[httputil.HTTPServerRequest] + self.delegate = None + self._chunks = [] # type: List[bytes] + + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: + self.request = httputil.HTTPServerRequest( + connection=self.connection, + start_line=typing.cast(httputil.RequestStartLine, start_line), + headers=headers, + ) + return None + + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: + self._chunks.append(chunk) + return None + + def finish(self) -> None: + assert self.request is not None + self.request.body = b"".join(self._chunks) + self.request._parse_body() + self.request_callback(self.request) + + def on_connection_close(self) -> None: + del self._chunks + + +class _HTTPRequestContext(object): + def __init__( + self, + stream: iostream.IOStream, + address: Tuple, + protocol: Optional[str], + trusted_downstream: Optional[List[str]] = None, + ) -> None: + self.address = address + # Save the socket's address family now so we know how to + # interpret self.address even after the stream is closed + # and its socket attribute replaced with None. + if stream.socket is not None: + self.address_family = stream.socket.family + else: + self.address_family = None + # In HTTPServerRequest we want an IP, not a full socket address. + if ( + self.address_family in (socket.AF_INET, socket.AF_INET6) + and address is not None + ): + self.remote_ip = address[0] + else: + # Unix (or other) socket; fake the remote address. + self.remote_ip = "0.0.0.0" + if protocol: + self.protocol = protocol + elif isinstance(stream, iostream.SSLIOStream): + self.protocol = "https" + else: + self.protocol = "http" + self._orig_remote_ip = self.remote_ip + self._orig_protocol = self.protocol + self.trusted_downstream = set(trusted_downstream or []) + + def __str__(self) -> str: + if self.address_family in (socket.AF_INET, socket.AF_INET6): + return self.remote_ip + elif isinstance(self.address, bytes): + # Python 3 with the -bb option warns about str(bytes), + # so convert it explicitly. + # Unix socket addresses are str on mac but bytes on linux. + return native_str(self.address) + else: + return str(self.address) + + def _apply_xheaders(self, headers: httputil.HTTPHeaders) -> None: + """Rewrite the ``remote_ip`` and ``protocol`` fields.""" + # Squid uses X-Forwarded-For, others use X-Real-Ip + ip = headers.get("X-Forwarded-For", self.remote_ip) + # Skip trusted downstream hosts in X-Forwarded-For list + for ip in (cand.strip() for cand in reversed(ip.split(","))): + if ip not in self.trusted_downstream: + break + ip = headers.get("X-Real-Ip", ip) + if netutil.is_valid_ip(ip): + self.remote_ip = ip + # AWS uses X-Forwarded-Proto + proto_header = headers.get( + "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol) + ) + if proto_header: + # use only the last proto entry if there is more than one + # TODO: support trusting multiple layers of proxied protocol + proto_header = proto_header.split(",")[-1].strip() + if proto_header in ("http", "https"): + self.protocol = proto_header + + def _unapply_xheaders(self) -> None: + """Undo changes from `_apply_xheaders`. + + Xheaders are per-request so they should not leak to the next + request on the same connection. + """ + self.remote_ip = self._orig_remote_ip + self.protocol = self._orig_protocol + + +class _ProxyAdapter(httputil.HTTPMessageDelegate): + def __init__( + self, + delegate: httputil.HTTPMessageDelegate, + request_conn: httputil.HTTPConnection, + ) -> None: + self.connection = request_conn + self.delegate = delegate + + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: + # TODO: either make context an official part of the + # HTTPConnection interface or figure out some other way to do this. + self.connection.context._apply_xheaders(headers) # type: ignore + return self.delegate.headers_received(start_line, headers) + + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: + return self.delegate.data_received(chunk) + + def finish(self) -> None: + self.delegate.finish() + self._cleanup() + + def on_connection_close(self) -> None: + self.delegate.on_connection_close() + self._cleanup() + + def _cleanup(self) -> None: + self.connection.context._unapply_xheaders() # type: ignore + + +HTTPRequest = httputil.HTTPServerRequest diff --git a/venv/lib/python3.9/site-packages/tornado/httputil.py b/venv/lib/python3.9/site-packages/tornado/httputil.py new file mode 100644 index 00000000..9c341d47 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/httputil.py @@ -0,0 +1,1134 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""HTTP utility code shared by clients and servers. + +This module also defines the `HTTPServerRequest` class which is exposed +via `tornado.web.RequestHandler.request`. +""" + +import calendar +import collections.abc +import copy +import datetime +import email.utils +from functools import lru_cache +from http.client import responses +import http.cookies +import re +from ssl import SSLError +import time +import unicodedata +from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl + +from tornado.escape import native_str, parse_qs_bytes, utf8 +from tornado.log import gen_log +from tornado.util import ObjectDict, unicode_type + + +# responses is unused in this file, but we re-export it to other files. +# Reference it so pyflakes doesn't complain. +responses + +import typing +from typing import ( + Tuple, + Iterable, + List, + Mapping, + Iterator, + Dict, + Union, + Optional, + Awaitable, + Generator, + AnyStr, +) + +if typing.TYPE_CHECKING: + from typing import Deque # noqa: F401 + from asyncio import Future # noqa: F401 + import unittest # noqa: F401 + + +@lru_cache(1000) +def _normalize_header(name: str) -> str: + """Map a header name to Http-Header-Case. + + >>> _normalize_header("coNtent-TYPE") + 'Content-Type' + """ + return "-".join([w.capitalize() for w in name.split("-")]) + + +class HTTPHeaders(collections.abc.MutableMapping): + """A dictionary that maintains ``Http-Header-Case`` for all keys. + + Supports multiple values per key via a pair of new methods, + `add()` and `get_list()`. The regular dictionary interface + returns a single value per key, with multiple values joined by a + comma. + + >>> h = HTTPHeaders({"content-type": "text/html"}) + >>> list(h.keys()) + ['Content-Type'] + >>> h["Content-Type"] + 'text/html' + + >>> h.add("Set-Cookie", "A=B") + >>> h.add("Set-Cookie", "C=D") + >>> h["set-cookie"] + 'A=B,C=D' + >>> h.get_list("set-cookie") + ['A=B', 'C=D'] + + >>> for (k,v) in sorted(h.get_all()): + ... print('%s: %s' % (k,v)) + ... + Content-Type: text/html + Set-Cookie: A=B + Set-Cookie: C=D + """ + + @typing.overload + def __init__(self, __arg: Mapping[str, List[str]]) -> None: + pass + + @typing.overload # noqa: F811 + def __init__(self, __arg: Mapping[str, str]) -> None: + pass + + @typing.overload # noqa: F811 + def __init__(self, *args: Tuple[str, str]) -> None: + pass + + @typing.overload # noqa: F811 + def __init__(self, **kwargs: str) -> None: + pass + + def __init__(self, *args: typing.Any, **kwargs: str) -> None: # noqa: F811 + self._dict = {} # type: typing.Dict[str, str] + self._as_list = {} # type: typing.Dict[str, typing.List[str]] + self._last_key = None # type: Optional[str] + if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], HTTPHeaders): + # Copy constructor + for k, v in args[0].get_all(): + self.add(k, v) + else: + # Dict-style initialization + self.update(*args, **kwargs) + + # new public methods + + def add(self, name: str, value: str) -> None: + """Adds a new value for the given key.""" + norm_name = _normalize_header(name) + self._last_key = norm_name + if norm_name in self: + self._dict[norm_name] = ( + native_str(self[norm_name]) + "," + native_str(value) + ) + self._as_list[norm_name].append(value) + else: + self[norm_name] = value + + def get_list(self, name: str) -> List[str]: + """Returns all values for the given header as a list.""" + norm_name = _normalize_header(name) + return self._as_list.get(norm_name, []) + + def get_all(self) -> Iterable[Tuple[str, str]]: + """Returns an iterable of all (name, value) pairs. + + If a header has multiple values, multiple pairs will be + returned with the same name. + """ + for name, values in self._as_list.items(): + for value in values: + yield (name, value) + + def parse_line(self, line: str) -> None: + """Updates the dictionary with a single header line. + + >>> h = HTTPHeaders() + >>> h.parse_line("Content-Type: text/html") + >>> h.get('content-type') + 'text/html' + """ + if line[0].isspace(): + # continuation of a multi-line header + if self._last_key is None: + raise HTTPInputError("first header line cannot start with whitespace") + new_part = " " + line.lstrip() + self._as_list[self._last_key][-1] += new_part + self._dict[self._last_key] += new_part + else: + try: + name, value = line.split(":", 1) + except ValueError: + raise HTTPInputError("no colon in header line") + self.add(name, value.strip()) + + @classmethod + def parse(cls, headers: str) -> "HTTPHeaders": + """Returns a dictionary from HTTP header text. + + >>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n") + >>> sorted(h.items()) + [('Content-Length', '42'), ('Content-Type', 'text/html')] + + .. versionchanged:: 5.1 + + Raises `HTTPInputError` on malformed headers instead of a + mix of `KeyError`, and `ValueError`. + + """ + h = cls() + # RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line + # terminator and ignore any preceding CR. + for line in headers.split("\n"): + if line.endswith("\r"): + line = line[:-1] + if line: + h.parse_line(line) + return h + + # MutableMapping abstract method implementations. + + def __setitem__(self, name: str, value: str) -> None: + norm_name = _normalize_header(name) + self._dict[norm_name] = value + self._as_list[norm_name] = [value] + + def __getitem__(self, name: str) -> str: + return self._dict[_normalize_header(name)] + + def __delitem__(self, name: str) -> None: + norm_name = _normalize_header(name) + del self._dict[norm_name] + del self._as_list[norm_name] + + def __len__(self) -> int: + return len(self._dict) + + def __iter__(self) -> Iterator[typing.Any]: + return iter(self._dict) + + def copy(self) -> "HTTPHeaders": + # defined in dict but not in MutableMapping. + return HTTPHeaders(self) + + # Use our overridden copy method for the copy.copy module. + # This makes shallow copies one level deeper, but preserves + # the appearance that HTTPHeaders is a single container. + __copy__ = copy + + def __str__(self) -> str: + lines = [] + for name, value in self.get_all(): + lines.append("%s: %s\n" % (name, value)) + return "".join(lines) + + __unicode__ = __str__ + + +class HTTPServerRequest(object): + """A single HTTP request. + + All attributes are type `str` unless otherwise noted. + + .. attribute:: method + + HTTP request method, e.g. "GET" or "POST" + + .. attribute:: uri + + The requested uri. + + .. attribute:: path + + The path portion of `uri` + + .. attribute:: query + + The query portion of `uri` + + .. attribute:: version + + HTTP version specified in request, e.g. "HTTP/1.1" + + .. attribute:: headers + + `.HTTPHeaders` dictionary-like object for request headers. Acts like + a case-insensitive dictionary with additional methods for repeated + headers. + + .. attribute:: body + + Request body, if present, as a byte string. + + .. attribute:: remote_ip + + Client's IP address as a string. If ``HTTPServer.xheaders`` is set, + will pass along the real IP address provided by a load balancer + in the ``X-Real-Ip`` or ``X-Forwarded-For`` header. + + .. versionchanged:: 3.1 + The list format of ``X-Forwarded-For`` is now supported. + + .. attribute:: protocol + + The protocol used, either "http" or "https". If ``HTTPServer.xheaders`` + is set, will pass along the protocol used by a load balancer if + reported via an ``X-Scheme`` header. + + .. attribute:: host + + The requested hostname, usually taken from the ``Host`` header. + + .. attribute:: arguments + + GET/POST arguments are available in the arguments property, which + maps arguments names to lists of values (to support multiple values + for individual names). Names are of type `str`, while arguments + are byte strings. Note that this is different from + `.RequestHandler.get_argument`, which returns argument values as + unicode strings. + + .. attribute:: query_arguments + + Same format as ``arguments``, but contains only arguments extracted + from the query string. + + .. versionadded:: 3.2 + + .. attribute:: body_arguments + + Same format as ``arguments``, but contains only arguments extracted + from the request body. + + .. versionadded:: 3.2 + + .. attribute:: files + + File uploads are available in the files property, which maps file + names to lists of `.HTTPFile`. + + .. attribute:: connection + + An HTTP request is attached to a single HTTP connection, which can + be accessed through the "connection" attribute. Since connections + are typically kept open in HTTP/1.1, multiple requests can be handled + sequentially on a single connection. + + .. versionchanged:: 4.0 + Moved from ``tornado.httpserver.HTTPRequest``. + """ + + path = None # type: str + query = None # type: str + + # HACK: Used for stream_request_body + _body_future = None # type: Future[None] + + def __init__( + self, + method: Optional[str] = None, + uri: Optional[str] = None, + version: str = "HTTP/1.0", + headers: Optional[HTTPHeaders] = None, + body: Optional[bytes] = None, + host: Optional[str] = None, + files: Optional[Dict[str, List["HTTPFile"]]] = None, + connection: Optional["HTTPConnection"] = None, + start_line: Optional["RequestStartLine"] = None, + server_connection: Optional[object] = None, + ) -> None: + if start_line is not None: + method, uri, version = start_line + self.method = method + self.uri = uri + self.version = version + self.headers = headers or HTTPHeaders() + self.body = body or b"" + + # set remote IP and protocol + context = getattr(connection, "context", None) + self.remote_ip = getattr(context, "remote_ip", None) + self.protocol = getattr(context, "protocol", "http") + + self.host = host or self.headers.get("Host") or "127.0.0.1" + self.host_name = split_host_and_port(self.host.lower())[0] + self.files = files or {} + self.connection = connection + self.server_connection = server_connection + self._start_time = time.time() + self._finish_time = None + + if uri is not None: + self.path, sep, self.query = uri.partition("?") + self.arguments = parse_qs_bytes(self.query, keep_blank_values=True) + self.query_arguments = copy.deepcopy(self.arguments) + self.body_arguments = {} # type: Dict[str, List[bytes]] + + @property + def cookies(self) -> Dict[str, http.cookies.Morsel]: + """A dictionary of ``http.cookies.Morsel`` objects.""" + if not hasattr(self, "_cookies"): + self._cookies = ( + http.cookies.SimpleCookie() + ) # type: http.cookies.SimpleCookie + if "Cookie" in self.headers: + try: + parsed = parse_cookie(self.headers["Cookie"]) + except Exception: + pass + else: + for k, v in parsed.items(): + try: + self._cookies[k] = v + except Exception: + # SimpleCookie imposes some restrictions on keys; + # parse_cookie does not. Discard any cookies + # with disallowed keys. + pass + return self._cookies + + def full_url(self) -> str: + """Reconstructs the full URL for this request.""" + return self.protocol + "://" + self.host + self.uri # type: ignore[operator] + + def request_time(self) -> float: + """Returns the amount of time it took for this request to execute.""" + if self._finish_time is None: + return time.time() - self._start_time + else: + return self._finish_time - self._start_time + + def get_ssl_certificate( + self, binary_form: bool = False + ) -> Union[None, Dict, bytes]: + """Returns the client's SSL certificate, if any. + + To use client certificates, the HTTPServer's + `ssl.SSLContext.verify_mode` field must be set, e.g.:: + + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain("foo.crt", "foo.key") + ssl_ctx.load_verify_locations("cacerts.pem") + ssl_ctx.verify_mode = ssl.CERT_REQUIRED + server = HTTPServer(app, ssl_options=ssl_ctx) + + By default, the return value is a dictionary (or None, if no + client certificate is present). If ``binary_form`` is true, a + DER-encoded form of the certificate is returned instead. See + SSLSocket.getpeercert() in the standard library for more + details. + http://docs.python.org/library/ssl.html#sslsocket-objects + """ + try: + if self.connection is None: + return None + # TODO: add a method to HTTPConnection for this so it can work with HTTP/2 + return self.connection.stream.socket.getpeercert( # type: ignore + binary_form=binary_form + ) + except SSLError: + return None + + def _parse_body(self) -> None: + parse_body_arguments( + self.headers.get("Content-Type", ""), + self.body, + self.body_arguments, + self.files, + self.headers, + ) + + for k, v in self.body_arguments.items(): + self.arguments.setdefault(k, []).extend(v) + + def __repr__(self) -> str: + attrs = ("protocol", "host", "method", "uri", "version", "remote_ip") + args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs]) + return "%s(%s)" % (self.__class__.__name__, args) + + +class HTTPInputError(Exception): + """Exception class for malformed HTTP requests or responses + from remote sources. + + .. versionadded:: 4.0 + """ + + pass + + +class HTTPOutputError(Exception): + """Exception class for errors in HTTP output. + + .. versionadded:: 4.0 + """ + + pass + + +class HTTPServerConnectionDelegate(object): + """Implement this interface to handle requests from `.HTTPServer`. + + .. versionadded:: 4.0 + """ + + def start_request( + self, server_conn: object, request_conn: "HTTPConnection" + ) -> "HTTPMessageDelegate": + """This method is called by the server when a new request has started. + + :arg server_conn: is an opaque object representing the long-lived + (e.g. tcp-level) connection. + :arg request_conn: is a `.HTTPConnection` object for a single + request/response exchange. + + This method should return a `.HTTPMessageDelegate`. + """ + raise NotImplementedError() + + def on_close(self, server_conn: object) -> None: + """This method is called when a connection has been closed. + + :arg server_conn: is a server connection that has previously been + passed to ``start_request``. + """ + pass + + +class HTTPMessageDelegate(object): + """Implement this interface to handle an HTTP request or response. + + .. versionadded:: 4.0 + """ + + # TODO: genericize this class to avoid exposing the Union. + def headers_received( + self, + start_line: Union["RequestStartLine", "ResponseStartLine"], + headers: HTTPHeaders, + ) -> Optional[Awaitable[None]]: + """Called when the HTTP headers have been received and parsed. + + :arg start_line: a `.RequestStartLine` or `.ResponseStartLine` + depending on whether this is a client or server message. + :arg headers: a `.HTTPHeaders` instance. + + Some `.HTTPConnection` methods can only be called during + ``headers_received``. + + May return a `.Future`; if it does the body will not be read + until it is done. + """ + pass + + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: + """Called when a chunk of data has been received. + + May return a `.Future` for flow control. + """ + pass + + def finish(self) -> None: + """Called after the last chunk of data has been received.""" + pass + + def on_connection_close(self) -> None: + """Called if the connection is closed without finishing the request. + + If ``headers_received`` is called, either ``finish`` or + ``on_connection_close`` will be called, but not both. + """ + pass + + +class HTTPConnection(object): + """Applications use this interface to write their responses. + + .. versionadded:: 4.0 + """ + + def write_headers( + self, + start_line: Union["RequestStartLine", "ResponseStartLine"], + headers: HTTPHeaders, + chunk: Optional[bytes] = None, + ) -> "Future[None]": + """Write an HTTP header block. + + :arg start_line: a `.RequestStartLine` or `.ResponseStartLine`. + :arg headers: a `.HTTPHeaders` instance. + :arg chunk: the first (optional) chunk of data. This is an optimization + so that small responses can be written in the same call as their + headers. + + The ``version`` field of ``start_line`` is ignored. + + Returns a future for flow control. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. + """ + raise NotImplementedError() + + def write(self, chunk: bytes) -> "Future[None]": + """Writes a chunk of body data. + + Returns a future for flow control. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. + """ + raise NotImplementedError() + + def finish(self) -> None: + """Indicates that the last body data has been written.""" + raise NotImplementedError() + + +def url_concat( + url: str, + args: Union[ + None, Dict[str, str], List[Tuple[str, str]], Tuple[Tuple[str, str], ...] + ], +) -> str: + """Concatenate url and arguments regardless of whether + url has existing query parameters. + + ``args`` may be either a dictionary or a list of key-value pairs + (the latter allows for multiple values with the same key. + + >>> url_concat("http://example.com/foo", dict(c="d")) + 'http://example.com/foo?c=d' + >>> url_concat("http://example.com/foo?a=b", dict(c="d")) + 'http://example.com/foo?a=b&c=d' + >>> url_concat("http://example.com/foo?a=b", [("c", "d"), ("c", "d2")]) + 'http://example.com/foo?a=b&c=d&c=d2' + """ + if args is None: + return url + parsed_url = urlparse(url) + if isinstance(args, dict): + parsed_query = parse_qsl(parsed_url.query, keep_blank_values=True) + parsed_query.extend(args.items()) + elif isinstance(args, list) or isinstance(args, tuple): + parsed_query = parse_qsl(parsed_url.query, keep_blank_values=True) + parsed_query.extend(args) + else: + err = "'args' parameter should be dict, list or tuple. Not {0}".format( + type(args) + ) + raise TypeError(err) + final_query = urlencode(parsed_query) + url = urlunparse( + ( + parsed_url[0], + parsed_url[1], + parsed_url[2], + parsed_url[3], + final_query, + parsed_url[5], + ) + ) + return url + + +class HTTPFile(ObjectDict): + """Represents a file uploaded via a form. + + For backwards compatibility, its instance attributes are also + accessible as dictionary keys. + + * ``filename`` + * ``body`` + * ``content_type`` + """ + + filename: str + body: bytes + content_type: str + + +def _parse_request_range( + range_header: str, +) -> Optional[Tuple[Optional[int], Optional[int]]]: + """Parses a Range header. + + Returns either ``None`` or tuple ``(start, end)``. + Note that while the HTTP headers use inclusive byte positions, + this method returns indexes suitable for use in slices. + + >>> start, end = _parse_request_range("bytes=1-2") + >>> start, end + (1, 3) + >>> [0, 1, 2, 3, 4][start:end] + [1, 2] + >>> _parse_request_range("bytes=6-") + (6, None) + >>> _parse_request_range("bytes=-6") + (-6, None) + >>> _parse_request_range("bytes=-0") + (None, 0) + >>> _parse_request_range("bytes=") + (None, None) + >>> _parse_request_range("foo=42") + >>> _parse_request_range("bytes=1-2,6-10") + + Note: only supports one range (ex, ``bytes=1-2,6-10`` is not allowed). + + See [0] for the details of the range header. + + [0]: http://greenbytes.de/tech/webdav/draft-ietf-httpbis-p5-range-latest.html#byte.ranges + """ + unit, _, value = range_header.partition("=") + unit, value = unit.strip(), value.strip() + if unit != "bytes": + return None + start_b, _, end_b = value.partition("-") + try: + start = _int_or_none(start_b) + end = _int_or_none(end_b) + except ValueError: + return None + if end is not None: + if start is None: + if end != 0: + start = -end + end = None + else: + end += 1 + return (start, end) + + +def _get_content_range(start: Optional[int], end: Optional[int], total: int) -> str: + """Returns a suitable Content-Range header: + + >>> print(_get_content_range(None, 1, 4)) + bytes 0-0/4 + >>> print(_get_content_range(1, 3, 4)) + bytes 1-2/4 + >>> print(_get_content_range(None, None, 4)) + bytes 0-3/4 + """ + start = start or 0 + end = (end or total) - 1 + return "bytes %s-%s/%s" % (start, end, total) + + +def _int_or_none(val: str) -> Optional[int]: + val = val.strip() + if val == "": + return None + return int(val) + + +def parse_body_arguments( + content_type: str, + body: bytes, + arguments: Dict[str, List[bytes]], + files: Dict[str, List[HTTPFile]], + headers: Optional[HTTPHeaders] = None, +) -> None: + """Parses a form request body. + + Supports ``application/x-www-form-urlencoded`` and + ``multipart/form-data``. The ``content_type`` parameter should be + a string and ``body`` should be a byte string. The ``arguments`` + and ``files`` parameters are dictionaries that will be updated + with the parsed contents. + """ + if content_type.startswith("application/x-www-form-urlencoded"): + if headers and "Content-Encoding" in headers: + gen_log.warning( + "Unsupported Content-Encoding: %s", headers["Content-Encoding"] + ) + return + try: + # real charset decoding will happen in RequestHandler.decode_argument() + uri_arguments = parse_qs_bytes(body, keep_blank_values=True) + except Exception as e: + gen_log.warning("Invalid x-www-form-urlencoded body: %s", e) + uri_arguments = {} + for name, values in uri_arguments.items(): + if values: + arguments.setdefault(name, []).extend(values) + elif content_type.startswith("multipart/form-data"): + if headers and "Content-Encoding" in headers: + gen_log.warning( + "Unsupported Content-Encoding: %s", headers["Content-Encoding"] + ) + return + try: + fields = content_type.split(";") + for field in fields: + k, sep, v = field.strip().partition("=") + if k == "boundary" and v: + parse_multipart_form_data(utf8(v), body, arguments, files) + break + else: + raise ValueError("multipart boundary not found") + except Exception as e: + gen_log.warning("Invalid multipart/form-data: %s", e) + + +def parse_multipart_form_data( + boundary: bytes, + data: bytes, + arguments: Dict[str, List[bytes]], + files: Dict[str, List[HTTPFile]], +) -> None: + """Parses a ``multipart/form-data`` body. + + The ``boundary`` and ``data`` parameters are both byte strings. + The dictionaries given in the arguments and files parameters + will be updated with the contents of the body. + + .. versionchanged:: 5.1 + + Now recognizes non-ASCII filenames in RFC 2231/5987 + (``filename*=``) format. + """ + # The standard allows for the boundary to be quoted in the header, + # although it's rare (it happens at least for google app engine + # xmpp). I think we're also supposed to handle backslash-escapes + # here but I'll save that until we see a client that uses them + # in the wild. + if boundary.startswith(b'"') and boundary.endswith(b'"'): + boundary = boundary[1:-1] + final_boundary_index = data.rfind(b"--" + boundary + b"--") + if final_boundary_index == -1: + gen_log.warning("Invalid multipart/form-data: no final boundary") + return + parts = data[:final_boundary_index].split(b"--" + boundary + b"\r\n") + for part in parts: + if not part: + continue + eoh = part.find(b"\r\n\r\n") + if eoh == -1: + gen_log.warning("multipart/form-data missing headers") + continue + headers = HTTPHeaders.parse(part[:eoh].decode("utf-8")) + disp_header = headers.get("Content-Disposition", "") + disposition, disp_params = _parse_header(disp_header) + if disposition != "form-data" or not part.endswith(b"\r\n"): + gen_log.warning("Invalid multipart/form-data") + continue + value = part[eoh + 4 : -2] + if not disp_params.get("name"): + gen_log.warning("multipart/form-data value missing name") + continue + name = disp_params["name"] + if disp_params.get("filename"): + ctype = headers.get("Content-Type", "application/unknown") + files.setdefault(name, []).append( + HTTPFile( + filename=disp_params["filename"], body=value, content_type=ctype + ) + ) + else: + arguments.setdefault(name, []).append(value) + + +def format_timestamp( + ts: Union[int, float, tuple, time.struct_time, datetime.datetime] +) -> str: + """Formats a timestamp in the format used by HTTP. + + The argument may be a numeric timestamp as returned by `time.time`, + a time tuple as returned by `time.gmtime`, or a `datetime.datetime` + object. + + >>> format_timestamp(1359312200) + 'Sun, 27 Jan 2013 18:43:20 GMT' + """ + if isinstance(ts, (int, float)): + time_num = ts + elif isinstance(ts, (tuple, time.struct_time)): + time_num = calendar.timegm(ts) + elif isinstance(ts, datetime.datetime): + time_num = calendar.timegm(ts.utctimetuple()) + else: + raise TypeError("unknown timestamp type: %r" % ts) + return email.utils.formatdate(time_num, usegmt=True) + + +RequestStartLine = collections.namedtuple( + "RequestStartLine", ["method", "path", "version"] +) + + +_http_version_re = re.compile(r"^HTTP/1\.[0-9]$") + + +def parse_request_start_line(line: str) -> RequestStartLine: + """Returns a (method, path, version) tuple for an HTTP 1.x request line. + + The response is a `collections.namedtuple`. + + >>> parse_request_start_line("GET /foo HTTP/1.1") + RequestStartLine(method='GET', path='/foo', version='HTTP/1.1') + """ + try: + method, path, version = line.split(" ") + except ValueError: + # https://tools.ietf.org/html/rfc7230#section-3.1.1 + # invalid request-line SHOULD respond with a 400 (Bad Request) + raise HTTPInputError("Malformed HTTP request line") + if not _http_version_re.match(version): + raise HTTPInputError( + "Malformed HTTP version in HTTP Request-Line: %r" % version + ) + return RequestStartLine(method, path, version) + + +ResponseStartLine = collections.namedtuple( + "ResponseStartLine", ["version", "code", "reason"] +) + + +_http_response_line_re = re.compile(r"(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)") + + +def parse_response_start_line(line: str) -> ResponseStartLine: + """Returns a (version, code, reason) tuple for an HTTP 1.x response line. + + The response is a `collections.namedtuple`. + + >>> parse_response_start_line("HTTP/1.1 200 OK") + ResponseStartLine(version='HTTP/1.1', code=200, reason='OK') + """ + line = native_str(line) + match = _http_response_line_re.match(line) + if not match: + raise HTTPInputError("Error parsing response start line") + return ResponseStartLine(match.group(1), int(match.group(2)), match.group(3)) + + +# _parseparam and _parse_header are copied and modified from python2.7's cgi.py +# The original 2.7 version of this code did not correctly support some +# combinations of semicolons and double quotes. +# It has also been modified to support valueless parameters as seen in +# websocket extension negotiations, and to support non-ascii values in +# RFC 2231/5987 format. + + +def _parseparam(s: str) -> Generator[str, None, None]: + while s[:1] == ";": + s = s[1:] + end = s.find(";") + while end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2: + end = s.find(";", end + 1) + if end < 0: + end = len(s) + f = s[:end] + yield f.strip() + s = s[end:] + + +def _parse_header(line: str) -> Tuple[str, Dict[str, str]]: + r"""Parse a Content-type like header. + + Return the main content-type and a dictionary of options. + + >>> d = "form-data; foo=\"b\\\\a\\\"r\"; file*=utf-8''T%C3%A4st" + >>> ct, d = _parse_header(d) + >>> ct + 'form-data' + >>> d['file'] == r'T\u00e4st'.encode('ascii').decode('unicode_escape') + True + >>> d['foo'] + 'b\\a"r' + """ + parts = _parseparam(";" + line) + key = next(parts) + # decode_params treats first argument special, but we already stripped key + params = [("Dummy", "value")] + for p in parts: + i = p.find("=") + if i >= 0: + name = p[:i].strip().lower() + value = p[i + 1 :].strip() + params.append((name, native_str(value))) + decoded_params = email.utils.decode_params(params) + decoded_params.pop(0) # get rid of the dummy again + pdict = {} + for name, decoded_value in decoded_params: + value = email.utils.collapse_rfc2231_value(decoded_value) + if len(value) >= 2 and value[0] == '"' and value[-1] == '"': + value = value[1:-1] + pdict[name] = value + return key, pdict + + +def _encode_header(key: str, pdict: Dict[str, str]) -> str: + """Inverse of _parse_header. + + >>> _encode_header('permessage-deflate', + ... {'client_max_window_bits': 15, 'client_no_context_takeover': None}) + 'permessage-deflate; client_max_window_bits=15; client_no_context_takeover' + """ + if not pdict: + return key + out = [key] + # Sort the parameters just to make it easy to test. + for k, v in sorted(pdict.items()): + if v is None: + out.append(k) + else: + # TODO: quote if necessary. + out.append("%s=%s" % (k, v)) + return "; ".join(out) + + +def encode_username_password( + username: Union[str, bytes], password: Union[str, bytes] +) -> bytes: + """Encodes a username/password pair in the format used by HTTP auth. + + The return value is a byte string in the form ``username:password``. + + .. versionadded:: 5.1 + """ + if isinstance(username, unicode_type): + username = unicodedata.normalize("NFC", username) + if isinstance(password, unicode_type): + password = unicodedata.normalize("NFC", password) + return utf8(username) + b":" + utf8(password) + + +def doctests(): + # type: () -> unittest.TestSuite + import doctest + + return doctest.DocTestSuite() + + +_netloc_re = re.compile(r"^(.+):(\d+)$") + + +def split_host_and_port(netloc: str) -> Tuple[str, Optional[int]]: + """Returns ``(host, port)`` tuple from ``netloc``. + + Returned ``port`` will be ``None`` if not present. + + .. versionadded:: 4.1 + """ + match = _netloc_re.match(netloc) + if match: + host = match.group(1) + port = int(match.group(2)) # type: Optional[int] + else: + host = netloc + port = None + return (host, port) + + +def qs_to_qsl(qs: Dict[str, List[AnyStr]]) -> Iterable[Tuple[str, AnyStr]]: + """Generator converting a result of ``parse_qs`` back to name-value pairs. + + .. versionadded:: 5.0 + """ + for k, vs in qs.items(): + for v in vs: + yield (k, v) + + +_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]") +_QuotePatt = re.compile(r"[\\].") +_nulljoin = "".join + + +def _unquote_cookie(s: str) -> str: + """Handle double quotes and escaping in cookie values. + + This method is copied verbatim from the Python 3.5 standard + library (http.cookies._unquote) so we don't have to depend on + non-public interfaces. + """ + # If there aren't any doublequotes, + # then there can't be any special characters. See RFC 2109. + if s is None or len(s) < 2: + return s + if s[0] != '"' or s[-1] != '"': + return s + + # We have to assume that we must decode this string. + # Down to work. + + # Remove the "s + s = s[1:-1] + + # Check for special sequences. Examples: + # \012 --> \n + # \" --> " + # + i = 0 + n = len(s) + res = [] + while 0 <= i < n: + o_match = _OctalPatt.search(s, i) + q_match = _QuotePatt.search(s, i) + if not o_match and not q_match: # Neither matched + res.append(s[i:]) + break + # else: + j = k = -1 + if o_match: + j = o_match.start(0) + if q_match: + k = q_match.start(0) + if q_match and (not o_match or k < j): # QuotePatt matched + res.append(s[i:k]) + res.append(s[k + 1]) + i = k + 2 + else: # OctalPatt matched + res.append(s[i:j]) + res.append(chr(int(s[j + 1 : j + 4], 8))) + i = j + 4 + return _nulljoin(res) + + +def parse_cookie(cookie: str) -> Dict[str, str]: + """Parse a ``Cookie`` HTTP header into a dict of name/value pairs. + + This function attempts to mimic browser cookie parsing behavior; + it specifically does not follow any of the cookie-related RFCs + (because browsers don't either). + + The algorithm used is identical to that used by Django version 1.9.10. + + .. versionadded:: 4.4.2 + """ + cookiedict = {} + for chunk in cookie.split(str(";")): + if str("=") in chunk: + key, val = chunk.split(str("="), 1) + else: + # Assume an empty name per + # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 + key, val = str(""), chunk + key, val = key.strip(), val.strip() + if key or val: + # unquote using Python's algorithm. + cookiedict[key] = _unquote_cookie(val) + return cookiedict diff --git a/venv/lib/python3.9/site-packages/tornado/ioloop.py b/venv/lib/python3.9/site-packages/tornado/ioloop.py new file mode 100644 index 00000000..bcdcca09 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/ioloop.py @@ -0,0 +1,960 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""An I/O event loop for non-blocking sockets. + +In Tornado 6.0, `.IOLoop` is a wrapper around the `asyncio` event loop, with a +slightly different interface. The `.IOLoop` interface is now provided primarily +for backwards compatibility; new code should generally use the `asyncio` event +loop interface directly. The `IOLoop.current` class method provides the +`IOLoop` instance corresponding to the running `asyncio` event loop. + +""" + +import asyncio +import concurrent.futures +import datetime +import functools +import numbers +import os +import sys +import time +import math +import random +import warnings +from inspect import isawaitable + +from tornado.concurrent import ( + Future, + is_future, + chain_future, + future_set_exc_info, + future_add_done_callback, +) +from tornado.log import app_log +from tornado.util import Configurable, TimeoutError, import_object + +import typing +from typing import Union, Any, Type, Optional, Callable, TypeVar, Tuple, Awaitable + +if typing.TYPE_CHECKING: + from typing import Dict, List # noqa: F401 + + from typing_extensions import Protocol +else: + Protocol = object + + +class _Selectable(Protocol): + def fileno(self) -> int: + pass + + def close(self) -> None: + pass + + +_T = TypeVar("_T") +_S = TypeVar("_S", bound=_Selectable) + + +class IOLoop(Configurable): + """An I/O event loop. + + As of Tornado 6.0, `IOLoop` is a wrapper around the `asyncio` event loop. + + Example usage for a simple TCP server: + + .. testcode:: + + import asyncio + import errno + import functools + import socket + + import tornado + from tornado.iostream import IOStream + + async def handle_connection(connection, address): + stream = IOStream(connection) + message = await stream.read_until_close() + print("message from client:", message.decode().strip()) + + def connection_ready(sock, fd, events): + while True: + try: + connection, address = sock.accept() + except BlockingIOError: + return + connection.setblocking(0) + io_loop = tornado.ioloop.IOLoop.current() + io_loop.spawn_callback(handle_connection, connection, address) + + async def main(): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(0) + sock.bind(("", 8888)) + sock.listen(128) + + io_loop = tornado.ioloop.IOLoop.current() + callback = functools.partial(connection_ready, sock) + io_loop.add_handler(sock.fileno(), callback, io_loop.READ) + await asyncio.Event().wait() + + if __name__ == "__main__": + asyncio.run(main()) + + .. testoutput:: + :hide: + + Most applications should not attempt to construct an `IOLoop` directly, + and instead initialize the `asyncio` event loop and use `IOLoop.current()`. + In some cases, such as in test frameworks when initializing an `IOLoop` + to be run in a secondary thread, it may be appropriate to construct + an `IOLoop` with ``IOLoop(make_current=False)``. + + In general, an `IOLoop` cannot survive a fork or be shared across processes + in any way. When multiple processes are being used, each process should + create its own `IOLoop`, which also implies that any objects which depend on + the `IOLoop` (such as `.AsyncHTTPClient`) must also be created in the child + processes. As a guideline, anything that starts processes (including the + `tornado.process` and `multiprocessing` modules) should do so as early as + possible, ideally the first thing the application does after loading its + configuration, and *before* any calls to `.IOLoop.start` or `asyncio.run`. + + .. versionchanged:: 4.2 + Added the ``make_current`` keyword argument to the `IOLoop` + constructor. + + .. versionchanged:: 5.0 + + Uses the `asyncio` event loop by default. The ``IOLoop.configure`` method + cannot be used on Python 3 except to redundantly specify the `asyncio` + event loop. + + .. versionchanged:: 6.3 + ``make_current=True`` is now the default when creating an IOLoop - + previously the default was to make the event loop current if there wasn't + already a current one. + """ + + # These constants were originally based on constants from the epoll module. + NONE = 0 + READ = 0x001 + WRITE = 0x004 + ERROR = 0x018 + + # In Python 3, _ioloop_for_asyncio maps from asyncio loops to IOLoops. + _ioloop_for_asyncio = dict() # type: Dict[asyncio.AbstractEventLoop, IOLoop] + + @classmethod + def configure( + cls, impl: "Union[None, str, Type[Configurable]]", **kwargs: Any + ) -> None: + from tornado.platform.asyncio import BaseAsyncIOLoop + + if isinstance(impl, str): + impl = import_object(impl) + if isinstance(impl, type) and not issubclass(impl, BaseAsyncIOLoop): + raise RuntimeError("only AsyncIOLoop is allowed when asyncio is available") + super(IOLoop, cls).configure(impl, **kwargs) + + @staticmethod + def instance() -> "IOLoop": + """Deprecated alias for `IOLoop.current()`. + + .. versionchanged:: 5.0 + + Previously, this method returned a global singleton + `IOLoop`, in contrast with the per-thread `IOLoop` returned + by `current()`. In nearly all cases the two were the same + (when they differed, it was generally used from non-Tornado + threads to communicate back to the main thread's `IOLoop`). + This distinction is not present in `asyncio`, so in order + to facilitate integration with that package `instance()` + was changed to be an alias to `current()`. Applications + using the cross-thread communications aspect of + `instance()` should instead set their own global variable + to point to the `IOLoop` they want to use. + + .. deprecated:: 5.0 + """ + return IOLoop.current() + + def install(self) -> None: + """Deprecated alias for `make_current()`. + + .. versionchanged:: 5.0 + + Previously, this method would set this `IOLoop` as the + global singleton used by `IOLoop.instance()`. Now that + `instance()` is an alias for `current()`, `install()` + is an alias for `make_current()`. + + .. deprecated:: 5.0 + """ + self.make_current() + + @staticmethod + def clear_instance() -> None: + """Deprecated alias for `clear_current()`. + + .. versionchanged:: 5.0 + + Previously, this method would clear the `IOLoop` used as + the global singleton by `IOLoop.instance()`. Now that + `instance()` is an alias for `current()`, + `clear_instance()` is an alias for `clear_current()`. + + .. deprecated:: 5.0 + + """ + IOLoop.clear_current() + + @typing.overload + @staticmethod + def current() -> "IOLoop": + pass + + @typing.overload + @staticmethod + def current(instance: bool = True) -> Optional["IOLoop"]: # noqa: F811 + pass + + @staticmethod + def current(instance: bool = True) -> Optional["IOLoop"]: # noqa: F811 + """Returns the current thread's `IOLoop`. + + If an `IOLoop` is currently running or has been marked as + current by `make_current`, returns that instance. If there is + no current `IOLoop` and ``instance`` is true, creates one. + + .. versionchanged:: 4.1 + Added ``instance`` argument to control the fallback to + `IOLoop.instance()`. + .. versionchanged:: 5.0 + On Python 3, control of the current `IOLoop` is delegated + to `asyncio`, with this and other methods as pass-through accessors. + The ``instance`` argument now controls whether an `IOLoop` + is created automatically when there is none, instead of + whether we fall back to `IOLoop.instance()` (which is now + an alias for this method). ``instance=False`` is deprecated, + since even if we do not create an `IOLoop`, this method + may initialize the asyncio loop. + + .. deprecated:: 6.2 + It is deprecated to call ``IOLoop.current()`` when no `asyncio` + event loop is running. + """ + try: + loop = asyncio.get_event_loop() + except RuntimeError: + if not instance: + return None + # Create a new asyncio event loop for this thread. + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + return IOLoop._ioloop_for_asyncio[loop] + except KeyError: + if instance: + from tornado.platform.asyncio import AsyncIOMainLoop + + current = AsyncIOMainLoop() # type: Optional[IOLoop] + else: + current = None + return current + + def make_current(self) -> None: + """Makes this the `IOLoop` for the current thread. + + An `IOLoop` automatically becomes current for its thread + when it is started, but it is sometimes useful to call + `make_current` explicitly before starting the `IOLoop`, + so that code run at startup time can find the right + instance. + + .. versionchanged:: 4.1 + An `IOLoop` created while there is no current `IOLoop` + will automatically become current. + + .. versionchanged:: 5.0 + This method also sets the current `asyncio` event loop. + + .. deprecated:: 6.2 + Setting and clearing the current event loop through Tornado is + deprecated. Use ``asyncio.set_event_loop`` instead if you need this. + """ + warnings.warn( + "make_current is deprecated; start the event loop first", + DeprecationWarning, + stacklevel=2, + ) + self._make_current() + + def _make_current(self) -> None: + # The asyncio event loops override this method. + raise NotImplementedError() + + @staticmethod + def clear_current() -> None: + """Clears the `IOLoop` for the current thread. + + Intended primarily for use by test frameworks in between tests. + + .. versionchanged:: 5.0 + This method also clears the current `asyncio` event loop. + .. deprecated:: 6.2 + """ + warnings.warn( + "clear_current is deprecated", + DeprecationWarning, + stacklevel=2, + ) + IOLoop._clear_current() + + @staticmethod + def _clear_current() -> None: + old = IOLoop.current(instance=False) + if old is not None: + old._clear_current_hook() + + def _clear_current_hook(self) -> None: + """Instance method called when an IOLoop ceases to be current. + + May be overridden by subclasses as a counterpart to make_current. + """ + pass + + @classmethod + def configurable_base(cls) -> Type[Configurable]: + return IOLoop + + @classmethod + def configurable_default(cls) -> Type[Configurable]: + from tornado.platform.asyncio import AsyncIOLoop + + return AsyncIOLoop + + def initialize(self, make_current: bool = True) -> None: + if make_current: + self._make_current() + + def close(self, all_fds: bool = False) -> None: + """Closes the `IOLoop`, freeing any resources used. + + If ``all_fds`` is true, all file descriptors registered on the + IOLoop will be closed (not just the ones created by the + `IOLoop` itself). + + Many applications will only use a single `IOLoop` that runs for the + entire lifetime of the process. In that case closing the `IOLoop` + is not necessary since everything will be cleaned up when the + process exits. `IOLoop.close` is provided mainly for scenarios + such as unit tests, which create and destroy a large number of + ``IOLoops``. + + An `IOLoop` must be completely stopped before it can be closed. This + means that `IOLoop.stop()` must be called *and* `IOLoop.start()` must + be allowed to return before attempting to call `IOLoop.close()`. + Therefore the call to `close` will usually appear just after + the call to `start` rather than near the call to `stop`. + + .. versionchanged:: 3.1 + If the `IOLoop` implementation supports non-integer objects + for "file descriptors", those objects will have their + ``close`` method when ``all_fds`` is true. + """ + raise NotImplementedError() + + @typing.overload + def add_handler( + self, fd: int, handler: Callable[[int, int], None], events: int + ) -> None: + pass + + @typing.overload # noqa: F811 + def add_handler( + self, fd: _S, handler: Callable[[_S, int], None], events: int + ) -> None: + pass + + def add_handler( # noqa: F811 + self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int + ) -> None: + """Registers the given handler to receive the given events for ``fd``. + + The ``fd`` argument may either be an integer file descriptor or + a file-like object with a ``fileno()`` and ``close()`` method. + + The ``events`` argument is a bitwise or of the constants + ``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``. + + When an event occurs, ``handler(fd, events)`` will be run. + + .. versionchanged:: 4.0 + Added the ability to pass file-like objects in addition to + raw file descriptors. + """ + raise NotImplementedError() + + def update_handler(self, fd: Union[int, _Selectable], events: int) -> None: + """Changes the events we listen for ``fd``. + + .. versionchanged:: 4.0 + Added the ability to pass file-like objects in addition to + raw file descriptors. + """ + raise NotImplementedError() + + def remove_handler(self, fd: Union[int, _Selectable]) -> None: + """Stop listening for events on ``fd``. + + .. versionchanged:: 4.0 + Added the ability to pass file-like objects in addition to + raw file descriptors. + """ + raise NotImplementedError() + + def start(self) -> None: + """Starts the I/O loop. + + The loop will run until one of the callbacks calls `stop()`, which + will make the loop stop after the current event iteration completes. + """ + raise NotImplementedError() + + def stop(self) -> None: + """Stop the I/O loop. + + If the event loop is not currently running, the next call to `start()` + will return immediately. + + Note that even after `stop` has been called, the `IOLoop` is not + completely stopped until `IOLoop.start` has also returned. + Some work that was scheduled before the call to `stop` may still + be run before the `IOLoop` shuts down. + """ + raise NotImplementedError() + + def run_sync(self, func: Callable, timeout: Optional[float] = None) -> Any: + """Starts the `IOLoop`, runs the given function, and stops the loop. + + The function must return either an awaitable object or + ``None``. If the function returns an awaitable object, the + `IOLoop` will run until the awaitable is resolved (and + `run_sync()` will return the awaitable's result). If it raises + an exception, the `IOLoop` will stop and the exception will be + re-raised to the caller. + + The keyword-only argument ``timeout`` may be used to set + a maximum duration for the function. If the timeout expires, + a `asyncio.TimeoutError` is raised. + + This method is useful to allow asynchronous calls in a + ``main()`` function:: + + async def main(): + # do stuff... + + if __name__ == '__main__': + IOLoop.current().run_sync(main) + + .. versionchanged:: 4.3 + Returning a non-``None``, non-awaitable value is now an error. + + .. versionchanged:: 5.0 + If a timeout occurs, the ``func`` coroutine will be cancelled. + + .. versionchanged:: 6.2 + ``tornado.util.TimeoutError`` is now an alias to ``asyncio.TimeoutError``. + """ + future_cell = [None] # type: List[Optional[Future]] + + def run() -> None: + try: + result = func() + if result is not None: + from tornado.gen import convert_yielded + + result = convert_yielded(result) + except Exception: + fut = Future() # type: Future[Any] + future_cell[0] = fut + future_set_exc_info(fut, sys.exc_info()) + else: + if is_future(result): + future_cell[0] = result + else: + fut = Future() + future_cell[0] = fut + fut.set_result(result) + assert future_cell[0] is not None + self.add_future(future_cell[0], lambda future: self.stop()) + + self.add_callback(run) + if timeout is not None: + + def timeout_callback() -> None: + # If we can cancel the future, do so and wait on it. If not, + # Just stop the loop and return with the task still pending. + # (If we neither cancel nor wait for the task, a warning + # will be logged). + assert future_cell[0] is not None + if not future_cell[0].cancel(): + self.stop() + + timeout_handle = self.add_timeout(self.time() + timeout, timeout_callback) + self.start() + if timeout is not None: + self.remove_timeout(timeout_handle) + assert future_cell[0] is not None + if future_cell[0].cancelled() or not future_cell[0].done(): + raise TimeoutError("Operation timed out after %s seconds" % timeout) + return future_cell[0].result() + + def time(self) -> float: + """Returns the current time according to the `IOLoop`'s clock. + + The return value is a floating-point number relative to an + unspecified time in the past. + + Historically, the IOLoop could be customized to use e.g. + `time.monotonic` instead of `time.time`, but this is not + currently supported and so this method is equivalent to + `time.time`. + + """ + return time.time() + + def add_timeout( + self, + deadline: Union[float, datetime.timedelta], + callback: Callable, + *args: Any, + **kwargs: Any + ) -> object: + """Runs the ``callback`` at the time ``deadline`` from the I/O loop. + + Returns an opaque handle that may be passed to + `remove_timeout` to cancel. + + ``deadline`` may be a number denoting a time (on the same + scale as `IOLoop.time`, normally `time.time`), or a + `datetime.timedelta` object for a deadline relative to the + current time. Since Tornado 4.0, `call_later` is a more + convenient alternative for the relative case since it does not + require a timedelta object. + + Note that it is not safe to call `add_timeout` from other threads. + Instead, you must use `add_callback` to transfer control to the + `IOLoop`'s thread, and then call `add_timeout` from there. + + Subclasses of IOLoop must implement either `add_timeout` or + `call_at`; the default implementations of each will call + the other. `call_at` is usually easier to implement, but + subclasses that wish to maintain compatibility with Tornado + versions prior to 4.0 must use `add_timeout` instead. + + .. versionchanged:: 4.0 + Now passes through ``*args`` and ``**kwargs`` to the callback. + """ + if isinstance(deadline, numbers.Real): + return self.call_at(deadline, callback, *args, **kwargs) + elif isinstance(deadline, datetime.timedelta): + return self.call_at( + self.time() + deadline.total_seconds(), callback, *args, **kwargs + ) + else: + raise TypeError("Unsupported deadline %r" % deadline) + + def call_later( + self, delay: float, callback: Callable, *args: Any, **kwargs: Any + ) -> object: + """Runs the ``callback`` after ``delay`` seconds have passed. + + Returns an opaque handle that may be passed to `remove_timeout` + to cancel. Note that unlike the `asyncio` method of the same + name, the returned object does not have a ``cancel()`` method. + + See `add_timeout` for comments on thread-safety and subclassing. + + .. versionadded:: 4.0 + """ + return self.call_at(self.time() + delay, callback, *args, **kwargs) + + def call_at( + self, when: float, callback: Callable, *args: Any, **kwargs: Any + ) -> object: + """Runs the ``callback`` at the absolute time designated by ``when``. + + ``when`` must be a number using the same reference point as + `IOLoop.time`. + + Returns an opaque handle that may be passed to `remove_timeout` + to cancel. Note that unlike the `asyncio` method of the same + name, the returned object does not have a ``cancel()`` method. + + See `add_timeout` for comments on thread-safety and subclassing. + + .. versionadded:: 4.0 + """ + return self.add_timeout(when, callback, *args, **kwargs) + + def remove_timeout(self, timeout: object) -> None: + """Cancels a pending timeout. + + The argument is a handle as returned by `add_timeout`. It is + safe to call `remove_timeout` even if the callback has already + been run. + """ + raise NotImplementedError() + + def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None: + """Calls the given callback on the next I/O loop iteration. + + It is safe to call this method from any thread at any time, + except from a signal handler. Note that this is the **only** + method in `IOLoop` that makes this thread-safety guarantee; all + other interaction with the `IOLoop` must be done from that + `IOLoop`'s thread. `add_callback()` may be used to transfer + control from other threads to the `IOLoop`'s thread. + + To add a callback from a signal handler, see + `add_callback_from_signal`. + """ + raise NotImplementedError() + + def add_callback_from_signal( + self, callback: Callable, *args: Any, **kwargs: Any + ) -> None: + """Calls the given callback on the next I/O loop iteration. + + Safe for use from a Python signal handler; should not be used + otherwise. + """ + raise NotImplementedError() + + def spawn_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None: + """Calls the given callback on the next IOLoop iteration. + + As of Tornado 6.0, this method is equivalent to `add_callback`. + + .. versionadded:: 4.0 + """ + self.add_callback(callback, *args, **kwargs) + + def add_future( + self, + future: "Union[Future[_T], concurrent.futures.Future[_T]]", + callback: Callable[["Future[_T]"], None], + ) -> None: + """Schedules a callback on the ``IOLoop`` when the given + `.Future` is finished. + + The callback is invoked with one argument, the + `.Future`. + + This method only accepts `.Future` objects and not other + awaitables (unlike most of Tornado where the two are + interchangeable). + """ + if isinstance(future, Future): + # Note that we specifically do not want the inline behavior of + # tornado.concurrent.future_add_done_callback. We always want + # this callback scheduled on the next IOLoop iteration (which + # asyncio.Future always does). + # + # Wrap the callback in self._run_callback so we control + # the error logging (i.e. it goes to tornado.log.app_log + # instead of asyncio's log). + future.add_done_callback( + lambda f: self._run_callback(functools.partial(callback, future)) + ) + else: + assert is_future(future) + # For concurrent futures, we use self.add_callback, so + # it's fine if future_add_done_callback inlines that call. + future_add_done_callback( + future, lambda f: self.add_callback(callback, future) + ) + + def run_in_executor( + self, + executor: Optional[concurrent.futures.Executor], + func: Callable[..., _T], + *args: Any + ) -> Awaitable[_T]: + """Runs a function in a ``concurrent.futures.Executor``. If + ``executor`` is ``None``, the IO loop's default executor will be used. + + Use `functools.partial` to pass keyword arguments to ``func``. + + .. versionadded:: 5.0 + """ + if executor is None: + if not hasattr(self, "_executor"): + from tornado.process import cpu_count + + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=(cpu_count() * 5) + ) # type: concurrent.futures.Executor + executor = self._executor + c_future = executor.submit(func, *args) + # Concurrent Futures are not usable with await. Wrap this in a + # Tornado Future instead, using self.add_future for thread-safety. + t_future = Future() # type: Future[_T] + self.add_future(c_future, lambda f: chain_future(f, t_future)) + return t_future + + def set_default_executor(self, executor: concurrent.futures.Executor) -> None: + """Sets the default executor to use with :meth:`run_in_executor`. + + .. versionadded:: 5.0 + """ + self._executor = executor + + def _run_callback(self, callback: Callable[[], Any]) -> None: + """Runs a callback with error handling. + + .. versionchanged:: 6.0 + + CancelledErrors are no longer logged. + """ + try: + ret = callback() + if ret is not None: + from tornado import gen + + # Functions that return Futures typically swallow all + # exceptions and store them in the Future. If a Future + # makes it out to the IOLoop, ensure its exception (if any) + # gets logged too. + try: + ret = gen.convert_yielded(ret) + except gen.BadYieldError: + # It's not unusual for add_callback to be used with + # methods returning a non-None and non-yieldable + # result, which should just be ignored. + pass + else: + self.add_future(ret, self._discard_future_result) + except asyncio.CancelledError: + pass + except Exception: + app_log.error("Exception in callback %r", callback, exc_info=True) + + def _discard_future_result(self, future: Future) -> None: + """Avoid unhandled-exception warnings from spawned coroutines.""" + future.result() + + def split_fd( + self, fd: Union[int, _Selectable] + ) -> Tuple[int, Union[int, _Selectable]]: + # """Returns an (fd, obj) pair from an ``fd`` parameter. + + # We accept both raw file descriptors and file-like objects as + # input to `add_handler` and related methods. When a file-like + # object is passed, we must retain the object itself so we can + # close it correctly when the `IOLoop` shuts down, but the + # poller interfaces favor file descriptors (they will accept + # file-like objects and call ``fileno()`` for you, but they + # always return the descriptor itself). + + # This method is provided for use by `IOLoop` subclasses and should + # not generally be used by application code. + + # .. versionadded:: 4.0 + # """ + if isinstance(fd, int): + return fd, fd + return fd.fileno(), fd + + def close_fd(self, fd: Union[int, _Selectable]) -> None: + # """Utility method to close an ``fd``. + + # If ``fd`` is a file-like object, we close it directly; otherwise + # we use `os.close`. + + # This method is provided for use by `IOLoop` subclasses (in + # implementations of ``IOLoop.close(all_fds=True)`` and should + # not generally be used by application code. + + # .. versionadded:: 4.0 + # """ + try: + if isinstance(fd, int): + os.close(fd) + else: + fd.close() + except OSError: + pass + + +class _Timeout(object): + """An IOLoop timeout, a UNIX timestamp and a callback""" + + # Reduce memory overhead when there are lots of pending callbacks + __slots__ = ["deadline", "callback", "tdeadline"] + + def __init__( + self, deadline: float, callback: Callable[[], None], io_loop: IOLoop + ) -> None: + if not isinstance(deadline, numbers.Real): + raise TypeError("Unsupported deadline %r" % deadline) + self.deadline = deadline + self.callback = callback + self.tdeadline = ( + deadline, + next(io_loop._timeout_counter), + ) # type: Tuple[float, int] + + # Comparison methods to sort by deadline, with object id as a tiebreaker + # to guarantee a consistent ordering. The heapq module uses __le__ + # in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons + # use __lt__). + def __lt__(self, other: "_Timeout") -> bool: + return self.tdeadline < other.tdeadline + + def __le__(self, other: "_Timeout") -> bool: + return self.tdeadline <= other.tdeadline + + +class PeriodicCallback(object): + """Schedules the given callback to be called periodically. + + The callback is called every ``callback_time`` milliseconds when + ``callback_time`` is a float. Note that the timeout is given in + milliseconds, while most other time-related functions in Tornado use + seconds. ``callback_time`` may alternatively be given as a + `datetime.timedelta` object. + + If ``jitter`` is specified, each callback time will be randomly selected + within a window of ``jitter * callback_time`` milliseconds. + Jitter can be used to reduce alignment of events with similar periods. + A jitter of 0.1 means allowing a 10% variation in callback time. + The window is centered on ``callback_time`` so the total number of calls + within a given interval should not be significantly affected by adding + jitter. + + If the callback runs for longer than ``callback_time`` milliseconds, + subsequent invocations will be skipped to get back on schedule. + + `start` must be called after the `PeriodicCallback` is created. + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. versionchanged:: 5.1 + The ``jitter`` argument is added. + + .. versionchanged:: 6.2 + If the ``callback`` argument is a coroutine, and a callback runs for + longer than ``callback_time``, subsequent invocations will be skipped. + Previously this was only true for regular functions, not coroutines, + which were "fire-and-forget" for `PeriodicCallback`. + + The ``callback_time`` argument now accepts `datetime.timedelta` objects, + in addition to the previous numeric milliseconds. + """ + + def __init__( + self, + callback: Callable[[], Optional[Awaitable]], + callback_time: Union[datetime.timedelta, float], + jitter: float = 0, + ) -> None: + self.callback = callback + if isinstance(callback_time, datetime.timedelta): + self.callback_time = callback_time / datetime.timedelta(milliseconds=1) + else: + if callback_time <= 0: + raise ValueError("Periodic callback must have a positive callback_time") + self.callback_time = callback_time + self.jitter = jitter + self._running = False + self._timeout = None # type: object + + def start(self) -> None: + """Starts the timer.""" + # Looking up the IOLoop here allows to first instantiate the + # PeriodicCallback in another thread, then start it using + # IOLoop.add_callback(). + self.io_loop = IOLoop.current() + self._running = True + self._next_timeout = self.io_loop.time() + self._schedule_next() + + def stop(self) -> None: + """Stops the timer.""" + self._running = False + if self._timeout is not None: + self.io_loop.remove_timeout(self._timeout) + self._timeout = None + + def is_running(self) -> bool: + """Returns ``True`` if this `.PeriodicCallback` has been started. + + .. versionadded:: 4.1 + """ + return self._running + + async def _run(self) -> None: + if not self._running: + return + try: + val = self.callback() + if val is not None and isawaitable(val): + await val + except Exception: + app_log.error("Exception in callback %r", self.callback, exc_info=True) + finally: + self._schedule_next() + + def _schedule_next(self) -> None: + if self._running: + self._update_next(self.io_loop.time()) + self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run) + + def _update_next(self, current_time: float) -> None: + callback_time_sec = self.callback_time / 1000.0 + if self.jitter: + # apply jitter fraction + callback_time_sec *= 1 + (self.jitter * (random.random() - 0.5)) + if self._next_timeout <= current_time: + # The period should be measured from the start of one call + # to the start of the next. If one call takes too long, + # skip cycles to get back to a multiple of the original + # schedule. + self._next_timeout += ( + math.floor((current_time - self._next_timeout) / callback_time_sec) + 1 + ) * callback_time_sec + else: + # If the clock moved backwards, ensure we advance the next + # timeout instead of recomputing the same value again. + # This may result in long gaps between callbacks if the + # clock jumps backwards by a lot, but the far more common + # scenario is a small NTP adjustment that should just be + # ignored. + # + # Note that on some systems if time.time() runs slower + # than time.monotonic() (most common on windows), we + # effectively experience a small backwards time jump on + # every iteration because PeriodicCallback uses + # time.time() while asyncio schedules callbacks using + # time.monotonic(). + # https://github.com/tornadoweb/tornado/issues/2333 + self._next_timeout += callback_time_sec diff --git a/venv/lib/python3.9/site-packages/tornado/iostream.py b/venv/lib/python3.9/site-packages/tornado/iostream.py new file mode 100644 index 00000000..a408be59 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/iostream.py @@ -0,0 +1,1654 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Utility classes to write to and read from non-blocking files and sockets. + +Contents: + +* `BaseIOStream`: Generic interface for reading and writing. +* `IOStream`: Implementation of BaseIOStream using non-blocking sockets. +* `SSLIOStream`: SSL-aware version of IOStream. +* `PipeIOStream`: Pipe-based IOStream implementation. +""" + +import asyncio +import collections +import errno +import io +import numbers +import os +import socket +import ssl +import sys +import re + +from tornado.concurrent import Future, future_set_result_unless_cancelled +from tornado import ioloop +from tornado.log import gen_log +from tornado.netutil import ssl_wrap_socket, _client_ssl_defaults, _server_ssl_defaults +from tornado.util import errno_from_exception + +import typing +from typing import ( + Union, + Optional, + Awaitable, + Callable, + Pattern, + Any, + Dict, + TypeVar, + Tuple, +) +from types import TracebackType + +if typing.TYPE_CHECKING: + from typing import Deque, List, Type # noqa: F401 + +_IOStreamType = TypeVar("_IOStreamType", bound="IOStream") + +# These errnos indicate that a connection has been abruptly terminated. +# They should be caught and handled less noisily than other errors. +_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, errno.ETIMEDOUT) + +if hasattr(errno, "WSAECONNRESET"): + _ERRNO_CONNRESET += ( # type: ignore + errno.WSAECONNRESET, # type: ignore + errno.WSAECONNABORTED, # type: ignore + errno.WSAETIMEDOUT, # type: ignore + ) + +if sys.platform == "darwin": + # OSX appears to have a race condition that causes send(2) to return + # EPROTOTYPE if called while a socket is being torn down: + # http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/ + # Since the socket is being closed anyway, treat this as an ECONNRESET + # instead of an unexpected error. + _ERRNO_CONNRESET += (errno.EPROTOTYPE,) # type: ignore + +_WINDOWS = sys.platform.startswith("win") + + +class StreamClosedError(IOError): + """Exception raised by `IOStream` methods when the stream is closed. + + Note that the close callback is scheduled to run *after* other + callbacks on the stream (to allow for buffered data to be processed), + so you may see this error before you see the close callback. + + The ``real_error`` attribute contains the underlying error that caused + the stream to close (if any). + + .. versionchanged:: 4.3 + Added the ``real_error`` attribute. + """ + + def __init__(self, real_error: Optional[BaseException] = None) -> None: + super().__init__("Stream is closed") + self.real_error = real_error + + +class UnsatisfiableReadError(Exception): + """Exception raised when a read cannot be satisfied. + + Raised by ``read_until`` and ``read_until_regex`` with a ``max_bytes`` + argument. + """ + + pass + + +class StreamBufferFullError(Exception): + """Exception raised by `IOStream` methods when the buffer is full.""" + + +class _StreamBuffer(object): + """ + A specialized buffer that tries to avoid copies when large pieces + of data are encountered. + """ + + def __init__(self) -> None: + # A sequence of (False, bytearray) and (True, memoryview) objects + self._buffers = ( + collections.deque() + ) # type: Deque[Tuple[bool, Union[bytearray, memoryview]]] + # Position in the first buffer + self._first_pos = 0 + self._size = 0 + + def __len__(self) -> int: + return self._size + + # Data above this size will be appended separately instead + # of extending an existing bytearray + _large_buf_threshold = 2048 + + def append(self, data: Union[bytes, bytearray, memoryview]) -> None: + """ + Append the given piece of data (should be a buffer-compatible object). + """ + size = len(data) + if size > self._large_buf_threshold: + if not isinstance(data, memoryview): + data = memoryview(data) + self._buffers.append((True, data)) + elif size > 0: + if self._buffers: + is_memview, b = self._buffers[-1] + new_buf = is_memview or len(b) >= self._large_buf_threshold + else: + new_buf = True + if new_buf: + self._buffers.append((False, bytearray(data))) + else: + b += data # type: ignore + + self._size += size + + def peek(self, size: int) -> memoryview: + """ + Get a view over at most ``size`` bytes (possibly fewer) at the + current buffer position. + """ + assert size > 0 + try: + is_memview, b = self._buffers[0] + except IndexError: + return memoryview(b"") + + pos = self._first_pos + if is_memview: + return typing.cast(memoryview, b[pos : pos + size]) + else: + return memoryview(b)[pos : pos + size] + + def advance(self, size: int) -> None: + """ + Advance the current buffer position by ``size`` bytes. + """ + assert 0 < size <= self._size + self._size -= size + pos = self._first_pos + + buffers = self._buffers + while buffers and size > 0: + is_large, b = buffers[0] + b_remain = len(b) - size - pos + if b_remain <= 0: + buffers.popleft() + size -= len(b) - pos + pos = 0 + elif is_large: + pos += size + size = 0 + else: + pos += size + del typing.cast(bytearray, b)[:pos] + pos = 0 + size = 0 + + assert size == 0 + self._first_pos = pos + + +class BaseIOStream(object): + """A utility class to write to and read from a non-blocking file or socket. + + We support a non-blocking ``write()`` and a family of ``read_*()`` + methods. When the operation completes, the ``Awaitable`` will resolve + with the data read (or ``None`` for ``write()``). All outstanding + ``Awaitables`` will resolve with a `StreamClosedError` when the + stream is closed; `.BaseIOStream.set_close_callback` can also be used + to be notified of a closed stream. + + When a stream is closed due to an error, the IOStream's ``error`` + attribute contains the exception object. + + Subclasses must implement `fileno`, `close_fd`, `write_to_fd`, + `read_from_fd`, and optionally `get_fd_error`. + + """ + + def __init__( + self, + max_buffer_size: Optional[int] = None, + read_chunk_size: Optional[int] = None, + max_write_buffer_size: Optional[int] = None, + ) -> None: + """`BaseIOStream` constructor. + + :arg max_buffer_size: Maximum amount of incoming data to buffer; + defaults to 100MB. + :arg read_chunk_size: Amount of data to read at one time from the + underlying transport; defaults to 64KB. + :arg max_write_buffer_size: Amount of outgoing data to buffer; + defaults to unlimited. + + .. versionchanged:: 4.0 + Add the ``max_write_buffer_size`` parameter. Changed default + ``read_chunk_size`` to 64KB. + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been + removed. + """ + self.io_loop = ioloop.IOLoop.current() + self.max_buffer_size = max_buffer_size or 104857600 + # A chunk size that is too close to max_buffer_size can cause + # spurious failures. + self.read_chunk_size = min(read_chunk_size or 65536, self.max_buffer_size // 2) + self.max_write_buffer_size = max_write_buffer_size + self.error = None # type: Optional[BaseException] + self._read_buffer = bytearray() + self._read_buffer_size = 0 + self._user_read_buffer = False + self._after_user_read_buffer = None # type: Optional[bytearray] + self._write_buffer = _StreamBuffer() + self._total_write_index = 0 + self._total_write_done_index = 0 + self._read_delimiter = None # type: Optional[bytes] + self._read_regex = None # type: Optional[Pattern] + self._read_max_bytes = None # type: Optional[int] + self._read_bytes = None # type: Optional[int] + self._read_partial = False + self._read_until_close = False + self._read_future = None # type: Optional[Future] + self._write_futures = ( + collections.deque() + ) # type: Deque[Tuple[int, Future[None]]] + self._close_callback = None # type: Optional[Callable[[], None]] + self._connect_future = None # type: Optional[Future[IOStream]] + # _ssl_connect_future should be defined in SSLIOStream + # but it's here so we can clean it up in _signal_closed + # TODO: refactor that so subclasses can add additional futures + # to be cancelled. + self._ssl_connect_future = None # type: Optional[Future[SSLIOStream]] + self._connecting = False + self._state = None # type: Optional[int] + self._closed = False + + def fileno(self) -> Union[int, ioloop._Selectable]: + """Returns the file descriptor for this stream.""" + raise NotImplementedError() + + def close_fd(self) -> None: + """Closes the file underlying this stream. + + ``close_fd`` is called by `BaseIOStream` and should not be called + elsewhere; other users should call `close` instead. + """ + raise NotImplementedError() + + def write_to_fd(self, data: memoryview) -> int: + """Attempts to write ``data`` to the underlying file. + + Returns the number of bytes written. + """ + raise NotImplementedError() + + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: + """Attempts to read from the underlying file. + + Reads up to ``len(buf)`` bytes, storing them in the buffer. + Returns the number of bytes read. Returns None if there was + nothing to read (the socket returned `~errno.EWOULDBLOCK` or + equivalent), and zero on EOF. + + .. versionchanged:: 5.0 + + Interface redesigned to take a buffer and return a number + of bytes instead of a freshly-allocated object. + """ + raise NotImplementedError() + + def get_fd_error(self) -> Optional[Exception]: + """Returns information about any error on the underlying file. + + This method is called after the `.IOLoop` has signaled an error on the + file descriptor, and should return an Exception (such as `socket.error` + with additional information, or None if no such information is + available. + """ + return None + + def read_until_regex( + self, regex: bytes, max_bytes: Optional[int] = None + ) -> Awaitable[bytes]: + """Asynchronously read until we have matched the given regex. + + The result includes the data that matches the regex and anything + that came before it. + + If ``max_bytes`` is not None, the connection will be closed + if more than ``max_bytes`` bytes have been read and the regex is + not satisfied. + + .. versionchanged:: 4.0 + Added the ``max_bytes`` argument. The ``callback`` argument is + now optional and a `.Future` will be returned if it is omitted. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + `.Future` instead. + + """ + future = self._start_read() + self._read_regex = re.compile(regex) + self._read_max_bytes = max_bytes + try: + self._try_inline_read() + except UnsatisfiableReadError as e: + # Handle this the same way as in _handle_events. + gen_log.info("Unsatisfiable read, closing connection: %s" % e) + self.close(exc_info=e) + return future + except: + # Ensure that the future doesn't log an error because its + # failure was never examined. + future.add_done_callback(lambda f: f.exception()) + raise + return future + + def read_until( + self, delimiter: bytes, max_bytes: Optional[int] = None + ) -> Awaitable[bytes]: + """Asynchronously read until we have found the given delimiter. + + The result includes all the data read including the delimiter. + + If ``max_bytes`` is not None, the connection will be closed + if more than ``max_bytes`` bytes have been read and the delimiter + is not found. + + .. versionchanged:: 4.0 + Added the ``max_bytes`` argument. The ``callback`` argument is + now optional and a `.Future` will be returned if it is omitted. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + `.Future` instead. + """ + future = self._start_read() + self._read_delimiter = delimiter + self._read_max_bytes = max_bytes + try: + self._try_inline_read() + except UnsatisfiableReadError as e: + # Handle this the same way as in _handle_events. + gen_log.info("Unsatisfiable read, closing connection: %s" % e) + self.close(exc_info=e) + return future + except: + future.add_done_callback(lambda f: f.exception()) + raise + return future + + def read_bytes(self, num_bytes: int, partial: bool = False) -> Awaitable[bytes]: + """Asynchronously read a number of bytes. + + If ``partial`` is true, data is returned as soon as we have + any bytes to return (but never more than ``num_bytes``) + + .. versionchanged:: 4.0 + Added the ``partial`` argument. The callback argument is now + optional and a `.Future` will be returned if it is omitted. + + .. versionchanged:: 6.0 + + The ``callback`` and ``streaming_callback`` arguments have + been removed. Use the returned `.Future` (and + ``partial=True`` for ``streaming_callback``) instead. + + """ + future = self._start_read() + assert isinstance(num_bytes, numbers.Integral) + self._read_bytes = num_bytes + self._read_partial = partial + try: + self._try_inline_read() + except: + future.add_done_callback(lambda f: f.exception()) + raise + return future + + def read_into(self, buf: bytearray, partial: bool = False) -> Awaitable[int]: + """Asynchronously read a number of bytes. + + ``buf`` must be a writable buffer into which data will be read. + + If ``partial`` is true, the callback is run as soon as any bytes + have been read. Otherwise, it is run when the ``buf`` has been + entirely filled with read data. + + .. versionadded:: 5.0 + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + `.Future` instead. + + """ + future = self._start_read() + + # First copy data already in read buffer + available_bytes = self._read_buffer_size + n = len(buf) + if available_bytes >= n: + buf[:] = memoryview(self._read_buffer)[:n] + del self._read_buffer[:n] + self._after_user_read_buffer = self._read_buffer + elif available_bytes > 0: + buf[:available_bytes] = memoryview(self._read_buffer)[:] + + # Set up the supplied buffer as our temporary read buffer. + # The original (if it had any data remaining) has been + # saved for later. + self._user_read_buffer = True + self._read_buffer = buf + self._read_buffer_size = available_bytes + self._read_bytes = n + self._read_partial = partial + + try: + self._try_inline_read() + except: + future.add_done_callback(lambda f: f.exception()) + raise + return future + + def read_until_close(self) -> Awaitable[bytes]: + """Asynchronously reads all data from the socket until it is closed. + + This will buffer all available data until ``max_buffer_size`` + is reached. If flow control or cancellation are desired, use a + loop with `read_bytes(partial=True) <.read_bytes>` instead. + + .. versionchanged:: 4.0 + The callback argument is now optional and a `.Future` will + be returned if it is omitted. + + .. versionchanged:: 6.0 + + The ``callback`` and ``streaming_callback`` arguments have + been removed. Use the returned `.Future` (and `read_bytes` + with ``partial=True`` for ``streaming_callback``) instead. + + """ + future = self._start_read() + if self.closed(): + self._finish_read(self._read_buffer_size) + return future + self._read_until_close = True + try: + self._try_inline_read() + except: + future.add_done_callback(lambda f: f.exception()) + raise + return future + + def write(self, data: Union[bytes, memoryview]) -> "Future[None]": + """Asynchronously write the given data to this stream. + + This method returns a `.Future` that resolves (with a result + of ``None``) when the write has been completed. + + The ``data`` argument may be of type `bytes` or `memoryview`. + + .. versionchanged:: 4.0 + Now returns a `.Future` if no callback is given. + + .. versionchanged:: 4.5 + Added support for `memoryview` arguments. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + `.Future` instead. + + """ + self._check_closed() + if data: + if isinstance(data, memoryview): + # Make sure that ``len(data) == data.nbytes`` + data = memoryview(data).cast("B") + if ( + self.max_write_buffer_size is not None + and len(self._write_buffer) + len(data) > self.max_write_buffer_size + ): + raise StreamBufferFullError("Reached maximum write buffer size") + self._write_buffer.append(data) + self._total_write_index += len(data) + future = Future() # type: Future[None] + future.add_done_callback(lambda f: f.exception()) + self._write_futures.append((self._total_write_index, future)) + if not self._connecting: + self._handle_write() + if self._write_buffer: + self._add_io_state(self.io_loop.WRITE) + self._maybe_add_error_listener() + return future + + def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None: + """Call the given callback when the stream is closed. + + This mostly is not necessary for applications that use the + `.Future` interface; all outstanding ``Futures`` will resolve + with a `StreamClosedError` when the stream is closed. However, + it is still useful as a way to signal that the stream has been + closed while no other read or write is in progress. + + Unlike other callback-based interfaces, ``set_close_callback`` + was not removed in Tornado 6.0. + """ + self._close_callback = callback + self._maybe_add_error_listener() + + def close( + self, + exc_info: Union[ + None, + bool, + BaseException, + Tuple[ + "Optional[Type[BaseException]]", + Optional[BaseException], + Optional[TracebackType], + ], + ] = False, + ) -> None: + """Close this stream. + + If ``exc_info`` is true, set the ``error`` attribute to the current + exception from `sys.exc_info` (or if ``exc_info`` is a tuple, + use that instead of `sys.exc_info`). + """ + if not self.closed(): + if exc_info: + if isinstance(exc_info, tuple): + self.error = exc_info[1] + elif isinstance(exc_info, BaseException): + self.error = exc_info + else: + exc_info = sys.exc_info() + if any(exc_info): + self.error = exc_info[1] + if self._read_until_close: + self._read_until_close = False + self._finish_read(self._read_buffer_size) + elif self._read_future is not None: + # resolve reads that are pending and ready to complete + try: + pos = self._find_read_pos() + except UnsatisfiableReadError: + pass + else: + if pos is not None: + self._read_from_buffer(pos) + if self._state is not None: + self.io_loop.remove_handler(self.fileno()) + self._state = None + self.close_fd() + self._closed = True + self._signal_closed() + + def _signal_closed(self) -> None: + futures = [] # type: List[Future] + if self._read_future is not None: + futures.append(self._read_future) + self._read_future = None + futures += [future for _, future in self._write_futures] + self._write_futures.clear() + if self._connect_future is not None: + futures.append(self._connect_future) + self._connect_future = None + for future in futures: + if not future.done(): + future.set_exception(StreamClosedError(real_error=self.error)) + # Reference the exception to silence warnings. Annoyingly, + # this raises if the future was cancelled, but just + # returns any other error. + try: + future.exception() + except asyncio.CancelledError: + pass + if self._ssl_connect_future is not None: + # _ssl_connect_future expects to see the real exception (typically + # an ssl.SSLError), not just StreamClosedError. + if not self._ssl_connect_future.done(): + if self.error is not None: + self._ssl_connect_future.set_exception(self.error) + else: + self._ssl_connect_future.set_exception(StreamClosedError()) + self._ssl_connect_future.exception() + self._ssl_connect_future = None + if self._close_callback is not None: + cb = self._close_callback + self._close_callback = None + self.io_loop.add_callback(cb) + # Clear the buffers so they can be cleared immediately even + # if the IOStream object is kept alive by a reference cycle. + # TODO: Clear the read buffer too; it currently breaks some tests. + self._write_buffer = None # type: ignore + + def reading(self) -> bool: + """Returns ``True`` if we are currently reading from the stream.""" + return self._read_future is not None + + def writing(self) -> bool: + """Returns ``True`` if we are currently writing to the stream.""" + return bool(self._write_buffer) + + def closed(self) -> bool: + """Returns ``True`` if the stream has been closed.""" + return self._closed + + def set_nodelay(self, value: bool) -> None: + """Sets the no-delay flag for this stream. + + By default, data written to TCP streams may be held for a time + to make the most efficient use of bandwidth (according to + Nagle's algorithm). The no-delay flag requests that data be + written as soon as possible, even if doing so would consume + additional bandwidth. + + This flag is currently defined only for TCP-based ``IOStreams``. + + .. versionadded:: 3.1 + """ + pass + + def _handle_connect(self) -> None: + raise NotImplementedError() + + def _handle_events(self, fd: Union[int, ioloop._Selectable], events: int) -> None: + if self.closed(): + gen_log.warning("Got events for closed stream %s", fd) + return + try: + if self._connecting: + # Most IOLoops will report a write failed connect + # with the WRITE event, but SelectIOLoop reports a + # READ as well so we must check for connecting before + # either. + self._handle_connect() + if self.closed(): + return + if events & self.io_loop.READ: + self._handle_read() + if self.closed(): + return + if events & self.io_loop.WRITE: + self._handle_write() + if self.closed(): + return + if events & self.io_loop.ERROR: + self.error = self.get_fd_error() + # We may have queued up a user callback in _handle_read or + # _handle_write, so don't close the IOStream until those + # callbacks have had a chance to run. + self.io_loop.add_callback(self.close) + return + state = self.io_loop.ERROR + if self.reading(): + state |= self.io_loop.READ + if self.writing(): + state |= self.io_loop.WRITE + if state == self.io_loop.ERROR and self._read_buffer_size == 0: + # If the connection is idle, listen for reads too so + # we can tell if the connection is closed. If there is + # data in the read buffer we won't run the close callback + # yet anyway, so we don't need to listen in this case. + state |= self.io_loop.READ + if state != self._state: + assert ( + self._state is not None + ), "shouldn't happen: _handle_events without self._state" + self._state = state + self.io_loop.update_handler(self.fileno(), self._state) + except UnsatisfiableReadError as e: + gen_log.info("Unsatisfiable read, closing connection: %s" % e) + self.close(exc_info=e) + except Exception as e: + gen_log.error("Uncaught exception, closing connection.", exc_info=True) + self.close(exc_info=e) + raise + + def _read_to_buffer_loop(self) -> Optional[int]: + # This method is called from _handle_read and _try_inline_read. + if self._read_bytes is not None: + target_bytes = self._read_bytes # type: Optional[int] + elif self._read_max_bytes is not None: + target_bytes = self._read_max_bytes + elif self.reading(): + # For read_until without max_bytes, or + # read_until_close, read as much as we can before + # scanning for the delimiter. + target_bytes = None + else: + target_bytes = 0 + next_find_pos = 0 + while not self.closed(): + # Read from the socket until we get EWOULDBLOCK or equivalent. + # SSL sockets do some internal buffering, and if the data is + # sitting in the SSL object's buffer select() and friends + # can't see it; the only way to find out if it's there is to + # try to read it. + if self._read_to_buffer() == 0: + break + + # If we've read all the bytes we can use, break out of + # this loop. + + # If we've reached target_bytes, we know we're done. + if target_bytes is not None and self._read_buffer_size >= target_bytes: + break + + # Otherwise, we need to call the more expensive find_read_pos. + # It's inefficient to do this on every read, so instead + # do it on the first read and whenever the read buffer + # size has doubled. + if self._read_buffer_size >= next_find_pos: + pos = self._find_read_pos() + if pos is not None: + return pos + next_find_pos = self._read_buffer_size * 2 + return self._find_read_pos() + + def _handle_read(self) -> None: + try: + pos = self._read_to_buffer_loop() + except UnsatisfiableReadError: + raise + except asyncio.CancelledError: + raise + except Exception as e: + gen_log.warning("error on read: %s" % e) + self.close(exc_info=e) + return + if pos is not None: + self._read_from_buffer(pos) + + def _start_read(self) -> Future: + if self._read_future is not None: + # It is an error to start a read while a prior read is unresolved. + # However, if the prior read is unresolved because the stream was + # closed without satisfying it, it's better to raise + # StreamClosedError instead of AssertionError. In particular, this + # situation occurs in harmless situations in http1connection.py and + # an AssertionError would be logged noisily. + # + # On the other hand, it is legal to start a new read while the + # stream is closed, in case the read can be satisfied from the + # read buffer. So we only want to check the closed status of the + # stream if we need to decide what kind of error to raise for + # "already reading". + # + # These conditions have proven difficult to test; we have no + # unittests that reliably verify this behavior so be careful + # when making changes here. See #2651 and #2719. + self._check_closed() + assert self._read_future is None, "Already reading" + self._read_future = Future() + return self._read_future + + def _finish_read(self, size: int) -> None: + if self._user_read_buffer: + self._read_buffer = self._after_user_read_buffer or bytearray() + self._after_user_read_buffer = None + self._read_buffer_size = len(self._read_buffer) + self._user_read_buffer = False + result = size # type: Union[int, bytes] + else: + result = self._consume(size) + if self._read_future is not None: + future = self._read_future + self._read_future = None + future_set_result_unless_cancelled(future, result) + self._maybe_add_error_listener() + + def _try_inline_read(self) -> None: + """Attempt to complete the current read operation from buffered data. + + If the read can be completed without blocking, schedules the + read callback on the next IOLoop iteration; otherwise starts + listening for reads on the socket. + """ + # See if we've already got the data from a previous read + pos = self._find_read_pos() + if pos is not None: + self._read_from_buffer(pos) + return + self._check_closed() + pos = self._read_to_buffer_loop() + if pos is not None: + self._read_from_buffer(pos) + return + # We couldn't satisfy the read inline, so make sure we're + # listening for new data unless the stream is closed. + if not self.closed(): + self._add_io_state(ioloop.IOLoop.READ) + + def _read_to_buffer(self) -> Optional[int]: + """Reads from the socket and appends the result to the read buffer. + + Returns the number of bytes read. Returns 0 if there is nothing + to read (i.e. the read returns EWOULDBLOCK or equivalent). On + error closes the socket and raises an exception. + """ + try: + while True: + try: + if self._user_read_buffer: + buf = memoryview(self._read_buffer)[ + self._read_buffer_size : + ] # type: Union[memoryview, bytearray] + else: + buf = bytearray(self.read_chunk_size) + bytes_read = self.read_from_fd(buf) + except (socket.error, IOError, OSError) as e: + # ssl.SSLError is a subclass of socket.error + if self._is_connreset(e): + # Treat ECONNRESET as a connection close rather than + # an error to minimize log spam (the exception will + # be available on self.error for apps that care). + self.close(exc_info=e) + return None + self.close(exc_info=e) + raise + break + if bytes_read is None: + return 0 + elif bytes_read == 0: + self.close() + return 0 + if not self._user_read_buffer: + self._read_buffer += memoryview(buf)[:bytes_read] + self._read_buffer_size += bytes_read + finally: + # Break the reference to buf so we don't waste a chunk's worth of + # memory in case an exception hangs on to our stack frame. + del buf + if self._read_buffer_size > self.max_buffer_size: + gen_log.error("Reached maximum read buffer size") + self.close() + raise StreamBufferFullError("Reached maximum read buffer size") + return bytes_read + + def _read_from_buffer(self, pos: int) -> None: + """Attempts to complete the currently-pending read from the buffer. + + The argument is either a position in the read buffer or None, + as returned by _find_read_pos. + """ + self._read_bytes = self._read_delimiter = self._read_regex = None + self._read_partial = False + self._finish_read(pos) + + def _find_read_pos(self) -> Optional[int]: + """Attempts to find a position in the read buffer that satisfies + the currently-pending read. + + Returns a position in the buffer if the current read can be satisfied, + or None if it cannot. + """ + if self._read_bytes is not None and ( + self._read_buffer_size >= self._read_bytes + or (self._read_partial and self._read_buffer_size > 0) + ): + num_bytes = min(self._read_bytes, self._read_buffer_size) + return num_bytes + elif self._read_delimiter is not None: + # Multi-byte delimiters (e.g. '\r\n') may straddle two + # chunks in the read buffer, so we can't easily find them + # without collapsing the buffer. However, since protocols + # using delimited reads (as opposed to reads of a known + # length) tend to be "line" oriented, the delimiter is likely + # to be in the first few chunks. Merge the buffer gradually + # since large merges are relatively expensive and get undone in + # _consume(). + if self._read_buffer: + loc = self._read_buffer.find(self._read_delimiter) + if loc != -1: + delimiter_len = len(self._read_delimiter) + self._check_max_bytes(self._read_delimiter, loc + delimiter_len) + return loc + delimiter_len + self._check_max_bytes(self._read_delimiter, self._read_buffer_size) + elif self._read_regex is not None: + if self._read_buffer: + m = self._read_regex.search(self._read_buffer) + if m is not None: + loc = m.end() + self._check_max_bytes(self._read_regex, loc) + return loc + self._check_max_bytes(self._read_regex, self._read_buffer_size) + return None + + def _check_max_bytes(self, delimiter: Union[bytes, Pattern], size: int) -> None: + if self._read_max_bytes is not None and size > self._read_max_bytes: + raise UnsatisfiableReadError( + "delimiter %r not found within %d bytes" + % (delimiter, self._read_max_bytes) + ) + + def _handle_write(self) -> None: + while True: + size = len(self._write_buffer) + if not size: + break + assert size > 0 + try: + if _WINDOWS: + # On windows, socket.send blows up if given a + # write buffer that's too large, instead of just + # returning the number of bytes it was able to + # process. Therefore we must not call socket.send + # with more than 128KB at a time. + size = 128 * 1024 + + num_bytes = self.write_to_fd(self._write_buffer.peek(size)) + if num_bytes == 0: + break + self._write_buffer.advance(num_bytes) + self._total_write_done_index += num_bytes + except BlockingIOError: + break + except (socket.error, IOError, OSError) as e: + if not self._is_connreset(e): + # Broken pipe errors are usually caused by connection + # reset, and its better to not log EPIPE errors to + # minimize log spam + gen_log.warning("Write error on %s: %s", self.fileno(), e) + self.close(exc_info=e) + return + + while self._write_futures: + index, future = self._write_futures[0] + if index > self._total_write_done_index: + break + self._write_futures.popleft() + future_set_result_unless_cancelled(future, None) + + def _consume(self, loc: int) -> bytes: + # Consume loc bytes from the read buffer and return them + if loc == 0: + return b"" + assert loc <= self._read_buffer_size + # Slice the bytearray buffer into bytes, without intermediate copying + b = (memoryview(self._read_buffer)[:loc]).tobytes() + self._read_buffer_size -= loc + del self._read_buffer[:loc] + return b + + def _check_closed(self) -> None: + if self.closed(): + raise StreamClosedError(real_error=self.error) + + def _maybe_add_error_listener(self) -> None: + # This method is part of an optimization: to detect a connection that + # is closed when we're not actively reading or writing, we must listen + # for read events. However, it is inefficient to do this when the + # connection is first established because we are going to read or write + # immediately anyway. Instead, we insert checks at various times to + # see if the connection is idle and add the read listener then. + if self._state is None or self._state == ioloop.IOLoop.ERROR: + if ( + not self.closed() + and self._read_buffer_size == 0 + and self._close_callback is not None + ): + self._add_io_state(ioloop.IOLoop.READ) + + def _add_io_state(self, state: int) -> None: + """Adds `state` (IOLoop.{READ,WRITE} flags) to our event handler. + + Implementation notes: Reads and writes have a fast path and a + slow path. The fast path reads synchronously from socket + buffers, while the slow path uses `_add_io_state` to schedule + an IOLoop callback. + + To detect closed connections, we must have called + `_add_io_state` at some point, but we want to delay this as + much as possible so we don't have to set an `IOLoop.ERROR` + listener that will be overwritten by the next slow-path + operation. If a sequence of fast-path ops do not end in a + slow-path op, (e.g. for an @asynchronous long-poll request), + we must add the error handler. + + TODO: reevaluate this now that callbacks are gone. + + """ + if self.closed(): + # connection has been closed, so there can be no future events + return + if self._state is None: + self._state = ioloop.IOLoop.ERROR | state + self.io_loop.add_handler(self.fileno(), self._handle_events, self._state) + elif not self._state & state: + self._state = self._state | state + self.io_loop.update_handler(self.fileno(), self._state) + + def _is_connreset(self, exc: BaseException) -> bool: + """Return ``True`` if exc is ECONNRESET or equivalent. + + May be overridden in subclasses. + """ + return ( + isinstance(exc, (socket.error, IOError)) + and errno_from_exception(exc) in _ERRNO_CONNRESET + ) + + +class IOStream(BaseIOStream): + r"""Socket-based `IOStream` implementation. + + This class supports the read and write methods from `BaseIOStream` + plus a `connect` method. + + The ``socket`` parameter may either be connected or unconnected. + For server operations the socket is the result of calling + `socket.accept <socket.socket.accept>`. For client operations the + socket is created with `socket.socket`, and may either be + connected before passing it to the `IOStream` or connected with + `IOStream.connect`. + + A very simple (and broken) HTTP client using this class: + + .. testcode:: + + import socket + import tornado + + async def main(): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + stream = tornado.iostream.IOStream(s) + await stream.connect(("friendfeed.com", 80)) + await stream.write(b"GET / HTTP/1.0\r\nHost: friendfeed.com\r\n\r\n") + header_data = await stream.read_until(b"\r\n\r\n") + headers = {} + for line in header_data.split(b"\r\n"): + parts = line.split(b":") + if len(parts) == 2: + headers[parts[0].strip()] = parts[1].strip() + body_data = await stream.read_bytes(int(headers[b"Content-Length"])) + print(body_data) + stream.close() + + if __name__ == '__main__': + asyncio.run(main()) + + .. testoutput:: + :hide: + + """ + + def __init__(self, socket: socket.socket, *args: Any, **kwargs: Any) -> None: + self.socket = socket + self.socket.setblocking(False) + super().__init__(*args, **kwargs) + + def fileno(self) -> Union[int, ioloop._Selectable]: + return self.socket + + def close_fd(self) -> None: + self.socket.close() + self.socket = None # type: ignore + + def get_fd_error(self) -> Optional[Exception]: + errno = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + return socket.error(errno, os.strerror(errno)) + + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: + try: + return self.socket.recv_into(buf, len(buf)) + except BlockingIOError: + return None + finally: + del buf + + def write_to_fd(self, data: memoryview) -> int: + try: + return self.socket.send(data) # type: ignore + finally: + # Avoid keeping to data, which can be a memoryview. + # See https://github.com/tornadoweb/tornado/pull/2008 + del data + + def connect( + self: _IOStreamType, address: Any, server_hostname: Optional[str] = None + ) -> "Future[_IOStreamType]": + """Connects the socket to a remote address without blocking. + + May only be called if the socket passed to the constructor was + not previously connected. The address parameter is in the + same format as for `socket.connect <socket.socket.connect>` for + the type of socket passed to the IOStream constructor, + e.g. an ``(ip, port)`` tuple. Hostnames are accepted here, + but will be resolved synchronously and block the IOLoop. + If you have a hostname instead of an IP address, the `.TCPClient` + class is recommended instead of calling this method directly. + `.TCPClient` will do asynchronous DNS resolution and handle + both IPv4 and IPv6. + + If ``callback`` is specified, it will be called with no + arguments when the connection is completed; if not this method + returns a `.Future` (whose result after a successful + connection will be the stream itself). + + In SSL mode, the ``server_hostname`` parameter will be used + for certificate validation (unless disabled in the + ``ssl_options``) and SNI (if supported; requires Python + 2.7.9+). + + Note that it is safe to call `IOStream.write + <BaseIOStream.write>` while the connection is pending, in + which case the data will be written as soon as the connection + is ready. Calling `IOStream` read methods before the socket is + connected works on some platforms but is non-portable. + + .. versionchanged:: 4.0 + If no callback is given, returns a `.Future`. + + .. versionchanged:: 4.2 + SSL certificates are validated by default; pass + ``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a + suitably-configured `ssl.SSLContext` to the + `SSLIOStream` constructor to disable. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + `.Future` instead. + + """ + self._connecting = True + future = Future() # type: Future[_IOStreamType] + self._connect_future = typing.cast("Future[IOStream]", future) + try: + self.socket.connect(address) + except BlockingIOError: + # In non-blocking mode we expect connect() to raise an + # exception with EINPROGRESS or EWOULDBLOCK. + pass + except socket.error as e: + # On freebsd, other errors such as ECONNREFUSED may be + # returned immediately when attempting to connect to + # localhost, so handle them the same way as an error + # reported later in _handle_connect. + if future is None: + gen_log.warning("Connect error on fd %s: %s", self.socket.fileno(), e) + self.close(exc_info=e) + return future + self._add_io_state(self.io_loop.WRITE) + return future + + def start_tls( + self, + server_side: bool, + ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None, + server_hostname: Optional[str] = None, + ) -> Awaitable["SSLIOStream"]: + """Convert this `IOStream` to an `SSLIOStream`. + + This enables protocols that begin in clear-text mode and + switch to SSL after some initial negotiation (such as the + ``STARTTLS`` extension to SMTP and IMAP). + + This method cannot be used if there are outstanding reads + or writes on the stream, or if there is any data in the + IOStream's buffer (data in the operating system's socket + buffer is allowed). This means it must generally be used + immediately after reading or writing the last clear-text + data. It can also be used immediately after connecting, + before any reads or writes. + + The ``ssl_options`` argument may be either an `ssl.SSLContext` + object or a dictionary of keyword arguments for the + `ssl.wrap_socket` function. The ``server_hostname`` argument + will be used for certificate validation unless disabled + in the ``ssl_options``. + + This method returns a `.Future` whose result is the new + `SSLIOStream`. After this method has been called, + any other operation on the original stream is undefined. + + If a close callback is defined on this stream, it will be + transferred to the new stream. + + .. versionadded:: 4.0 + + .. versionchanged:: 4.2 + SSL certificates are validated by default; pass + ``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a + suitably-configured `ssl.SSLContext` to disable. + """ + if ( + self._read_future + or self._write_futures + or self._connect_future + or self._closed + or self._read_buffer + or self._write_buffer + ): + raise ValueError("IOStream is not idle; cannot convert to SSL") + if ssl_options is None: + if server_side: + ssl_options = _server_ssl_defaults + else: + ssl_options = _client_ssl_defaults + + socket = self.socket + self.io_loop.remove_handler(socket) + self.socket = None # type: ignore + socket = ssl_wrap_socket( + socket, + ssl_options, + server_hostname=server_hostname, + server_side=server_side, + do_handshake_on_connect=False, + ) + orig_close_callback = self._close_callback + self._close_callback = None + + future = Future() # type: Future[SSLIOStream] + ssl_stream = SSLIOStream(socket, ssl_options=ssl_options) + ssl_stream.set_close_callback(orig_close_callback) + ssl_stream._ssl_connect_future = future + ssl_stream.max_buffer_size = self.max_buffer_size + ssl_stream.read_chunk_size = self.read_chunk_size + return future + + def _handle_connect(self) -> None: + try: + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + except socket.error as e: + # Hurd doesn't allow SO_ERROR for loopback sockets because all + # errors for such sockets are reported synchronously. + if errno_from_exception(e) == errno.ENOPROTOOPT: + err = 0 + if err != 0: + self.error = socket.error(err, os.strerror(err)) + # IOLoop implementations may vary: some of them return + # an error state before the socket becomes writable, so + # in that case a connection failure would be handled by the + # error path in _handle_events instead of here. + if self._connect_future is None: + gen_log.warning( + "Connect error on fd %s: %s", + self.socket.fileno(), + errno.errorcode[err], + ) + self.close() + return + if self._connect_future is not None: + future = self._connect_future + self._connect_future = None + future_set_result_unless_cancelled(future, self) + self._connecting = False + + def set_nodelay(self, value: bool) -> None: + if self.socket is not None and self.socket.family in ( + socket.AF_INET, + socket.AF_INET6, + ): + try: + self.socket.setsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY, 1 if value else 0 + ) + except socket.error as e: + # Sometimes setsockopt will fail if the socket is closed + # at the wrong time. This can happen with HTTPServer + # resetting the value to ``False`` between requests. + if e.errno != errno.EINVAL and not self._is_connreset(e): + raise + + +class SSLIOStream(IOStream): + """A utility class to write to and read from a non-blocking SSL socket. + + If the socket passed to the constructor is already connected, + it should be wrapped with:: + + ssl.wrap_socket(sock, do_handshake_on_connect=False, **kwargs) + + before constructing the `SSLIOStream`. Unconnected sockets will be + wrapped when `IOStream.connect` is finished. + """ + + socket = None # type: ssl.SSLSocket + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """The ``ssl_options`` keyword argument may either be an + `ssl.SSLContext` object or a dictionary of keywords arguments + for `ssl.wrap_socket` + """ + self._ssl_options = kwargs.pop("ssl_options", _client_ssl_defaults) + super().__init__(*args, **kwargs) + self._ssl_accepting = True + self._handshake_reading = False + self._handshake_writing = False + self._server_hostname = None # type: Optional[str] + + # If the socket is already connected, attempt to start the handshake. + try: + self.socket.getpeername() + except socket.error: + pass + else: + # Indirectly start the handshake, which will run on the next + # IOLoop iteration and then the real IO state will be set in + # _handle_events. + self._add_io_state(self.io_loop.WRITE) + + def reading(self) -> bool: + return self._handshake_reading or super().reading() + + def writing(self) -> bool: + return self._handshake_writing or super().writing() + + def _do_ssl_handshake(self) -> None: + # Based on code from test_ssl.py in the python stdlib + try: + self._handshake_reading = False + self._handshake_writing = False + self.socket.do_handshake() + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + self._handshake_reading = True + return + elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + self._handshake_writing = True + return + elif err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN): + return self.close(exc_info=err) + elif err.args[0] == ssl.SSL_ERROR_SSL: + try: + peer = self.socket.getpeername() + except Exception: + peer = "(not connected)" + gen_log.warning( + "SSL Error on %s %s: %s", self.socket.fileno(), peer, err + ) + return self.close(exc_info=err) + raise + except ssl.CertificateError as err: + # CertificateError can happen during handshake (hostname + # verification) and should be passed to user. Starting + # in Python 3.7, this error is a subclass of SSLError + # and will be handled by the previous block instead. + return self.close(exc_info=err) + except socket.error as err: + # Some port scans (e.g. nmap in -sT mode) have been known + # to cause do_handshake to raise EBADF and ENOTCONN, so make + # those errors quiet as well. + # https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0 + # Errno 0 is also possible in some cases (nc -z). + # https://github.com/tornadoweb/tornado/issues/2504 + if self._is_connreset(err) or err.args[0] in ( + 0, + errno.EBADF, + errno.ENOTCONN, + ): + return self.close(exc_info=err) + raise + except AttributeError as err: + # On Linux, if the connection was reset before the call to + # wrap_socket, do_handshake will fail with an + # AttributeError. + return self.close(exc_info=err) + else: + self._ssl_accepting = False + if not self._verify_cert(self.socket.getpeercert()): + self.close() + return + self._finish_ssl_connect() + + def _finish_ssl_connect(self) -> None: + if self._ssl_connect_future is not None: + future = self._ssl_connect_future + self._ssl_connect_future = None + future_set_result_unless_cancelled(future, self) + + def _verify_cert(self, peercert: Any) -> bool: + """Returns ``True`` if peercert is valid according to the configured + validation mode and hostname. + + The ssl handshake already tested the certificate for a valid + CA signature; the only thing that remains is to check + the hostname. + """ + if isinstance(self._ssl_options, dict): + verify_mode = self._ssl_options.get("cert_reqs", ssl.CERT_NONE) + elif isinstance(self._ssl_options, ssl.SSLContext): + verify_mode = self._ssl_options.verify_mode + assert verify_mode in (ssl.CERT_NONE, ssl.CERT_REQUIRED, ssl.CERT_OPTIONAL) + if verify_mode == ssl.CERT_NONE or self._server_hostname is None: + return True + cert = self.socket.getpeercert() + if cert is None and verify_mode == ssl.CERT_REQUIRED: + gen_log.warning("No SSL certificate given") + return False + try: + ssl.match_hostname(peercert, self._server_hostname) + except ssl.CertificateError as e: + gen_log.warning("Invalid SSL certificate: %s" % e) + return False + else: + return True + + def _handle_read(self) -> None: + if self._ssl_accepting: + self._do_ssl_handshake() + return + super()._handle_read() + + def _handle_write(self) -> None: + if self._ssl_accepting: + self._do_ssl_handshake() + return + super()._handle_write() + + def connect( + self, address: Tuple, server_hostname: Optional[str] = None + ) -> "Future[SSLIOStream]": + self._server_hostname = server_hostname + # Ignore the result of connect(). If it fails, + # wait_for_handshake will raise an error too. This is + # necessary for the old semantics of the connect callback + # (which takes no arguments). In 6.0 this can be refactored to + # be a regular coroutine. + # TODO: This is trickier than it looks, since if write() + # is called with a connect() pending, we want the connect + # to resolve before the write. Or do we care about this? + # (There's a test for it, but I think in practice users + # either wait for the connect before performing a write or + # they don't care about the connect Future at all) + fut = super().connect(address) + fut.add_done_callback(lambda f: f.exception()) + return self.wait_for_handshake() + + def _handle_connect(self) -> None: + # Call the superclass method to check for errors. + super()._handle_connect() + if self.closed(): + return + # When the connection is complete, wrap the socket for SSL + # traffic. Note that we do this by overriding _handle_connect + # instead of by passing a callback to super().connect because + # user callbacks are enqueued asynchronously on the IOLoop, + # but since _handle_events calls _handle_connect immediately + # followed by _handle_write we need this to be synchronous. + # + # The IOLoop will get confused if we swap out self.socket while the + # fd is registered, so remove it now and re-register after + # wrap_socket(). + self.io_loop.remove_handler(self.socket) + old_state = self._state + assert old_state is not None + self._state = None + self.socket = ssl_wrap_socket( + self.socket, + self._ssl_options, + server_hostname=self._server_hostname, + do_handshake_on_connect=False, + server_side=False, + ) + self._add_io_state(old_state) + + def wait_for_handshake(self) -> "Future[SSLIOStream]": + """Wait for the initial SSL handshake to complete. + + If a ``callback`` is given, it will be called with no + arguments once the handshake is complete; otherwise this + method returns a `.Future` which will resolve to the + stream itself after the handshake is complete. + + Once the handshake is complete, information such as + the peer's certificate and NPN/ALPN selections may be + accessed on ``self.socket``. + + This method is intended for use on server-side streams + or after using `IOStream.start_tls`; it should not be used + with `IOStream.connect` (which already waits for the + handshake to complete). It may only be called once per stream. + + .. versionadded:: 4.2 + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. Use the returned + `.Future` instead. + + """ + if self._ssl_connect_future is not None: + raise RuntimeError("Already waiting") + future = self._ssl_connect_future = Future() + if not self._ssl_accepting: + self._finish_ssl_connect() + return future + + def write_to_fd(self, data: memoryview) -> int: + # clip buffer size at 1GB since SSL sockets only support upto 2GB + # this change in behaviour is transparent, since the function is + # already expected to (possibly) write less than the provided buffer + if len(data) >> 30: + data = memoryview(data)[: 1 << 30] + try: + return self.socket.send(data) # type: ignore + except ssl.SSLError as e: + if e.args[0] == ssl.SSL_ERROR_WANT_WRITE: + # In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if + # the socket is not writeable; we need to transform this into + # an EWOULDBLOCK socket.error or a zero return value, + # either of which will be recognized by the caller of this + # method. Prior to Python 3.5, an unwriteable socket would + # simply return 0 bytes written. + return 0 + raise + finally: + # Avoid keeping to data, which can be a memoryview. + # See https://github.com/tornadoweb/tornado/pull/2008 + del data + + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: + try: + if self._ssl_accepting: + # If the handshake hasn't finished yet, there can't be anything + # to read (attempting to read may or may not raise an exception + # depending on the SSL version) + return None + # clip buffer size at 1GB since SSL sockets only support upto 2GB + # this change in behaviour is transparent, since the function is + # already expected to (possibly) read less than the provided buffer + if len(buf) >> 30: + buf = memoryview(buf)[: 1 << 30] + try: + return self.socket.recv_into(buf, len(buf)) + except ssl.SSLError as e: + # SSLError is a subclass of socket.error, so this except + # block must come first. + if e.args[0] == ssl.SSL_ERROR_WANT_READ: + return None + else: + raise + except BlockingIOError: + return None + finally: + del buf + + def _is_connreset(self, e: BaseException) -> bool: + if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF: + return True + return super()._is_connreset(e) + + +class PipeIOStream(BaseIOStream): + """Pipe-based `IOStream` implementation. + + The constructor takes an integer file descriptor (such as one returned + by `os.pipe`) rather than an open file object. Pipes are generally + one-way, so a `PipeIOStream` can be used for reading or writing but not + both. + + ``PipeIOStream`` is only available on Unix-based platforms. + """ + + def __init__(self, fd: int, *args: Any, **kwargs: Any) -> None: + self.fd = fd + self._fio = io.FileIO(self.fd, "r+") + if sys.platform == "win32": + # The form and placement of this assertion is important to mypy. + # A plain assert statement isn't recognized here. If the assertion + # were earlier it would worry that the attributes of self aren't + # set on windows. If it were missing it would complain about + # the absence of the set_blocking function. + raise AssertionError("PipeIOStream is not supported on Windows") + os.set_blocking(fd, False) + super().__init__(*args, **kwargs) + + def fileno(self) -> int: + return self.fd + + def close_fd(self) -> None: + self._fio.close() + + def write_to_fd(self, data: memoryview) -> int: + try: + return os.write(self.fd, data) # type: ignore + finally: + # Avoid keeping to data, which can be a memoryview. + # See https://github.com/tornadoweb/tornado/pull/2008 + del data + + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: + try: + return self._fio.readinto(buf) # type: ignore + except (IOError, OSError) as e: + if errno_from_exception(e) == errno.EBADF: + # If the writing half of a pipe is closed, select will + # report it as readable but reads will fail with EBADF. + self.close(exc_info=e) + return None + else: + raise + finally: + del buf + + +def doctests() -> Any: + import doctest + + return doctest.DocTestSuite() diff --git a/venv/lib/python3.9/site-packages/tornado/locale.py b/venv/lib/python3.9/site-packages/tornado/locale.py new file mode 100644 index 00000000..55072af2 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/locale.py @@ -0,0 +1,581 @@ +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Translation methods for generating localized strings. + +To load a locale and generate a translated string:: + + user_locale = tornado.locale.get("es_LA") + print(user_locale.translate("Sign out")) + +`tornado.locale.get()` returns the closest matching locale, not necessarily the +specific locale you requested. You can support pluralization with +additional arguments to `~Locale.translate()`, e.g.:: + + people = [...] + message = user_locale.translate( + "%(list)s is online", "%(list)s are online", len(people)) + print(message % {"list": user_locale.list(people)}) + +The first string is chosen if ``len(people) == 1``, otherwise the second +string is chosen. + +Applications should call one of `load_translations` (which uses a simple +CSV format) or `load_gettext_translations` (which uses the ``.mo`` format +supported by `gettext` and related tools). If neither method is called, +the `Locale.translate` method will simply return the original string. +""" + +import codecs +import csv +import datetime +import gettext +import glob +import os +import re + +from tornado import escape +from tornado.log import gen_log + +from tornado._locale_data import LOCALE_NAMES + +from typing import Iterable, Any, Union, Dict, Optional + +_default_locale = "en_US" +_translations = {} # type: Dict[str, Any] +_supported_locales = frozenset([_default_locale]) +_use_gettext = False +CONTEXT_SEPARATOR = "\x04" + + +def get(*locale_codes: str) -> "Locale": + """Returns the closest match for the given locale codes. + + We iterate over all given locale codes in order. If we have a tight + or a loose match for the code (e.g., "en" for "en_US"), we return + the locale. Otherwise we move to the next code in the list. + + By default we return ``en_US`` if no translations are found for any of + the specified locales. You can change the default locale with + `set_default_locale()`. + """ + return Locale.get_closest(*locale_codes) + + +def set_default_locale(code: str) -> None: + """Sets the default locale. + + The default locale is assumed to be the language used for all strings + in the system. The translations loaded from disk are mappings from + the default locale to the destination locale. Consequently, you don't + need to create a translation file for the default locale. + """ + global _default_locale + global _supported_locales + _default_locale = code + _supported_locales = frozenset(list(_translations.keys()) + [_default_locale]) + + +def load_translations(directory: str, encoding: Optional[str] = None) -> None: + """Loads translations from CSV files in a directory. + + Translations are strings with optional Python-style named placeholders + (e.g., ``My name is %(name)s``) and their associated translations. + + The directory should have translation files of the form ``LOCALE.csv``, + e.g. ``es_GT.csv``. The CSV files should have two or three columns: string, + translation, and an optional plural indicator. Plural indicators should + be one of "plural" or "singular". A given string can have both singular + and plural forms. For example ``%(name)s liked this`` may have a + different verb conjugation depending on whether %(name)s is one + name or a list of names. There should be two rows in the CSV file for + that string, one with plural indicator "singular", and one "plural". + For strings with no verbs that would change on translation, simply + use "unknown" or the empty string (or don't include the column at all). + + The file is read using the `csv` module in the default "excel" dialect. + In this format there should not be spaces after the commas. + + If no ``encoding`` parameter is given, the encoding will be + detected automatically (among UTF-8 and UTF-16) if the file + contains a byte-order marker (BOM), defaulting to UTF-8 if no BOM + is present. + + Example translation ``es_LA.csv``:: + + "I love you","Te amo" + "%(name)s liked this","A %(name)s les gustó esto","plural" + "%(name)s liked this","A %(name)s le gustó esto","singular" + + .. versionchanged:: 4.3 + Added ``encoding`` parameter. Added support for BOM-based encoding + detection, UTF-16, and UTF-8-with-BOM. + """ + global _translations + global _supported_locales + _translations = {} + for path in os.listdir(directory): + if not path.endswith(".csv"): + continue + locale, extension = path.split(".") + if not re.match("[a-z]+(_[A-Z]+)?$", locale): + gen_log.error( + "Unrecognized locale %r (path: %s)", + locale, + os.path.join(directory, path), + ) + continue + full_path = os.path.join(directory, path) + if encoding is None: + # Try to autodetect encoding based on the BOM. + with open(full_path, "rb") as bf: + data = bf.read(len(codecs.BOM_UTF16_LE)) + if data in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE): + encoding = "utf-16" + else: + # utf-8-sig is "utf-8 with optional BOM". It's discouraged + # in most cases but is common with CSV files because Excel + # cannot read utf-8 files without a BOM. + encoding = "utf-8-sig" + # python 3: csv.reader requires a file open in text mode. + # Specify an encoding to avoid dependence on $LANG environment variable. + with open(full_path, encoding=encoding) as f: + _translations[locale] = {} + for i, row in enumerate(csv.reader(f)): + if not row or len(row) < 2: + continue + row = [escape.to_unicode(c).strip() for c in row] + english, translation = row[:2] + if len(row) > 2: + plural = row[2] or "unknown" + else: + plural = "unknown" + if plural not in ("plural", "singular", "unknown"): + gen_log.error( + "Unrecognized plural indicator %r in %s line %d", + plural, + path, + i + 1, + ) + continue + _translations[locale].setdefault(plural, {})[english] = translation + _supported_locales = frozenset(list(_translations.keys()) + [_default_locale]) + gen_log.debug("Supported locales: %s", sorted(_supported_locales)) + + +def load_gettext_translations(directory: str, domain: str) -> None: + """Loads translations from `gettext`'s locale tree + + Locale tree is similar to system's ``/usr/share/locale``, like:: + + {directory}/{lang}/LC_MESSAGES/{domain}.mo + + Three steps are required to have your app translated: + + 1. Generate POT translation file:: + + xgettext --language=Python --keyword=_:1,2 -d mydomain file1.py file2.html etc + + 2. Merge against existing POT file:: + + msgmerge old.po mydomain.po > new.po + + 3. Compile:: + + msgfmt mydomain.po -o {directory}/pt_BR/LC_MESSAGES/mydomain.mo + """ + global _translations + global _supported_locales + global _use_gettext + _translations = {} + + for filename in glob.glob( + os.path.join(directory, "*", "LC_MESSAGES", domain + ".mo") + ): + lang = os.path.basename(os.path.dirname(os.path.dirname(filename))) + try: + _translations[lang] = gettext.translation( + domain, directory, languages=[lang] + ) + except Exception as e: + gen_log.error("Cannot load translation for '%s': %s", lang, str(e)) + continue + _supported_locales = frozenset(list(_translations.keys()) + [_default_locale]) + _use_gettext = True + gen_log.debug("Supported locales: %s", sorted(_supported_locales)) + + +def get_supported_locales() -> Iterable[str]: + """Returns a list of all the supported locale codes.""" + return _supported_locales + + +class Locale(object): + """Object representing a locale. + + After calling one of `load_translations` or `load_gettext_translations`, + call `get` or `get_closest` to get a Locale object. + """ + + _cache = {} # type: Dict[str, Locale] + + @classmethod + def get_closest(cls, *locale_codes: str) -> "Locale": + """Returns the closest match for the given locale code.""" + for code in locale_codes: + if not code: + continue + code = code.replace("-", "_") + parts = code.split("_") + if len(parts) > 2: + continue + elif len(parts) == 2: + code = parts[0].lower() + "_" + parts[1].upper() + if code in _supported_locales: + return cls.get(code) + if parts[0].lower() in _supported_locales: + return cls.get(parts[0].lower()) + return cls.get(_default_locale) + + @classmethod + def get(cls, code: str) -> "Locale": + """Returns the Locale for the given locale code. + + If it is not supported, we raise an exception. + """ + if code not in cls._cache: + assert code in _supported_locales + translations = _translations.get(code, None) + if translations is None: + locale = CSVLocale(code, {}) # type: Locale + elif _use_gettext: + locale = GettextLocale(code, translations) + else: + locale = CSVLocale(code, translations) + cls._cache[code] = locale + return cls._cache[code] + + def __init__(self, code: str) -> None: + self.code = code + self.name = LOCALE_NAMES.get(code, {}).get("name", "Unknown") + self.rtl = False + for prefix in ["fa", "ar", "he"]: + if self.code.startswith(prefix): + self.rtl = True + break + + # Initialize strings for date formatting + _ = self.translate + self._months = [ + _("January"), + _("February"), + _("March"), + _("April"), + _("May"), + _("June"), + _("July"), + _("August"), + _("September"), + _("October"), + _("November"), + _("December"), + ] + self._weekdays = [ + _("Monday"), + _("Tuesday"), + _("Wednesday"), + _("Thursday"), + _("Friday"), + _("Saturday"), + _("Sunday"), + ] + + def translate( + self, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: + """Returns the translation for the given message for this locale. + + If ``plural_message`` is given, you must also provide + ``count``. We return ``plural_message`` when ``count != 1``, + and we return the singular form for the given message when + ``count == 1``. + """ + raise NotImplementedError() + + def pgettext( + self, + context: str, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: + raise NotImplementedError() + + def format_date( + self, + date: Union[int, float, datetime.datetime], + gmt_offset: int = 0, + relative: bool = True, + shorter: bool = False, + full_format: bool = False, + ) -> str: + """Formats the given date (which should be GMT). + + By default, we return a relative time (e.g., "2 minutes ago"). You + can return an absolute date string with ``relative=False``. + + You can force a full format date ("July 10, 1980") with + ``full_format=True``. + + This method is primarily intended for dates in the past. + For dates in the future, we fall back to full format. + """ + if isinstance(date, (int, float)): + date = datetime.datetime.utcfromtimestamp(date) + now = datetime.datetime.utcnow() + if date > now: + if relative and (date - now).seconds < 60: + # Due to click skew, things are some things slightly + # in the future. Round timestamps in the immediate + # future down to now in relative mode. + date = now + else: + # Otherwise, future dates always use the full format. + full_format = True + local_date = date - datetime.timedelta(minutes=gmt_offset) + local_now = now - datetime.timedelta(minutes=gmt_offset) + local_yesterday = local_now - datetime.timedelta(hours=24) + difference = now - date + seconds = difference.seconds + days = difference.days + + _ = self.translate + format = None + if not full_format: + if relative and days == 0: + if seconds < 50: + return _("1 second ago", "%(seconds)d seconds ago", seconds) % { + "seconds": seconds + } + + if seconds < 50 * 60: + minutes = round(seconds / 60.0) + return _("1 minute ago", "%(minutes)d minutes ago", minutes) % { + "minutes": minutes + } + + hours = round(seconds / (60.0 * 60)) + return _("1 hour ago", "%(hours)d hours ago", hours) % {"hours": hours} + + if days == 0: + format = _("%(time)s") + elif days == 1 and local_date.day == local_yesterday.day and relative: + format = _("yesterday") if shorter else _("yesterday at %(time)s") + elif days < 5: + format = _("%(weekday)s") if shorter else _("%(weekday)s at %(time)s") + elif days < 334: # 11mo, since confusing for same month last year + format = ( + _("%(month_name)s %(day)s") + if shorter + else _("%(month_name)s %(day)s at %(time)s") + ) + + if format is None: + format = ( + _("%(month_name)s %(day)s, %(year)s") + if shorter + else _("%(month_name)s %(day)s, %(year)s at %(time)s") + ) + + tfhour_clock = self.code not in ("en", "en_US", "zh_CN") + if tfhour_clock: + str_time = "%d:%02d" % (local_date.hour, local_date.minute) + elif self.code == "zh_CN": + str_time = "%s%d:%02d" % ( + ("\u4e0a\u5348", "\u4e0b\u5348")[local_date.hour >= 12], + local_date.hour % 12 or 12, + local_date.minute, + ) + else: + str_time = "%d:%02d %s" % ( + local_date.hour % 12 or 12, + local_date.minute, + ("am", "pm")[local_date.hour >= 12], + ) + + return format % { + "month_name": self._months[local_date.month - 1], + "weekday": self._weekdays[local_date.weekday()], + "day": str(local_date.day), + "year": str(local_date.year), + "time": str_time, + } + + def format_day( + self, date: datetime.datetime, gmt_offset: int = 0, dow: bool = True + ) -> bool: + """Formats the given date as a day of week. + + Example: "Monday, January 22". You can remove the day of week with + ``dow=False``. + """ + local_date = date - datetime.timedelta(minutes=gmt_offset) + _ = self.translate + if dow: + return _("%(weekday)s, %(month_name)s %(day)s") % { + "month_name": self._months[local_date.month - 1], + "weekday": self._weekdays[local_date.weekday()], + "day": str(local_date.day), + } + else: + return _("%(month_name)s %(day)s") % { + "month_name": self._months[local_date.month - 1], + "day": str(local_date.day), + } + + def list(self, parts: Any) -> str: + """Returns a comma-separated list for the given list of parts. + + The format is, e.g., "A, B and C", "A and B" or just "A" for lists + of size 1. + """ + _ = self.translate + if len(parts) == 0: + return "" + if len(parts) == 1: + return parts[0] + comma = " \u0648 " if self.code.startswith("fa") else ", " + return _("%(commas)s and %(last)s") % { + "commas": comma.join(parts[:-1]), + "last": parts[len(parts) - 1], + } + + def friendly_number(self, value: int) -> str: + """Returns a comma-separated number for the given integer.""" + if self.code not in ("en", "en_US"): + return str(value) + s = str(value) + parts = [] + while s: + parts.append(s[-3:]) + s = s[:-3] + return ",".join(reversed(parts)) + + +class CSVLocale(Locale): + """Locale implementation using tornado's CSV translation format.""" + + def __init__(self, code: str, translations: Dict[str, Dict[str, str]]) -> None: + self.translations = translations + super().__init__(code) + + def translate( + self, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: + if plural_message is not None: + assert count is not None + if count != 1: + message = plural_message + message_dict = self.translations.get("plural", {}) + else: + message_dict = self.translations.get("singular", {}) + else: + message_dict = self.translations.get("unknown", {}) + return message_dict.get(message, message) + + def pgettext( + self, + context: str, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: + if self.translations: + gen_log.warning("pgettext is not supported by CSVLocale") + return self.translate(message, plural_message, count) + + +class GettextLocale(Locale): + """Locale implementation using the `gettext` module.""" + + def __init__(self, code: str, translations: gettext.NullTranslations) -> None: + self.ngettext = translations.ngettext + self.gettext = translations.gettext + # self.gettext must exist before __init__ is called, since it + # calls into self.translate + super().__init__(code) + + def translate( + self, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: + if plural_message is not None: + assert count is not None + return self.ngettext(message, plural_message, count) + else: + return self.gettext(message) + + def pgettext( + self, + context: str, + message: str, + plural_message: Optional[str] = None, + count: Optional[int] = None, + ) -> str: + """Allows to set context for translation, accepts plural forms. + + Usage example:: + + pgettext("law", "right") + pgettext("good", "right") + + Plural message example:: + + pgettext("organization", "club", "clubs", len(clubs)) + pgettext("stick", "club", "clubs", len(clubs)) + + To generate POT file with context, add following options to step 1 + of `load_gettext_translations` sequence:: + + xgettext [basic options] --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3 + + .. versionadded:: 4.2 + """ + if plural_message is not None: + assert count is not None + msgs_with_ctxt = ( + "%s%s%s" % (context, CONTEXT_SEPARATOR, message), + "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message), + count, + ) + result = self.ngettext(*msgs_with_ctxt) + if CONTEXT_SEPARATOR in result: + # Translation not found + result = self.ngettext(message, plural_message, count) + return result + else: + msg_with_ctxt = "%s%s%s" % (context, CONTEXT_SEPARATOR, message) + result = self.gettext(msg_with_ctxt) + if CONTEXT_SEPARATOR in result: + # Translation not found + result = message + return result diff --git a/venv/lib/python3.9/site-packages/tornado/locks.py b/venv/lib/python3.9/site-packages/tornado/locks.py new file mode 100644 index 00000000..1bcec1b3 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/locks.py @@ -0,0 +1,572 @@ +# Copyright 2015 The Tornado Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import datetime +import types + +from tornado import gen, ioloop +from tornado.concurrent import Future, future_set_result_unless_cancelled + +from typing import Union, Optional, Type, Any, Awaitable +import typing + +if typing.TYPE_CHECKING: + from typing import Deque, Set # noqa: F401 + +__all__ = ["Condition", "Event", "Semaphore", "BoundedSemaphore", "Lock"] + + +class _TimeoutGarbageCollector(object): + """Base class for objects that periodically clean up timed-out waiters. + + Avoids memory leak in a common pattern like: + + while True: + yield condition.wait(short_timeout) + print('looping....') + """ + + def __init__(self) -> None: + self._waiters = collections.deque() # type: Deque[Future] + self._timeouts = 0 + + def _garbage_collect(self) -> None: + # Occasionally clear timed-out waiters. + self._timeouts += 1 + if self._timeouts > 100: + self._timeouts = 0 + self._waiters = collections.deque(w for w in self._waiters if not w.done()) + + +class Condition(_TimeoutGarbageCollector): + """A condition allows one or more coroutines to wait until notified. + + Like a standard `threading.Condition`, but does not need an underlying lock + that is acquired and released. + + With a `Condition`, coroutines can wait to be notified by other coroutines: + + .. testcode:: + + import asyncio + from tornado import gen + from tornado.locks import Condition + + condition = Condition() + + async def waiter(): + print("I'll wait right here") + await condition.wait() + print("I'm done waiting") + + async def notifier(): + print("About to notify") + condition.notify() + print("Done notifying") + + async def runner(): + # Wait for waiter() and notifier() in parallel + await gen.multi([waiter(), notifier()]) + + asyncio.run(runner()) + + .. testoutput:: + + I'll wait right here + About to notify + Done notifying + I'm done waiting + + `wait` takes an optional ``timeout`` argument, which is either an absolute + timestamp:: + + io_loop = IOLoop.current() + + # Wait up to 1 second for a notification. + await condition.wait(timeout=io_loop.time() + 1) + + ...or a `datetime.timedelta` for a timeout relative to the current time:: + + # Wait up to 1 second. + await condition.wait(timeout=datetime.timedelta(seconds=1)) + + The method returns False if there's no notification before the deadline. + + .. versionchanged:: 5.0 + Previously, waiters could be notified synchronously from within + `notify`. Now, the notification will always be received on the + next iteration of the `.IOLoop`. + """ + + def __repr__(self) -> str: + result = "<%s" % (self.__class__.__name__,) + if self._waiters: + result += " waiters[%s]" % len(self._waiters) + return result + ">" + + def wait( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[bool]: + """Wait for `.notify`. + + Returns a `.Future` that resolves ``True`` if the condition is notified, + or ``False`` after a timeout. + """ + waiter = Future() # type: Future[bool] + self._waiters.append(waiter) + if timeout: + + def on_timeout() -> None: + if not waiter.done(): + future_set_result_unless_cancelled(waiter, False) + self._garbage_collect() + + io_loop = ioloop.IOLoop.current() + timeout_handle = io_loop.add_timeout(timeout, on_timeout) + waiter.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle)) + return waiter + + def notify(self, n: int = 1) -> None: + """Wake ``n`` waiters.""" + waiters = [] # Waiters we plan to run right now. + while n and self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): # Might have timed out. + n -= 1 + waiters.append(waiter) + + for waiter in waiters: + future_set_result_unless_cancelled(waiter, True) + + def notify_all(self) -> None: + """Wake all waiters.""" + self.notify(len(self._waiters)) + + +class Event(object): + """An event blocks coroutines until its internal flag is set to True. + + Similar to `threading.Event`. + + A coroutine can wait for an event to be set. Once it is set, calls to + ``yield event.wait()`` will not block unless the event has been cleared: + + .. testcode:: + + import asyncio + from tornado import gen + from tornado.locks import Event + + event = Event() + + async def waiter(): + print("Waiting for event") + await event.wait() + print("Not waiting this time") + await event.wait() + print("Done") + + async def setter(): + print("About to set the event") + event.set() + + async def runner(): + await gen.multi([waiter(), setter()]) + + asyncio.run(runner()) + + .. testoutput:: + + Waiting for event + About to set the event + Not waiting this time + Done + """ + + def __init__(self) -> None: + self._value = False + self._waiters = set() # type: Set[Future[None]] + + def __repr__(self) -> str: + return "<%s %s>" % ( + self.__class__.__name__, + "set" if self.is_set() else "clear", + ) + + def is_set(self) -> bool: + """Return ``True`` if the internal flag is true.""" + return self._value + + def set(self) -> None: + """Set the internal flag to ``True``. All waiters are awakened. + + Calling `.wait` once the flag is set will not block. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(None) + + def clear(self) -> None: + """Reset the internal flag to ``False``. + + Calls to `.wait` will block until `.set` is called. + """ + self._value = False + + def wait( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[None]: + """Block until the internal flag is true. + + Returns an awaitable, which raises `tornado.util.TimeoutError` after a + timeout. + """ + fut = Future() # type: Future[None] + if self._value: + fut.set_result(None) + return fut + self._waiters.add(fut) + fut.add_done_callback(lambda fut: self._waiters.remove(fut)) + if timeout is None: + return fut + else: + timeout_fut = gen.with_timeout(timeout, fut) + # This is a slightly clumsy workaround for the fact that + # gen.with_timeout doesn't cancel its futures. Cancelling + # fut will remove it from the waiters list. + timeout_fut.add_done_callback( + lambda tf: fut.cancel() if not fut.done() else None + ) + return timeout_fut + + +class _ReleasingContextManager(object): + """Releases a Lock or Semaphore at the end of a "with" statement. + + with (yield semaphore.acquire()): + pass + + # Now semaphore.release() has been called. + """ + + def __init__(self, obj: Any) -> None: + self._obj = obj + + def __enter__(self) -> None: + pass + + def __exit__( + self, + exc_type: "Optional[Type[BaseException]]", + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType], + ) -> None: + self._obj.release() + + +class Semaphore(_TimeoutGarbageCollector): + """A lock that can be acquired a fixed number of times before blocking. + + A Semaphore manages a counter representing the number of `.release` calls + minus the number of `.acquire` calls, plus an initial value. The `.acquire` + method blocks if necessary until it can return without making the counter + negative. + + Semaphores limit access to a shared resource. To allow access for two + workers at a time: + + .. testsetup:: semaphore + + from collections import deque + + from tornado import gen + from tornado.ioloop import IOLoop + from tornado.concurrent import Future + + inited = False + + async def simulator(futures): + for f in futures: + # simulate the asynchronous passage of time + await gen.sleep(0) + await gen.sleep(0) + f.set_result(None) + + def use_some_resource(): + global inited + global futures_q + if not inited: + inited = True + # Ensure reliable doctest output: resolve Futures one at a time. + futures_q = deque([Future() for _ in range(3)]) + IOLoop.current().add_callback(simulator, list(futures_q)) + + return futures_q.popleft() + + .. testcode:: semaphore + + import asyncio + from tornado import gen + from tornado.locks import Semaphore + + sem = Semaphore(2) + + async def worker(worker_id): + await sem.acquire() + try: + print("Worker %d is working" % worker_id) + await use_some_resource() + finally: + print("Worker %d is done" % worker_id) + sem.release() + + async def runner(): + # Join all workers. + await gen.multi([worker(i) for i in range(3)]) + + asyncio.run(runner()) + + .. testoutput:: semaphore + + Worker 0 is working + Worker 1 is working + Worker 0 is done + Worker 2 is working + Worker 1 is done + Worker 2 is done + + Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until + the semaphore has been released once, by worker 0. + + The semaphore can be used as an async context manager:: + + async def worker(worker_id): + async with sem: + print("Worker %d is working" % worker_id) + await use_some_resource() + + # Now the semaphore has been released. + print("Worker %d is done" % worker_id) + + For compatibility with older versions of Python, `.acquire` is a + context manager, so ``worker`` could also be written as:: + + @gen.coroutine + def worker(worker_id): + with (yield sem.acquire()): + print("Worker %d is working" % worker_id) + yield use_some_resource() + + # Now the semaphore has been released. + print("Worker %d is done" % worker_id) + + .. versionchanged:: 4.3 + Added ``async with`` support in Python 3.5. + + """ + + def __init__(self, value: int = 1) -> None: + super().__init__() + if value < 0: + raise ValueError("semaphore initial value must be >= 0") + + self._value = value + + def __repr__(self) -> str: + res = super().__repr__() + extra = ( + "locked" if self._value == 0 else "unlocked,value:{0}".format(self._value) + ) + if self._waiters: + extra = "{0},waiters:{1}".format(extra, len(self._waiters)) + return "<{0} [{1}]>".format(res[1:-1], extra) + + def release(self) -> None: + """Increment the counter and wake one waiter.""" + self._value += 1 + while self._waiters: + waiter = self._waiters.popleft() + if not waiter.done(): + self._value -= 1 + + # If the waiter is a coroutine paused at + # + # with (yield semaphore.acquire()): + # + # then the context manager's __exit__ calls release() at the end + # of the "with" block. + waiter.set_result(_ReleasingContextManager(self)) + break + + def acquire( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[_ReleasingContextManager]: + """Decrement the counter. Returns an awaitable. + + Block if the counter is zero and wait for a `.release`. The awaitable + raises `.TimeoutError` after the deadline. + """ + waiter = Future() # type: Future[_ReleasingContextManager] + if self._value > 0: + self._value -= 1 + waiter.set_result(_ReleasingContextManager(self)) + else: + self._waiters.append(waiter) + if timeout: + + def on_timeout() -> None: + if not waiter.done(): + waiter.set_exception(gen.TimeoutError()) + self._garbage_collect() + + io_loop = ioloop.IOLoop.current() + timeout_handle = io_loop.add_timeout(timeout, on_timeout) + waiter.add_done_callback( + lambda _: io_loop.remove_timeout(timeout_handle) + ) + return waiter + + def __enter__(self) -> None: + raise RuntimeError("Use 'async with' instead of 'with' for Semaphore") + + def __exit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> None: + self.__enter__() + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[types.TracebackType], + ) -> None: + self.release() + + +class BoundedSemaphore(Semaphore): + """A semaphore that prevents release() being called too many times. + + If `.release` would increment the semaphore's value past the initial + value, it raises `ValueError`. Semaphores are mostly used to guard + resources with limited capacity, so a semaphore released too many times + is a sign of a bug. + """ + + def __init__(self, value: int = 1) -> None: + super().__init__(value=value) + self._initial_value = value + + def release(self) -> None: + """Increment the counter and wake one waiter.""" + if self._value >= self._initial_value: + raise ValueError("Semaphore released too many times") + super().release() + + +class Lock(object): + """A lock for coroutines. + + A Lock begins unlocked, and `acquire` locks it immediately. While it is + locked, a coroutine that yields `acquire` waits until another coroutine + calls `release`. + + Releasing an unlocked lock raises `RuntimeError`. + + A Lock can be used as an async context manager with the ``async + with`` statement: + + >>> from tornado import locks + >>> lock = locks.Lock() + >>> + >>> async def f(): + ... async with lock: + ... # Do something holding the lock. + ... pass + ... + ... # Now the lock is released. + + For compatibility with older versions of Python, the `.acquire` + method asynchronously returns a regular context manager: + + >>> async def f2(): + ... with (yield lock.acquire()): + ... # Do something holding the lock. + ... pass + ... + ... # Now the lock is released. + + .. versionchanged:: 4.3 + Added ``async with`` support in Python 3.5. + + """ + + def __init__(self) -> None: + self._block = BoundedSemaphore(value=1) + + def __repr__(self) -> str: + return "<%s _block=%s>" % (self.__class__.__name__, self._block) + + def acquire( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[_ReleasingContextManager]: + """Attempt to lock. Returns an awaitable. + + Returns an awaitable, which raises `tornado.util.TimeoutError` after a + timeout. + """ + return self._block.acquire(timeout) + + def release(self) -> None: + """Unlock. + + The first coroutine in line waiting for `acquire` gets the lock. + + If not locked, raise a `RuntimeError`. + """ + try: + self._block.release() + except ValueError: + raise RuntimeError("release unlocked lock") + + def __enter__(self) -> None: + raise RuntimeError("Use `async with` instead of `with` for Lock") + + def __exit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[types.TracebackType], + ) -> None: + self.__enter__() + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[types.TracebackType], + ) -> None: + self.release() diff --git a/venv/lib/python3.9/site-packages/tornado/log.py b/venv/lib/python3.9/site-packages/tornado/log.py new file mode 100644 index 00000000..86998961 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/log.py @@ -0,0 +1,343 @@ +# +# Copyright 2012 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""Logging support for Tornado. + +Tornado uses three logger streams: + +* ``tornado.access``: Per-request logging for Tornado's HTTP servers (and + potentially other servers in the future) +* ``tornado.application``: Logging of errors from application code (i.e. + uncaught exceptions from callbacks) +* ``tornado.general``: General-purpose logging, including any errors + or warnings from Tornado itself. + +These streams may be configured independently using the standard library's +`logging` module. For example, you may wish to send ``tornado.access`` logs +to a separate file for analysis. +""" +import logging +import logging.handlers +import sys + +from tornado.escape import _unicode +from tornado.util import unicode_type, basestring_type + +try: + import colorama # type: ignore +except ImportError: + colorama = None + +try: + import curses +except ImportError: + curses = None # type: ignore + +from typing import Dict, Any, cast, Optional + +# Logger objects for internal tornado use +access_log = logging.getLogger("tornado.access") +app_log = logging.getLogger("tornado.application") +gen_log = logging.getLogger("tornado.general") + + +def _stderr_supports_color() -> bool: + try: + if hasattr(sys.stderr, "isatty") and sys.stderr.isatty(): + if curses: + curses.setupterm() + if curses.tigetnum("colors") > 0: + return True + elif colorama: + if sys.stderr is getattr( + colorama.initialise, "wrapped_stderr", object() + ): + return True + except Exception: + # Very broad exception handling because it's always better to + # fall back to non-colored logs than to break at startup. + pass + return False + + +def _safe_unicode(s: Any) -> str: + try: + return _unicode(s) + except UnicodeDecodeError: + return repr(s) + + +class LogFormatter(logging.Formatter): + """Log formatter used in Tornado. + + Key features of this formatter are: + + * Color support when logging to a terminal that supports it. + * Timestamps on every log line. + * Robust against str/bytes encoding problems. + + This formatter is enabled automatically by + `tornado.options.parse_command_line` or `tornado.options.parse_config_file` + (unless ``--logging=none`` is used). + + Color support on Windows versions that do not support ANSI color codes is + enabled by use of the colorama__ library. Applications that wish to use + this must first initialize colorama with a call to ``colorama.init``. + See the colorama documentation for details. + + __ https://pypi.python.org/pypi/colorama + + .. versionchanged:: 4.5 + Added support for ``colorama``. Changed the constructor + signature to be compatible with `logging.config.dictConfig`. + """ + + DEFAULT_FORMAT = "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s" # noqa: E501 + DEFAULT_DATE_FORMAT = "%y%m%d %H:%M:%S" + DEFAULT_COLORS = { + logging.DEBUG: 4, # Blue + logging.INFO: 2, # Green + logging.WARNING: 3, # Yellow + logging.ERROR: 1, # Red + logging.CRITICAL: 5, # Magenta + } + + def __init__( + self, + fmt: str = DEFAULT_FORMAT, + datefmt: str = DEFAULT_DATE_FORMAT, + style: str = "%", + color: bool = True, + colors: Dict[int, int] = DEFAULT_COLORS, + ) -> None: + r""" + :arg bool color: Enables color support. + :arg str fmt: Log message format. + It will be applied to the attributes dict of log records. The + text between ``%(color)s`` and ``%(end_color)s`` will be colored + depending on the level if color support is on. + :arg dict colors: color mappings from logging level to terminal color + code + :arg str datefmt: Datetime format. + Used for formatting ``(asctime)`` placeholder in ``prefix_fmt``. + + .. versionchanged:: 3.2 + + Added ``fmt`` and ``datefmt`` arguments. + """ + logging.Formatter.__init__(self, datefmt=datefmt) + self._fmt = fmt + + self._colors = {} # type: Dict[int, str] + if color and _stderr_supports_color(): + if curses is not None: + fg_color = curses.tigetstr("setaf") or curses.tigetstr("setf") or b"" + + for levelno, code in colors.items(): + # Convert the terminal control characters from + # bytes to unicode strings for easier use with the + # logging module. + self._colors[levelno] = unicode_type( + curses.tparm(fg_color, code), "ascii" + ) + normal = curses.tigetstr("sgr0") + if normal is not None: + self._normal = unicode_type(normal, "ascii") + else: + self._normal = "" + else: + # If curses is not present (currently we'll only get here for + # colorama on windows), assume hard-coded ANSI color codes. + for levelno, code in colors.items(): + self._colors[levelno] = "\033[2;3%dm" % code + self._normal = "\033[0m" + else: + self._normal = "" + + def format(self, record: Any) -> str: + try: + message = record.getMessage() + assert isinstance(message, basestring_type) # guaranteed by logging + # Encoding notes: The logging module prefers to work with character + # strings, but only enforces that log messages are instances of + # basestring. In python 2, non-ascii bytestrings will make + # their way through the logging framework until they blow up with + # an unhelpful decoding error (with this formatter it happens + # when we attach the prefix, but there are other opportunities for + # exceptions further along in the framework). + # + # If a byte string makes it this far, convert it to unicode to + # ensure it will make it out to the logs. Use repr() as a fallback + # to ensure that all byte strings can be converted successfully, + # but don't do it by default so we don't add extra quotes to ascii + # bytestrings. This is a bit of a hacky place to do this, but + # it's worth it since the encoding errors that would otherwise + # result are so useless (and tornado is fond of using utf8-encoded + # byte strings wherever possible). + record.message = _safe_unicode(message) + except Exception as e: + record.message = "Bad message (%r): %r" % (e, record.__dict__) + + record.asctime = self.formatTime(record, cast(str, self.datefmt)) + + if record.levelno in self._colors: + record.color = self._colors[record.levelno] + record.end_color = self._normal + else: + record.color = record.end_color = "" + + formatted = self._fmt % record.__dict__ + + if record.exc_info: + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + # exc_text contains multiple lines. We need to _safe_unicode + # each line separately so that non-utf8 bytes don't cause + # all the newlines to turn into '\n'. + lines = [formatted.rstrip()] + lines.extend(_safe_unicode(ln) for ln in record.exc_text.split("\n")) + formatted = "\n".join(lines) + return formatted.replace("\n", "\n ") + + +def enable_pretty_logging( + options: Any = None, logger: Optional[logging.Logger] = None +) -> None: + """Turns on formatted logging output as configured. + + This is called automatically by `tornado.options.parse_command_line` + and `tornado.options.parse_config_file`. + """ + if options is None: + import tornado.options + + options = tornado.options.options + if options.logging is None or options.logging.lower() == "none": + return + if logger is None: + logger = logging.getLogger() + logger.setLevel(getattr(logging, options.logging.upper())) + if options.log_file_prefix: + rotate_mode = options.log_rotate_mode + if rotate_mode == "size": + channel = logging.handlers.RotatingFileHandler( + filename=options.log_file_prefix, + maxBytes=options.log_file_max_size, + backupCount=options.log_file_num_backups, + encoding="utf-8", + ) # type: logging.Handler + elif rotate_mode == "time": + channel = logging.handlers.TimedRotatingFileHandler( + filename=options.log_file_prefix, + when=options.log_rotate_when, + interval=options.log_rotate_interval, + backupCount=options.log_file_num_backups, + encoding="utf-8", + ) + else: + error_message = ( + "The value of log_rotate_mode option should be " + + '"size" or "time", not "%s".' % rotate_mode + ) + raise ValueError(error_message) + channel.setFormatter(LogFormatter(color=False)) + logger.addHandler(channel) + + if options.log_to_stderr or (options.log_to_stderr is None and not logger.handlers): + # Set up color if we are in a tty and curses is installed + channel = logging.StreamHandler() + channel.setFormatter(LogFormatter()) + logger.addHandler(channel) + + +def define_logging_options(options: Any = None) -> None: + """Add logging-related flags to ``options``. + + These options are present automatically on the default options instance; + this method is only necessary if you have created your own `.OptionParser`. + + .. versionadded:: 4.2 + This function existed in prior versions but was broken and undocumented until 4.2. + """ + if options is None: + # late import to prevent cycle + import tornado.options + + options = tornado.options.options + options.define( + "logging", + default="info", + help=( + "Set the Python log level. If 'none', tornado won't touch the " + "logging configuration." + ), + metavar="debug|info|warning|error|none", + ) + options.define( + "log_to_stderr", + type=bool, + default=None, + help=( + "Send log output to stderr (colorized if possible). " + "By default use stderr if --log_file_prefix is not set and " + "no other logging is configured." + ), + ) + options.define( + "log_file_prefix", + type=str, + default=None, + metavar="PATH", + help=( + "Path prefix for log files. " + "Note that if you are running multiple tornado processes, " + "log_file_prefix must be different for each of them (e.g. " + "include the port number)" + ), + ) + options.define( + "log_file_max_size", + type=int, + default=100 * 1000 * 1000, + help="max size of log files before rollover", + ) + options.define( + "log_file_num_backups", type=int, default=10, help="number of log files to keep" + ) + + options.define( + "log_rotate_when", + type=str, + default="midnight", + help=( + "specify the type of TimedRotatingFileHandler interval " + "other options:('S', 'M', 'H', 'D', 'W0'-'W6')" + ), + ) + options.define( + "log_rotate_interval", + type=int, + default=1, + help="The interval value of timed rotating", + ) + + options.define( + "log_rotate_mode", + type=str, + default="size", + help="The mode of rotating files(time or size)", + ) + + options.add_parse_callback(lambda: enable_pretty_logging(options)) diff --git a/venv/lib/python3.9/site-packages/tornado/netutil.py b/venv/lib/python3.9/site-packages/tornado/netutil.py new file mode 100644 index 00000000..04db085a --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/netutil.py @@ -0,0 +1,677 @@ +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Miscellaneous network utility code.""" + +import asyncio +import concurrent.futures +import errno +import os +import sys +import socket +import ssl +import stat + +from tornado.concurrent import dummy_executor, run_on_executor +from tornado.ioloop import IOLoop +from tornado.util import Configurable, errno_from_exception + +from typing import List, Callable, Any, Type, Dict, Union, Tuple, Awaitable, Optional + +# Note that the naming of ssl.Purpose is confusing; the purpose +# of a context is to authenticate the opposite side of the connection. +_client_ssl_defaults = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) +_server_ssl_defaults = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) +if hasattr(ssl, "OP_NO_COMPRESSION"): + # See netutil.ssl_options_to_context + _client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION + _server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION + +# ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode, +# getaddrinfo attempts to import encodings.idna. If this is done at +# module-import time, the import lock is already held by the main thread, +# leading to deadlock. Avoid it by caching the idna encoder on the main +# thread now. +"foo".encode("idna") + +# For undiagnosed reasons, 'latin1' codec may also need to be preloaded. +"foo".encode("latin1") + +# Default backlog used when calling sock.listen() +_DEFAULT_BACKLOG = 128 + + +def bind_sockets( + port: int, + address: Optional[str] = None, + family: socket.AddressFamily = socket.AF_UNSPEC, + backlog: int = _DEFAULT_BACKLOG, + flags: Optional[int] = None, + reuse_port: bool = False, +) -> List[socket.socket]: + """Creates listening sockets bound to the given port and address. + + Returns a list of socket objects (multiple sockets are returned if + the given address maps to multiple IP addresses, which is most common + for mixed IPv4 and IPv6 use). + + Address may be either an IP address or hostname. If it's a hostname, + the server will listen on all IP addresses associated with the + name. Address may be an empty string or None to listen on all + available interfaces. Family may be set to either `socket.AF_INET` + or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise + both will be used if available. + + The ``backlog`` argument has the same meaning as for + `socket.listen() <socket.socket.listen>`. + + ``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like + ``socket.AI_PASSIVE | socket.AI_NUMERICHOST``. + + ``reuse_port`` option sets ``SO_REUSEPORT`` option for every socket + in the list. If your platform doesn't support this option ValueError will + be raised. + """ + if reuse_port and not hasattr(socket, "SO_REUSEPORT"): + raise ValueError("the platform doesn't support SO_REUSEPORT") + + sockets = [] + if address == "": + address = None + if not socket.has_ipv6 and family == socket.AF_UNSPEC: + # Python can be compiled with --disable-ipv6, which causes + # operations on AF_INET6 sockets to fail, but does not + # automatically exclude those results from getaddrinfo + # results. + # http://bugs.python.org/issue16208 + family = socket.AF_INET + if flags is None: + flags = socket.AI_PASSIVE + bound_port = None + unique_addresses = set() # type: set + for res in sorted( + socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, 0, flags), + key=lambda x: x[0], + ): + if res in unique_addresses: + continue + + unique_addresses.add(res) + + af, socktype, proto, canonname, sockaddr = res + if ( + sys.platform == "darwin" + and address == "localhost" + and af == socket.AF_INET6 + and sockaddr[3] != 0 # type: ignore + ): + # Mac OS X includes a link-local address fe80::1%lo0 in the + # getaddrinfo results for 'localhost'. However, the firewall + # doesn't understand that this is a local address and will + # prompt for access (often repeatedly, due to an apparent + # bug in its ability to remember granting access to an + # application). Skip these addresses. + continue + try: + sock = socket.socket(af, socktype, proto) + except socket.error as e: + if errno_from_exception(e) == errno.EAFNOSUPPORT: + continue + raise + if os.name != "nt": + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + except socket.error as e: + if errno_from_exception(e) != errno.ENOPROTOOPT: + # Hurd doesn't support SO_REUSEADDR. + raise + if reuse_port: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + if af == socket.AF_INET6: + # On linux, ipv6 sockets accept ipv4 too by default, + # but this makes it impossible to bind to both + # 0.0.0.0 in ipv4 and :: in ipv6. On other systems, + # separate sockets *must* be used to listen for both ipv4 + # and ipv6. For consistency, always disable ipv4 on our + # ipv6 sockets and use a separate ipv4 socket when needed. + # + # Python 2.x on windows doesn't have IPPROTO_IPV6. + if hasattr(socket, "IPPROTO_IPV6"): + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) + + # automatic port allocation with port=None + # should bind on the same port on IPv4 and IPv6 + host, requested_port = sockaddr[:2] + if requested_port == 0 and bound_port is not None: + sockaddr = tuple([host, bound_port] + list(sockaddr[2:])) + + sock.setblocking(False) + try: + sock.bind(sockaddr) + except OSError as e: + if ( + errno_from_exception(e) == errno.EADDRNOTAVAIL + and address == "localhost" + and sockaddr[0] == "::1" + ): + # On some systems (most notably docker with default + # configurations), ipv6 is partially disabled: + # socket.has_ipv6 is true, we can create AF_INET6 + # sockets, and getaddrinfo("localhost", ..., + # AF_PASSIVE) resolves to ::1, but we get an error + # when binding. + # + # Swallow the error, but only for this specific case. + # If EADDRNOTAVAIL occurs in other situations, it + # might be a real problem like a typo in a + # configuration. + sock.close() + continue + else: + raise + bound_port = sock.getsockname()[1] + sock.listen(backlog) + sockets.append(sock) + return sockets + + +if hasattr(socket, "AF_UNIX"): + + def bind_unix_socket( + file: str, mode: int = 0o600, backlog: int = _DEFAULT_BACKLOG + ) -> socket.socket: + """Creates a listening unix socket. + + If a socket with the given name already exists, it will be deleted. + If any other file with that name exists, an exception will be + raised. + + Returns a socket object (not a list of socket objects like + `bind_sockets`) + """ + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + except socket.error as e: + if errno_from_exception(e) != errno.ENOPROTOOPT: + # Hurd doesn't support SO_REUSEADDR + raise + sock.setblocking(False) + try: + st = os.stat(file) + except FileNotFoundError: + pass + else: + if stat.S_ISSOCK(st.st_mode): + os.remove(file) + else: + raise ValueError("File %s exists and is not a socket", file) + sock.bind(file) + os.chmod(file, mode) + sock.listen(backlog) + return sock + + +def add_accept_handler( + sock: socket.socket, callback: Callable[[socket.socket, Any], None] +) -> Callable[[], None]: + """Adds an `.IOLoop` event handler to accept new connections on ``sock``. + + When a connection is accepted, ``callback(connection, address)`` will + be run (``connection`` is a socket object, and ``address`` is the + address of the other end of the connection). Note that this signature + is different from the ``callback(fd, events)`` signature used for + `.IOLoop` handlers. + + A callable is returned which, when called, will remove the `.IOLoop` + event handler and stop processing further incoming connections. + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. versionchanged:: 5.0 + A callable is returned (``None`` was returned before). + """ + io_loop = IOLoop.current() + removed = [False] + + def accept_handler(fd: socket.socket, events: int) -> None: + # More connections may come in while we're handling callbacks; + # to prevent starvation of other tasks we must limit the number + # of connections we accept at a time. Ideally we would accept + # up to the number of connections that were waiting when we + # entered this method, but this information is not available + # (and rearranging this method to call accept() as many times + # as possible before running any callbacks would have adverse + # effects on load balancing in multiprocess configurations). + # Instead, we use the (default) listen backlog as a rough + # heuristic for the number of connections we can reasonably + # accept at once. + for i in range(_DEFAULT_BACKLOG): + if removed[0]: + # The socket was probably closed + return + try: + connection, address = sock.accept() + except BlockingIOError: + # EWOULDBLOCK indicates we have accepted every + # connection that is available. + return + except ConnectionAbortedError: + # ECONNABORTED indicates that there was a connection + # but it was closed while still in the accept queue. + # (observed on FreeBSD). + continue + callback(connection, address) + + def remove_handler() -> None: + io_loop.remove_handler(sock) + removed[0] = True + + io_loop.add_handler(sock, accept_handler, IOLoop.READ) + return remove_handler + + +def is_valid_ip(ip: str) -> bool: + """Returns ``True`` if the given string is a well-formed IP address. + + Supports IPv4 and IPv6. + """ + if not ip or "\x00" in ip: + # getaddrinfo resolves empty strings to localhost, and truncates + # on zero bytes. + return False + try: + res = socket.getaddrinfo( + ip, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_NUMERICHOST + ) + return bool(res) + except socket.gaierror as e: + if e.args[0] == socket.EAI_NONAME: + return False + raise + except UnicodeError: + # `socket.getaddrinfo` will raise a UnicodeError from the + # `idna` decoder if the input is longer than 63 characters, + # even for socket.AI_NUMERICHOST. See + # https://bugs.python.org/issue32958 for discussion + return False + return True + + +class Resolver(Configurable): + """Configurable asynchronous DNS resolver interface. + + By default, a blocking implementation is used (which simply calls + `socket.getaddrinfo`). An alternative implementation can be + chosen with the `Resolver.configure <.Configurable.configure>` + class method:: + + Resolver.configure('tornado.netutil.ThreadedResolver') + + The implementations of this interface included with Tornado are + + * `tornado.netutil.DefaultLoopResolver` + * `tornado.netutil.DefaultExecutorResolver` (deprecated) + * `tornado.netutil.BlockingResolver` (deprecated) + * `tornado.netutil.ThreadedResolver` (deprecated) + * `tornado.netutil.OverrideResolver` + * `tornado.platform.twisted.TwistedResolver` (deprecated) + * `tornado.platform.caresresolver.CaresResolver` (deprecated) + + .. versionchanged:: 5.0 + The default implementation has changed from `BlockingResolver` to + `DefaultExecutorResolver`. + + .. versionchanged:: 6.2 + The default implementation has changed from `DefaultExecutorResolver` to + `DefaultLoopResolver`. + """ + + @classmethod + def configurable_base(cls) -> Type["Resolver"]: + return Resolver + + @classmethod + def configurable_default(cls) -> Type["Resolver"]: + return DefaultLoopResolver + + def resolve( + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> Awaitable[List[Tuple[int, Any]]]: + """Resolves an address. + + The ``host`` argument is a string which may be a hostname or a + literal IP address. + + Returns a `.Future` whose result is a list of (family, + address) pairs, where address is a tuple suitable to pass to + `socket.connect <socket.socket.connect>` (i.e. a ``(host, + port)`` pair for IPv4; additional fields may be present for + IPv6). If a ``callback`` is passed, it will be run with the + result as an argument when it is complete. + + :raises IOError: if the address cannot be resolved. + + .. versionchanged:: 4.4 + Standardized all implementations to raise `IOError`. + + .. versionchanged:: 6.0 The ``callback`` argument was removed. + Use the returned awaitable object instead. + + """ + raise NotImplementedError() + + def close(self) -> None: + """Closes the `Resolver`, freeing any resources used. + + .. versionadded:: 3.1 + + """ + pass + + +def _resolve_addr( + host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC +) -> List[Tuple[int, Any]]: + # On Solaris, getaddrinfo fails if the given port is not found + # in /etc/services and no socket type is given, so we must pass + # one here. The socket type used here doesn't seem to actually + # matter (we discard the one we get back in the results), + # so the addresses we return should still be usable with SOCK_DGRAM. + addrinfo = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM) + results = [] + for fam, socktype, proto, canonname, address in addrinfo: + results.append((fam, address)) + return results # type: ignore + + +class DefaultExecutorResolver(Resolver): + """Resolver implementation using `.IOLoop.run_in_executor`. + + .. versionadded:: 5.0 + + .. deprecated:: 6.2 + + Use `DefaultLoopResolver` instead. + """ + + async def resolve( + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> List[Tuple[int, Any]]: + result = await IOLoop.current().run_in_executor( + None, _resolve_addr, host, port, family + ) + return result + + +class DefaultLoopResolver(Resolver): + """Resolver implementation using `asyncio.loop.getaddrinfo`.""" + + async def resolve( + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> List[Tuple[int, Any]]: + # On Solaris, getaddrinfo fails if the given port is not found + # in /etc/services and no socket type is given, so we must pass + # one here. The socket type used here doesn't seem to actually + # matter (we discard the one we get back in the results), + # so the addresses we return should still be usable with SOCK_DGRAM. + return [ + (fam, address) + for fam, _, _, _, address in await asyncio.get_running_loop().getaddrinfo( + host, port, family=family, type=socket.SOCK_STREAM + ) + ] + + +class ExecutorResolver(Resolver): + """Resolver implementation using a `concurrent.futures.Executor`. + + Use this instead of `ThreadedResolver` when you require additional + control over the executor being used. + + The executor will be shut down when the resolver is closed unless + ``close_resolver=False``; use this if you want to reuse the same + executor elsewhere. + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. deprecated:: 5.0 + The default `Resolver` now uses `asyncio.loop.getaddrinfo`; + use that instead of this class. + """ + + def initialize( + self, + executor: Optional[concurrent.futures.Executor] = None, + close_executor: bool = True, + ) -> None: + if executor is not None: + self.executor = executor + self.close_executor = close_executor + else: + self.executor = dummy_executor + self.close_executor = False + + def close(self) -> None: + if self.close_executor: + self.executor.shutdown() + self.executor = None # type: ignore + + @run_on_executor + def resolve( + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> List[Tuple[int, Any]]: + return _resolve_addr(host, port, family) + + +class BlockingResolver(ExecutorResolver): + """Default `Resolver` implementation, using `socket.getaddrinfo`. + + The `.IOLoop` will be blocked during the resolution, although the + callback will not be run until the next `.IOLoop` iteration. + + .. deprecated:: 5.0 + The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead + of this class. + """ + + def initialize(self) -> None: # type: ignore + super().initialize() + + +class ThreadedResolver(ExecutorResolver): + """Multithreaded non-blocking `Resolver` implementation. + + Requires the `concurrent.futures` package to be installed + (available in the standard library since Python 3.2, + installable with ``pip install futures`` in older versions). + + The thread pool size can be configured with:: + + Resolver.configure('tornado.netutil.ThreadedResolver', + num_threads=10) + + .. versionchanged:: 3.1 + All ``ThreadedResolvers`` share a single thread pool, whose + size is set by the first one to be created. + + .. deprecated:: 5.0 + The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead + of this class. + """ + + _threadpool = None # type: ignore + _threadpool_pid = None # type: int + + def initialize(self, num_threads: int = 10) -> None: # type: ignore + threadpool = ThreadedResolver._create_threadpool(num_threads) + super().initialize(executor=threadpool, close_executor=False) + + @classmethod + def _create_threadpool( + cls, num_threads: int + ) -> concurrent.futures.ThreadPoolExecutor: + pid = os.getpid() + if cls._threadpool_pid != pid: + # Threads cannot survive after a fork, so if our pid isn't what it + # was when we created the pool then delete it. + cls._threadpool = None + if cls._threadpool is None: + cls._threadpool = concurrent.futures.ThreadPoolExecutor(num_threads) + cls._threadpool_pid = pid + return cls._threadpool + + +class OverrideResolver(Resolver): + """Wraps a resolver with a mapping of overrides. + + This can be used to make local DNS changes (e.g. for testing) + without modifying system-wide settings. + + The mapping can be in three formats:: + + { + # Hostname to host or ip + "example.com": "127.0.1.1", + + # Host+port to host+port + ("login.example.com", 443): ("localhost", 1443), + + # Host+port+address family to host+port + ("login.example.com", 443, socket.AF_INET6): ("::1", 1443), + } + + .. versionchanged:: 5.0 + Added support for host-port-family triplets. + """ + + def initialize(self, resolver: Resolver, mapping: dict) -> None: + self.resolver = resolver + self.mapping = mapping + + def close(self) -> None: + self.resolver.close() + + def resolve( + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> Awaitable[List[Tuple[int, Any]]]: + if (host, port, family) in self.mapping: + host, port = self.mapping[(host, port, family)] + elif (host, port) in self.mapping: + host, port = self.mapping[(host, port)] + elif host in self.mapping: + host = self.mapping[host] + return self.resolver.resolve(host, port, family) + + +# These are the keyword arguments to ssl.wrap_socket that must be translated +# to their SSLContext equivalents (the other arguments are still passed +# to SSLContext.wrap_socket). +_SSL_CONTEXT_KEYWORDS = frozenset( + ["ssl_version", "certfile", "keyfile", "cert_reqs", "ca_certs", "ciphers"] +) + + +def ssl_options_to_context( + ssl_options: Union[Dict[str, Any], ssl.SSLContext], + server_side: Optional[bool] = None, +) -> ssl.SSLContext: + """Try to convert an ``ssl_options`` dictionary to an + `~ssl.SSLContext` object. + + The ``ssl_options`` dictionary contains keywords to be passed to + `ssl.wrap_socket`. In Python 2.7.9+, `ssl.SSLContext` objects can + be used instead. This function converts the dict form to its + `~ssl.SSLContext` equivalent, and may be used when a component which + accepts both forms needs to upgrade to the `~ssl.SSLContext` version + to use features like SNI or NPN. + + .. versionchanged:: 6.2 + + Added server_side argument. Omitting this argument will + result in a DeprecationWarning on Python 3.10. + + """ + if isinstance(ssl_options, ssl.SSLContext): + return ssl_options + assert isinstance(ssl_options, dict) + assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options + # TODO: Now that we have the server_side argument, can we switch to + # create_default_context or would that change behavior? + default_version = ssl.PROTOCOL_TLS + if server_side: + default_version = ssl.PROTOCOL_TLS_SERVER + elif server_side is not None: + default_version = ssl.PROTOCOL_TLS_CLIENT + context = ssl.SSLContext(ssl_options.get("ssl_version", default_version)) + if "certfile" in ssl_options: + context.load_cert_chain( + ssl_options["certfile"], ssl_options.get("keyfile", None) + ) + if "cert_reqs" in ssl_options: + if ssl_options["cert_reqs"] == ssl.CERT_NONE: + # This may have been set automatically by PROTOCOL_TLS_CLIENT but is + # incompatible with CERT_NONE so we must manually clear it. + context.check_hostname = False + context.verify_mode = ssl_options["cert_reqs"] + if "ca_certs" in ssl_options: + context.load_verify_locations(ssl_options["ca_certs"]) + if "ciphers" in ssl_options: + context.set_ciphers(ssl_options["ciphers"]) + if hasattr(ssl, "OP_NO_COMPRESSION"): + # Disable TLS compression to avoid CRIME and related attacks. + # This constant depends on openssl version 1.0. + # TODO: Do we need to do this ourselves or can we trust + # the defaults? + context.options |= ssl.OP_NO_COMPRESSION + return context + + +def ssl_wrap_socket( + socket: socket.socket, + ssl_options: Union[Dict[str, Any], ssl.SSLContext], + server_hostname: Optional[str] = None, + server_side: Optional[bool] = None, + **kwargs: Any +) -> ssl.SSLSocket: + """Returns an ``ssl.SSLSocket`` wrapping the given socket. + + ``ssl_options`` may be either an `ssl.SSLContext` object or a + dictionary (as accepted by `ssl_options_to_context`). Additional + keyword arguments are passed to ``wrap_socket`` (either the + `~ssl.SSLContext` method or the `ssl` module function as + appropriate). + + .. versionchanged:: 6.2 + + Added server_side argument. Omitting this argument will + result in a DeprecationWarning on Python 3.10. + """ + context = ssl_options_to_context(ssl_options, server_side=server_side) + if server_side is None: + server_side = False + if ssl.HAS_SNI: + # In python 3.4, wrap_socket only accepts the server_hostname + # argument if HAS_SNI is true. + # TODO: add a unittest (python added server-side SNI support in 3.4) + # In the meantime it can be manually tested with + # python3 -m tornado.httpclient https://sni.velox.ch + return context.wrap_socket( + socket, server_hostname=server_hostname, server_side=server_side, **kwargs + ) + else: + return context.wrap_socket(socket, server_side=server_side, **kwargs) diff --git a/venv/lib/python3.9/site-packages/tornado/options.py b/venv/lib/python3.9/site-packages/tornado/options.py new file mode 100644 index 00000000..b8296691 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/options.py @@ -0,0 +1,750 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""A command line parsing module that lets modules define their own options. + +This module is inspired by Google's `gflags +<https://github.com/google/python-gflags>`_. The primary difference +with libraries such as `argparse` is that a global registry is used so +that options may be defined in any module (it also enables +`tornado.log` by default). The rest of Tornado does not depend on this +module, so feel free to use `argparse` or other configuration +libraries if you prefer them. + +Options must be defined with `tornado.options.define` before use, +generally at the top level of a module. The options are then +accessible as attributes of `tornado.options.options`:: + + # myapp/db.py + from tornado.options import define, options + + define("mysql_host", default="127.0.0.1:3306", help="Main user DB") + define("memcache_hosts", default="127.0.0.1:11011", multiple=True, + help="Main user memcache servers") + + def connect(): + db = database.Connection(options.mysql_host) + ... + + # myapp/server.py + from tornado.options import define, options + + define("port", default=8080, help="port to listen on") + + def start_server(): + app = make_app() + app.listen(options.port) + +The ``main()`` method of your application does not need to be aware of all of +the options used throughout your program; they are all automatically loaded +when the modules are loaded. However, all modules that define options +must have been imported before the command line is parsed. + +Your ``main()`` method can parse the command line or parse a config file with +either `parse_command_line` or `parse_config_file`:: + + import myapp.db, myapp.server + import tornado + + if __name__ == '__main__': + tornado.options.parse_command_line() + # or + tornado.options.parse_config_file("/etc/server.conf") + +.. note:: + + When using multiple ``parse_*`` functions, pass ``final=False`` to all + but the last one, or side effects may occur twice (in particular, + this can result in log messages being doubled). + +`tornado.options.options` is a singleton instance of `OptionParser`, and +the top-level functions in this module (`define`, `parse_command_line`, etc) +simply call methods on it. You may create additional `OptionParser` +instances to define isolated sets of options, such as for subcommands. + +.. note:: + + By default, several options are defined that will configure the + standard `logging` module when `parse_command_line` or `parse_config_file` + are called. If you want Tornado to leave the logging configuration + alone so you can manage it yourself, either pass ``--logging=none`` + on the command line or do the following to disable it in code:: + + from tornado.options import options, parse_command_line + options.logging = None + parse_command_line() + +.. note:: + + `parse_command_line` or `parse_config_file` function should called after + logging configuration and user-defined command line flags using the + ``callback`` option definition, or these configurations will not take effect. + +.. versionchanged:: 4.3 + Dashes and underscores are fully interchangeable in option names; + options can be defined, set, and read with any mix of the two. + Dashes are typical for command-line usage while config files require + underscores. +""" + +import datetime +import numbers +import re +import sys +import os +import textwrap + +from tornado.escape import _unicode, native_str +from tornado.log import define_logging_options +from tornado.util import basestring_type, exec_in + +from typing import ( + Any, + Iterator, + Iterable, + Tuple, + Set, + Dict, + Callable, + List, + TextIO, + Optional, +) + + +class Error(Exception): + """Exception raised by errors in the options module.""" + + pass + + +class OptionParser(object): + """A collection of options, a dictionary with object-like access. + + Normally accessed via static functions in the `tornado.options` module, + which reference a global instance. + """ + + def __init__(self) -> None: + # we have to use self.__dict__ because we override setattr. + self.__dict__["_options"] = {} + self.__dict__["_parse_callbacks"] = [] + self.define( + "help", + type=bool, + help="show this help information", + callback=self._help_callback, + ) + + def _normalize_name(self, name: str) -> str: + return name.replace("_", "-") + + def __getattr__(self, name: str) -> Any: + name = self._normalize_name(name) + if isinstance(self._options.get(name), _Option): + return self._options[name].value() + raise AttributeError("Unrecognized option %r" % name) + + def __setattr__(self, name: str, value: Any) -> None: + name = self._normalize_name(name) + if isinstance(self._options.get(name), _Option): + return self._options[name].set(value) + raise AttributeError("Unrecognized option %r" % name) + + def __iter__(self) -> Iterator: + return (opt.name for opt in self._options.values()) + + def __contains__(self, name: str) -> bool: + name = self._normalize_name(name) + return name in self._options + + def __getitem__(self, name: str) -> Any: + return self.__getattr__(name) + + def __setitem__(self, name: str, value: Any) -> None: + return self.__setattr__(name, value) + + def items(self) -> Iterable[Tuple[str, Any]]: + """An iterable of (name, value) pairs. + + .. versionadded:: 3.1 + """ + return [(opt.name, opt.value()) for name, opt in self._options.items()] + + def groups(self) -> Set[str]: + """The set of option-groups created by ``define``. + + .. versionadded:: 3.1 + """ + return set(opt.group_name for opt in self._options.values()) + + def group_dict(self, group: str) -> Dict[str, Any]: + """The names and values of options in a group. + + Useful for copying options into Application settings:: + + from tornado.options import define, parse_command_line, options + + define('template_path', group='application') + define('static_path', group='application') + + parse_command_line() + + application = Application( + handlers, **options.group_dict('application')) + + .. versionadded:: 3.1 + """ + return dict( + (opt.name, opt.value()) + for name, opt in self._options.items() + if not group or group == opt.group_name + ) + + def as_dict(self) -> Dict[str, Any]: + """The names and values of all options. + + .. versionadded:: 3.1 + """ + return dict((opt.name, opt.value()) for name, opt in self._options.items()) + + def define( + self, + name: str, + default: Any = None, + type: Optional[type] = None, + help: Optional[str] = None, + metavar: Optional[str] = None, + multiple: bool = False, + group: Optional[str] = None, + callback: Optional[Callable[[Any], None]] = None, + ) -> None: + """Defines a new command line option. + + ``type`` can be any of `str`, `int`, `float`, `bool`, + `~datetime.datetime`, or `~datetime.timedelta`. If no ``type`` + is given but a ``default`` is, ``type`` is the type of + ``default``. Otherwise, ``type`` defaults to `str`. + + If ``multiple`` is True, the option value is a list of ``type`` + instead of an instance of ``type``. + + ``help`` and ``metavar`` are used to construct the + automatically generated command line help string. The help + message is formatted like:: + + --name=METAVAR help string + + ``group`` is used to group the defined options in logical + groups. By default, command line options are grouped by the + file in which they are defined. + + Command line option names must be unique globally. + + If a ``callback`` is given, it will be run with the new value whenever + the option is changed. This can be used to combine command-line + and file-based options:: + + define("config", type=str, help="path to config file", + callback=lambda path: parse_config_file(path, final=False)) + + With this definition, options in the file specified by ``--config`` will + override options set earlier on the command line, but can be overridden + by later flags. + + """ + normalized = self._normalize_name(name) + if normalized in self._options: + raise Error( + "Option %r already defined in %s" + % (normalized, self._options[normalized].file_name) + ) + frame = sys._getframe(0) + if frame is not None: + options_file = frame.f_code.co_filename + + # Can be called directly, or through top level define() fn, in which + # case, step up above that frame to look for real caller. + if ( + frame.f_back is not None + and frame.f_back.f_code.co_filename == options_file + and frame.f_back.f_code.co_name == "define" + ): + frame = frame.f_back + + assert frame.f_back is not None + file_name = frame.f_back.f_code.co_filename + else: + file_name = "<unknown>" + if file_name == options_file: + file_name = "" + if type is None: + if not multiple and default is not None: + type = default.__class__ + else: + type = str + if group: + group_name = group # type: Optional[str] + else: + group_name = file_name + option = _Option( + name, + file_name=file_name, + default=default, + type=type, + help=help, + metavar=metavar, + multiple=multiple, + group_name=group_name, + callback=callback, + ) + self._options[normalized] = option + + def parse_command_line( + self, args: Optional[List[str]] = None, final: bool = True + ) -> List[str]: + """Parses all options given on the command line (defaults to + `sys.argv`). + + Options look like ``--option=value`` and are parsed according + to their ``type``. For boolean options, ``--option`` is + equivalent to ``--option=true`` + + If the option has ``multiple=True``, comma-separated values + are accepted. For multi-value integer options, the syntax + ``x:y`` is also accepted and equivalent to ``range(x, y)``. + + Note that ``args[0]`` is ignored since it is the program name + in `sys.argv`. + + We return a list of all arguments that are not parsed as options. + + If ``final`` is ``False``, parse callbacks will not be run. + This is useful for applications that wish to combine configurations + from multiple sources. + + """ + if args is None: + args = sys.argv + remaining = [] # type: List[str] + for i in range(1, len(args)): + # All things after the last option are command line arguments + if not args[i].startswith("-"): + remaining = args[i:] + break + if args[i] == "--": + remaining = args[i + 1 :] + break + arg = args[i].lstrip("-") + name, equals, value = arg.partition("=") + name = self._normalize_name(name) + if name not in self._options: + self.print_help() + raise Error("Unrecognized command line option: %r" % name) + option = self._options[name] + if not equals: + if option.type == bool: + value = "true" + else: + raise Error("Option %r requires a value" % name) + option.parse(value) + + if final: + self.run_parse_callbacks() + + return remaining + + def parse_config_file(self, path: str, final: bool = True) -> None: + """Parses and loads the config file at the given path. + + The config file contains Python code that will be executed (so + it is **not safe** to use untrusted config files). Anything in + the global namespace that matches a defined option will be + used to set that option's value. + + Options may either be the specified type for the option or + strings (in which case they will be parsed the same way as in + `.parse_command_line`) + + Example (using the options defined in the top-level docs of + this module):: + + port = 80 + mysql_host = 'mydb.example.com:3306' + # Both lists and comma-separated strings are allowed for + # multiple=True. + memcache_hosts = ['cache1.example.com:11011', + 'cache2.example.com:11011'] + memcache_hosts = 'cache1.example.com:11011,cache2.example.com:11011' + + If ``final`` is ``False``, parse callbacks will not be run. + This is useful for applications that wish to combine configurations + from multiple sources. + + .. note:: + + `tornado.options` is primarily a command-line library. + Config file support is provided for applications that wish + to use it, but applications that prefer config files may + wish to look at other libraries instead. + + .. versionchanged:: 4.1 + Config files are now always interpreted as utf-8 instead of + the system default encoding. + + .. versionchanged:: 4.4 + The special variable ``__file__`` is available inside config + files, specifying the absolute path to the config file itself. + + .. versionchanged:: 5.1 + Added the ability to set options via strings in config files. + + """ + config = {"__file__": os.path.abspath(path)} + with open(path, "rb") as f: + exec_in(native_str(f.read()), config, config) + for name in config: + normalized = self._normalize_name(name) + if normalized in self._options: + option = self._options[normalized] + if option.multiple: + if not isinstance(config[name], (list, str)): + raise Error( + "Option %r is required to be a list of %s " + "or a comma-separated string" + % (option.name, option.type.__name__) + ) + + if type(config[name]) == str and ( + option.type != str or option.multiple + ): + option.parse(config[name]) + else: + option.set(config[name]) + + if final: + self.run_parse_callbacks() + + def print_help(self, file: Optional[TextIO] = None) -> None: + """Prints all the command line options to stderr (or another file).""" + if file is None: + file = sys.stderr + print("Usage: %s [OPTIONS]" % sys.argv[0], file=file) + print("\nOptions:\n", file=file) + by_group = {} # type: Dict[str, List[_Option]] + for option in self._options.values(): + by_group.setdefault(option.group_name, []).append(option) + + for filename, o in sorted(by_group.items()): + if filename: + print("\n%s options:\n" % os.path.normpath(filename), file=file) + o.sort(key=lambda option: option.name) + for option in o: + # Always print names with dashes in a CLI context. + prefix = self._normalize_name(option.name) + if option.metavar: + prefix += "=" + option.metavar + description = option.help or "" + if option.default is not None and option.default != "": + description += " (default %s)" % option.default + lines = textwrap.wrap(description, 79 - 35) + if len(prefix) > 30 or len(lines) == 0: + lines.insert(0, "") + print(" --%-30s %s" % (prefix, lines[0]), file=file) + for line in lines[1:]: + print("%-34s %s" % (" ", line), file=file) + print(file=file) + + def _help_callback(self, value: bool) -> None: + if value: + self.print_help() + sys.exit(0) + + def add_parse_callback(self, callback: Callable[[], None]) -> None: + """Adds a parse callback, to be invoked when option parsing is done.""" + self._parse_callbacks.append(callback) + + def run_parse_callbacks(self) -> None: + for callback in self._parse_callbacks: + callback() + + def mockable(self) -> "_Mockable": + """Returns a wrapper around self that is compatible with + `mock.patch <unittest.mock.patch>`. + + The `mock.patch <unittest.mock.patch>` function (included in + the standard library `unittest.mock` package since Python 3.3, + or in the third-party ``mock`` package for older versions of + Python) is incompatible with objects like ``options`` that + override ``__getattr__`` and ``__setattr__``. This function + returns an object that can be used with `mock.patch.object + <unittest.mock.patch.object>` to modify option values:: + + with mock.patch.object(options.mockable(), 'name', value): + assert options.name == value + """ + return _Mockable(self) + + +class _Mockable(object): + """`mock.patch` compatible wrapper for `OptionParser`. + + As of ``mock`` version 1.0.1, when an object uses ``__getattr__`` + hooks instead of ``__dict__``, ``patch.__exit__`` tries to delete + the attribute it set instead of setting a new one (assuming that + the object does not capture ``__setattr__``, so the patch + created a new attribute in ``__dict__``). + + _Mockable's getattr and setattr pass through to the underlying + OptionParser, and delattr undoes the effect of a previous setattr. + """ + + def __init__(self, options: OptionParser) -> None: + # Modify __dict__ directly to bypass __setattr__ + self.__dict__["_options"] = options + self.__dict__["_originals"] = {} + + def __getattr__(self, name: str) -> Any: + return getattr(self._options, name) + + def __setattr__(self, name: str, value: Any) -> None: + assert name not in self._originals, "don't reuse mockable objects" + self._originals[name] = getattr(self._options, name) + setattr(self._options, name, value) + + def __delattr__(self, name: str) -> None: + setattr(self._options, name, self._originals.pop(name)) + + +class _Option(object): + # This class could almost be made generic, but the way the types + # interact with the multiple argument makes this tricky. (default + # and the callback use List[T], but type is still Type[T]). + UNSET = object() + + def __init__( + self, + name: str, + default: Any = None, + type: Optional[type] = None, + help: Optional[str] = None, + metavar: Optional[str] = None, + multiple: bool = False, + file_name: Optional[str] = None, + group_name: Optional[str] = None, + callback: Optional[Callable[[Any], None]] = None, + ) -> None: + if default is None and multiple: + default = [] + self.name = name + if type is None: + raise ValueError("type must not be None") + self.type = type + self.help = help + self.metavar = metavar + self.multiple = multiple + self.file_name = file_name + self.group_name = group_name + self.callback = callback + self.default = default + self._value = _Option.UNSET # type: Any + + def value(self) -> Any: + return self.default if self._value is _Option.UNSET else self._value + + def parse(self, value: str) -> Any: + _parse = { + datetime.datetime: self._parse_datetime, + datetime.timedelta: self._parse_timedelta, + bool: self._parse_bool, + basestring_type: self._parse_string, + }.get( + self.type, self.type + ) # type: Callable[[str], Any] + if self.multiple: + self._value = [] + for part in value.split(","): + if issubclass(self.type, numbers.Integral): + # allow ranges of the form X:Y (inclusive at both ends) + lo_str, _, hi_str = part.partition(":") + lo = _parse(lo_str) + hi = _parse(hi_str) if hi_str else lo + self._value.extend(range(lo, hi + 1)) + else: + self._value.append(_parse(part)) + else: + self._value = _parse(value) + if self.callback is not None: + self.callback(self._value) + return self.value() + + def set(self, value: Any) -> None: + if self.multiple: + if not isinstance(value, list): + raise Error( + "Option %r is required to be a list of %s" + % (self.name, self.type.__name__) + ) + for item in value: + if item is not None and not isinstance(item, self.type): + raise Error( + "Option %r is required to be a list of %s" + % (self.name, self.type.__name__) + ) + else: + if value is not None and not isinstance(value, self.type): + raise Error( + "Option %r is required to be a %s (%s given)" + % (self.name, self.type.__name__, type(value)) + ) + self._value = value + if self.callback is not None: + self.callback(self._value) + + # Supported date/time formats in our options + _DATETIME_FORMATS = [ + "%a %b %d %H:%M:%S %Y", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y-%m-%dT%H:%M", + "%Y%m%d %H:%M:%S", + "%Y%m%d %H:%M", + "%Y-%m-%d", + "%Y%m%d", + "%H:%M:%S", + "%H:%M", + ] + + def _parse_datetime(self, value: str) -> datetime.datetime: + for format in self._DATETIME_FORMATS: + try: + return datetime.datetime.strptime(value, format) + except ValueError: + pass + raise Error("Unrecognized date/time format: %r" % value) + + _TIMEDELTA_ABBREV_DICT = { + "h": "hours", + "m": "minutes", + "min": "minutes", + "s": "seconds", + "sec": "seconds", + "ms": "milliseconds", + "us": "microseconds", + "d": "days", + "w": "weeks", + } + + _FLOAT_PATTERN = r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?" + + _TIMEDELTA_PATTERN = re.compile( + r"\s*(%s)\s*(\w*)\s*" % _FLOAT_PATTERN, re.IGNORECASE + ) + + def _parse_timedelta(self, value: str) -> datetime.timedelta: + try: + sum = datetime.timedelta() + start = 0 + while start < len(value): + m = self._TIMEDELTA_PATTERN.match(value, start) + if not m: + raise Exception() + num = float(m.group(1)) + units = m.group(2) or "seconds" + units = self._TIMEDELTA_ABBREV_DICT.get(units, units) + # This line confuses mypy when setup.py sets python_version=3.6 + # https://github.com/python/mypy/issues/9676 + sum += datetime.timedelta(**{units: num}) # type: ignore + start = m.end() + return sum + except Exception: + raise + + def _parse_bool(self, value: str) -> bool: + return value.lower() not in ("false", "0", "f") + + def _parse_string(self, value: str) -> str: + return _unicode(value) + + +options = OptionParser() +"""Global options object. + +All defined options are available as attributes on this object. +""" + + +def define( + name: str, + default: Any = None, + type: Optional[type] = None, + help: Optional[str] = None, + metavar: Optional[str] = None, + multiple: bool = False, + group: Optional[str] = None, + callback: Optional[Callable[[Any], None]] = None, +) -> None: + """Defines an option in the global namespace. + + See `OptionParser.define`. + """ + return options.define( + name, + default=default, + type=type, + help=help, + metavar=metavar, + multiple=multiple, + group=group, + callback=callback, + ) + + +def parse_command_line( + args: Optional[List[str]] = None, final: bool = True +) -> List[str]: + """Parses global options from the command line. + + See `OptionParser.parse_command_line`. + """ + return options.parse_command_line(args, final=final) + + +def parse_config_file(path: str, final: bool = True) -> None: + """Parses global options from a config file. + + See `OptionParser.parse_config_file`. + """ + return options.parse_config_file(path, final=final) + + +def print_help(file: Optional[TextIO] = None) -> None: + """Prints all the command line options to stderr (or another file). + + See `OptionParser.print_help`. + """ + return options.print_help(file) + + +def add_parse_callback(callback: Callable[[], None]) -> None: + """Adds a parse callback, to be invoked when option parsing is done. + + See `OptionParser.add_parse_callback` + """ + options.add_parse_callback(callback) + + +# Default options +define_logging_options(options) diff --git a/venv/lib/python3.9/site-packages/tornado/platform/__init__.py b/venv/lib/python3.9/site-packages/tornado/platform/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/platform/__init__.py diff --git a/venv/lib/python3.9/site-packages/tornado/platform/asyncio.py b/venv/lib/python3.9/site-packages/tornado/platform/asyncio.py new file mode 100644 index 00000000..a15a74df --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/platform/asyncio.py @@ -0,0 +1,662 @@ +"""Bridges between the `asyncio` module and Tornado IOLoop. + +.. versionadded:: 3.2 + +This module integrates Tornado with the ``asyncio`` module introduced +in Python 3.4. This makes it possible to combine the two libraries on +the same event loop. + +.. deprecated:: 5.0 + + While the code in this module is still used, it is now enabled + automatically when `asyncio` is available, so applications should + no longer need to refer to this module directly. + +.. note:: + + Tornado is designed to use a selector-based event loop. On Windows, + where a proactor-based event loop has been the default since Python 3.8, + a selector event loop is emulated by running ``select`` on a separate thread. + Configuring ``asyncio`` to use a selector event loop may improve performance + of Tornado (but may reduce performance of other ``asyncio``-based libraries + in the same process). +""" + +import asyncio +import atexit +import concurrent.futures +import errno +import functools +import select +import socket +import sys +import threading +import typing +import warnings +from tornado.gen import convert_yielded +from tornado.ioloop import IOLoop, _Selectable + +from typing import Any, TypeVar, Awaitable, Callable, Union, Optional, List, Dict + +if typing.TYPE_CHECKING: + from typing import Set, Tuple # noqa: F401 + from typing_extensions import Protocol + + class _HasFileno(Protocol): + def fileno(self) -> int: + pass + + _FileDescriptorLike = Union[int, _HasFileno] + +_T = TypeVar("_T") + + +# Collection of selector thread event loops to shut down on exit. +_selector_loops = set() # type: Set[AddThreadSelectorEventLoop] + + +def _atexit_callback() -> None: + for loop in _selector_loops: + with loop._select_cond: + loop._closing_selector = True + loop._select_cond.notify() + try: + loop._waker_w.send(b"a") + except BlockingIOError: + pass + # If we don't join our (daemon) thread here, we may get a deadlock + # during interpreter shutdown. I don't really understand why. This + # deadlock happens every time in CI (both travis and appveyor) but + # I've never been able to reproduce locally. + loop._thread.join() + _selector_loops.clear() + + +atexit.register(_atexit_callback) + + +class BaseAsyncIOLoop(IOLoop): + def initialize( # type: ignore + self, asyncio_loop: asyncio.AbstractEventLoop, **kwargs: Any + ) -> None: + # asyncio_loop is always the real underlying IOLoop. This is used in + # ioloop.py to maintain the asyncio-to-ioloop mappings. + self.asyncio_loop = asyncio_loop + # selector_loop is an event loop that implements the add_reader family of + # methods. Usually the same as asyncio_loop but differs on platforms such + # as windows where the default event loop does not implement these methods. + self.selector_loop = asyncio_loop + if hasattr(asyncio, "ProactorEventLoop") and isinstance( + asyncio_loop, asyncio.ProactorEventLoop # type: ignore + ): + # Ignore this line for mypy because the abstract method checker + # doesn't understand dynamic proxies. + self.selector_loop = AddThreadSelectorEventLoop(asyncio_loop) # type: ignore + # Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler) + self.handlers = {} # type: Dict[int, Tuple[Union[int, _Selectable], Callable]] + # Set of fds listening for reads/writes + self.readers = set() # type: Set[int] + self.writers = set() # type: Set[int] + self.closing = False + # If an asyncio loop was closed through an asyncio interface + # instead of IOLoop.close(), we'd never hear about it and may + # have left a dangling reference in our map. In case an + # application (or, more likely, a test suite) creates and + # destroys a lot of event loops in this way, check here to + # ensure that we don't have a lot of dead loops building up in + # the map. + # + # TODO(bdarnell): consider making self.asyncio_loop a weakref + # for AsyncIOMainLoop and make _ioloop_for_asyncio a + # WeakKeyDictionary. + for loop in IOLoop._ioloop_for_asyncio.copy(): + if loop.is_closed(): + try: + del IOLoop._ioloop_for_asyncio[loop] + except KeyError: + pass + + # Make sure we don't already have an IOLoop for this asyncio loop + existing_loop = IOLoop._ioloop_for_asyncio.setdefault(asyncio_loop, self) + if existing_loop is not self: + raise RuntimeError( + f"IOLoop {existing_loop} already associated with asyncio loop {asyncio_loop}" + ) + + super().initialize(**kwargs) + + def close(self, all_fds: bool = False) -> None: + self.closing = True + for fd in list(self.handlers): + fileobj, handler_func = self.handlers[fd] + self.remove_handler(fd) + if all_fds: + self.close_fd(fileobj) + # Remove the mapping before closing the asyncio loop. If this + # happened in the other order, we could race against another + # initialize() call which would see the closed asyncio loop, + # assume it was closed from the asyncio side, and do this + # cleanup for us, leading to a KeyError. + del IOLoop._ioloop_for_asyncio[self.asyncio_loop] + if self.selector_loop is not self.asyncio_loop: + self.selector_loop.close() + self.asyncio_loop.close() + + def add_handler( + self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int + ) -> None: + fd, fileobj = self.split_fd(fd) + if fd in self.handlers: + raise ValueError("fd %s added twice" % fd) + self.handlers[fd] = (fileobj, handler) + if events & IOLoop.READ: + self.selector_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ) + self.readers.add(fd) + if events & IOLoop.WRITE: + self.selector_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE) + self.writers.add(fd) + + def update_handler(self, fd: Union[int, _Selectable], events: int) -> None: + fd, fileobj = self.split_fd(fd) + if events & IOLoop.READ: + if fd not in self.readers: + self.selector_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ) + self.readers.add(fd) + else: + if fd in self.readers: + self.selector_loop.remove_reader(fd) + self.readers.remove(fd) + if events & IOLoop.WRITE: + if fd not in self.writers: + self.selector_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE) + self.writers.add(fd) + else: + if fd in self.writers: + self.selector_loop.remove_writer(fd) + self.writers.remove(fd) + + def remove_handler(self, fd: Union[int, _Selectable]) -> None: + fd, fileobj = self.split_fd(fd) + if fd not in self.handlers: + return + if fd in self.readers: + self.selector_loop.remove_reader(fd) + self.readers.remove(fd) + if fd in self.writers: + self.selector_loop.remove_writer(fd) + self.writers.remove(fd) + del self.handlers[fd] + + def _handle_events(self, fd: int, events: int) -> None: + fileobj, handler_func = self.handlers[fd] + handler_func(fileobj, events) + + def start(self) -> None: + self.asyncio_loop.run_forever() + + def stop(self) -> None: + self.asyncio_loop.stop() + + def call_at( + self, when: float, callback: Callable, *args: Any, **kwargs: Any + ) -> object: + # asyncio.call_at supports *args but not **kwargs, so bind them here. + # We do not synchronize self.time and asyncio_loop.time, so + # convert from absolute to relative. + return self.asyncio_loop.call_later( + max(0, when - self.time()), + self._run_callback, + functools.partial(callback, *args, **kwargs), + ) + + def remove_timeout(self, timeout: object) -> None: + timeout.cancel() # type: ignore + + def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None: + try: + if asyncio.get_running_loop() is self.asyncio_loop: + call_soon = self.asyncio_loop.call_soon + else: + call_soon = self.asyncio_loop.call_soon_threadsafe + except RuntimeError: + call_soon = self.asyncio_loop.call_soon_threadsafe + + try: + call_soon(self._run_callback, functools.partial(callback, *args, **kwargs)) + except RuntimeError: + # "Event loop is closed". Swallow the exception for + # consistency with PollIOLoop (and logical consistency + # with the fact that we can't guarantee that an + # add_callback that completes without error will + # eventually execute). + pass + except AttributeError: + # ProactorEventLoop may raise this instead of RuntimeError + # if call_soon_threadsafe races with a call to close(). + # Swallow it too for consistency. + pass + + def add_callback_from_signal( + self, callback: Callable, *args: Any, **kwargs: Any + ) -> None: + try: + self.asyncio_loop.call_soon_threadsafe( + self._run_callback, functools.partial(callback, *args, **kwargs) + ) + except RuntimeError: + pass + + def run_in_executor( + self, + executor: Optional[concurrent.futures.Executor], + func: Callable[..., _T], + *args: Any, + ) -> Awaitable[_T]: + return self.asyncio_loop.run_in_executor(executor, func, *args) + + def set_default_executor(self, executor: concurrent.futures.Executor) -> None: + return self.asyncio_loop.set_default_executor(executor) + + +class AsyncIOMainLoop(BaseAsyncIOLoop): + """``AsyncIOMainLoop`` creates an `.IOLoop` that corresponds to the + current ``asyncio`` event loop (i.e. the one returned by + ``asyncio.get_event_loop()``). + + .. deprecated:: 5.0 + + Now used automatically when appropriate; it is no longer necessary + to refer to this class directly. + + .. versionchanged:: 5.0 + + Closing an `AsyncIOMainLoop` now closes the underlying asyncio loop. + """ + + def initialize(self, **kwargs: Any) -> None: # type: ignore + super().initialize(asyncio.get_event_loop(), **kwargs) + + def _make_current(self) -> None: + # AsyncIOMainLoop already refers to the current asyncio loop so + # nothing to do here. + pass + + +class AsyncIOLoop(BaseAsyncIOLoop): + """``AsyncIOLoop`` is an `.IOLoop` that runs on an ``asyncio`` event loop. + This class follows the usual Tornado semantics for creating new + ``IOLoops``; these loops are not necessarily related to the + ``asyncio`` default event loop. + + Each ``AsyncIOLoop`` creates a new ``asyncio.EventLoop``; this object + can be accessed with the ``asyncio_loop`` attribute. + + .. versionchanged:: 6.2 + + Support explicit ``asyncio_loop`` argument + for specifying the asyncio loop to attach to, + rather than always creating a new one with the default policy. + + .. versionchanged:: 5.0 + + When an ``AsyncIOLoop`` becomes the current `.IOLoop`, it also sets + the current `asyncio` event loop. + + .. deprecated:: 5.0 + + Now used automatically when appropriate; it is no longer necessary + to refer to this class directly. + """ + + def initialize(self, **kwargs: Any) -> None: # type: ignore + self.is_current = False + loop = None + if "asyncio_loop" not in kwargs: + kwargs["asyncio_loop"] = loop = asyncio.new_event_loop() + try: + super().initialize(**kwargs) + except Exception: + # If initialize() does not succeed (taking ownership of the loop), + # we have to close it. + if loop is not None: + loop.close() + raise + + def close(self, all_fds: bool = False) -> None: + if self.is_current: + self._clear_current() + super().close(all_fds=all_fds) + + def _make_current(self) -> None: + if not self.is_current: + try: + self.old_asyncio = asyncio.get_event_loop() + except (RuntimeError, AssertionError): + self.old_asyncio = None # type: ignore + self.is_current = True + asyncio.set_event_loop(self.asyncio_loop) + + def _clear_current_hook(self) -> None: + if self.is_current: + asyncio.set_event_loop(self.old_asyncio) + self.is_current = False + + +def to_tornado_future(asyncio_future: asyncio.Future) -> asyncio.Future: + """Convert an `asyncio.Future` to a `tornado.concurrent.Future`. + + .. versionadded:: 4.1 + + .. deprecated:: 5.0 + Tornado ``Futures`` have been merged with `asyncio.Future`, + so this method is now a no-op. + """ + return asyncio_future + + +def to_asyncio_future(tornado_future: asyncio.Future) -> asyncio.Future: + """Convert a Tornado yieldable object to an `asyncio.Future`. + + .. versionadded:: 4.1 + + .. versionchanged:: 4.3 + Now accepts any yieldable object, not just + `tornado.concurrent.Future`. + + .. deprecated:: 5.0 + Tornado ``Futures`` have been merged with `asyncio.Future`, + so this method is now equivalent to `tornado.gen.convert_yielded`. + """ + return convert_yielded(tornado_future) + + +if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + # "Any thread" and "selector" should be orthogonal, but there's not a clean + # interface for composing policies so pick the right base. + _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore +else: + _BasePolicy = asyncio.DefaultEventLoopPolicy + + +class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore + """Event loop policy that allows loop creation on any thread. + + The default `asyncio` event loop policy only automatically creates + event loops in the main threads. Other threads must create event + loops explicitly or `asyncio.get_event_loop` (and therefore + `.IOLoop.current`) will fail. Installing this policy allows event + loops to be created automatically on any thread, matching the + behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). + + Usage:: + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + + .. versionadded:: 5.0 + + .. deprecated:: 6.2 + + ``AnyThreadEventLoopPolicy`` affects the implicit creation + of an event loop, which is deprecated in Python 3.10 and + will be removed in a future version of Python. At that time + ``AnyThreadEventLoopPolicy`` will no longer be useful. + If you are relying on it, use `asyncio.new_event_loop` + or `asyncio.run` explicitly in any non-main threads that + need event loops. + """ + + def __init__(self) -> None: + super().__init__() + warnings.warn( + "AnyThreadEventLoopPolicy is deprecated, use asyncio.run " + "or asyncio.new_event_loop instead", + DeprecationWarning, + stacklevel=2, + ) + + def get_event_loop(self) -> asyncio.AbstractEventLoop: + try: + return super().get_event_loop() + except (RuntimeError, AssertionError): + # This was an AssertionError in Python 3.4.2 (which ships with Debian Jessie) + # and changed to a RuntimeError in 3.4.3. + # "There is no current event loop in thread %r" + loop = self.new_event_loop() + self.set_event_loop(loop) + return loop + + +class AddThreadSelectorEventLoop(asyncio.AbstractEventLoop): + """Wrap an event loop to add implementations of the ``add_reader`` method family. + + Instances of this class start a second thread to run a selector. + This thread is completely hidden from the user; all callbacks are + run on the wrapped event loop's thread. + + This class is used automatically by Tornado; applications should not need + to refer to it directly. + + It is safe to wrap any event loop with this class, although it only makes sense + for event loops that do not implement the ``add_reader`` family of methods + themselves (i.e. ``WindowsProactorEventLoop``) + + Closing the ``AddThreadSelectorEventLoop`` also closes the wrapped event loop. + + """ + + # This class is a __getattribute__-based proxy. All attributes other than those + # in this set are proxied through to the underlying loop. + MY_ATTRIBUTES = { + "_consume_waker", + "_select_cond", + "_select_args", + "_closing_selector", + "_thread", + "_handle_event", + "_readers", + "_real_loop", + "_start_select", + "_run_select", + "_handle_select", + "_wake_selector", + "_waker_r", + "_waker_w", + "_writers", + "add_reader", + "add_writer", + "close", + "remove_reader", + "remove_writer", + } + + def __getattribute__(self, name: str) -> Any: + if name in AddThreadSelectorEventLoop.MY_ATTRIBUTES: + return super().__getattribute__(name) + return getattr(self._real_loop, name) + + def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None: + self._real_loop = real_loop + + # Create a thread to run the select system call. We manage this thread + # manually so we can trigger a clean shutdown from an atexit hook. Note + # that due to the order of operations at shutdown, only daemon threads + # can be shut down in this way (non-daemon threads would require the + # introduction of a new hook: https://bugs.python.org/issue41962) + self._select_cond = threading.Condition() + self._select_args = ( + None + ) # type: Optional[Tuple[List[_FileDescriptorLike], List[_FileDescriptorLike]]] + self._closing_selector = False + self._thread = threading.Thread( + name="Tornado selector", + daemon=True, + target=self._run_select, + ) + self._thread.start() + # Start the select loop once the loop is started. + self._real_loop.call_soon(self._start_select) + + self._readers = {} # type: Dict[_FileDescriptorLike, Callable] + self._writers = {} # type: Dict[_FileDescriptorLike, Callable] + + # Writing to _waker_w will wake up the selector thread, which + # watches for _waker_r to be readable. + self._waker_r, self._waker_w = socket.socketpair() + self._waker_r.setblocking(False) + self._waker_w.setblocking(False) + _selector_loops.add(self) + self.add_reader(self._waker_r, self._consume_waker) + + def __del__(self) -> None: + # If the top-level application code uses asyncio interfaces to + # start and stop the event loop, no objects created in Tornado + # can get a clean shutdown notification. If we're just left to + # be GC'd, we must explicitly close our sockets to avoid + # logging warnings. + _selector_loops.discard(self) + self._waker_r.close() + self._waker_w.close() + + def close(self) -> None: + with self._select_cond: + self._closing_selector = True + self._select_cond.notify() + self._wake_selector() + self._thread.join() + _selector_loops.discard(self) + self._waker_r.close() + self._waker_w.close() + self._real_loop.close() + + def _wake_selector(self) -> None: + try: + self._waker_w.send(b"a") + except BlockingIOError: + pass + + def _consume_waker(self) -> None: + try: + self._waker_r.recv(1024) + except BlockingIOError: + pass + + def _start_select(self) -> None: + # Capture reader and writer sets here in the event loop + # thread to avoid any problems with concurrent + # modification while the select loop uses them. + with self._select_cond: + assert self._select_args is None + self._select_args = (list(self._readers.keys()), list(self._writers.keys())) + self._select_cond.notify() + + def _run_select(self) -> None: + while True: + with self._select_cond: + while self._select_args is None and not self._closing_selector: + self._select_cond.wait() + if self._closing_selector: + return + assert self._select_args is not None + to_read, to_write = self._select_args + self._select_args = None + + # We use the simpler interface of the select module instead of + # the more stateful interface in the selectors module because + # this class is only intended for use on windows, where + # select.select is the only option. The selector interface + # does not have well-documented thread-safety semantics that + # we can rely on so ensuring proper synchronization would be + # tricky. + try: + # On windows, selecting on a socket for write will not + # return the socket when there is an error (but selecting + # for reads works). Also select for errors when selecting + # for writes, and merge the results. + # + # This pattern is also used in + # https://github.com/python/cpython/blob/v3.8.0/Lib/selectors.py#L312-L317 + rs, ws, xs = select.select(to_read, to_write, to_write) + ws = ws + xs + except OSError as e: + # After remove_reader or remove_writer is called, the file + # descriptor may subsequently be closed on the event loop + # thread. It's possible that this select thread hasn't + # gotten into the select system call by the time that + # happens in which case (at least on macOS), select may + # raise a "bad file descriptor" error. If we get that + # error, check and see if we're also being woken up by + # polling the waker alone. If we are, just return to the + # event loop and we'll get the updated set of file + # descriptors on the next iteration. Otherwise, raise the + # original error. + if e.errno == getattr(errno, "WSAENOTSOCK", errno.EBADF): + rs, _, _ = select.select([self._waker_r.fileno()], [], [], 0) + if rs: + ws = [] + else: + raise + else: + raise + + try: + self._real_loop.call_soon_threadsafe(self._handle_select, rs, ws) + except RuntimeError: + # "Event loop is closed". Swallow the exception for + # consistency with PollIOLoop (and logical consistency + # with the fact that we can't guarantee that an + # add_callback that completes without error will + # eventually execute). + pass + except AttributeError: + # ProactorEventLoop may raise this instead of RuntimeError + # if call_soon_threadsafe races with a call to close(). + # Swallow it too for consistency. + pass + + def _handle_select( + self, rs: List["_FileDescriptorLike"], ws: List["_FileDescriptorLike"] + ) -> None: + for r in rs: + self._handle_event(r, self._readers) + for w in ws: + self._handle_event(w, self._writers) + self._start_select() + + def _handle_event( + self, + fd: "_FileDescriptorLike", + cb_map: Dict["_FileDescriptorLike", Callable], + ) -> None: + try: + callback = cb_map[fd] + except KeyError: + return + callback() + + def add_reader( + self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any + ) -> None: + self._readers[fd] = functools.partial(callback, *args) + self._wake_selector() + + def add_writer( + self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any + ) -> None: + self._writers[fd] = functools.partial(callback, *args) + self._wake_selector() + + def remove_reader(self, fd: "_FileDescriptorLike") -> bool: + try: + del self._readers[fd] + except KeyError: + return False + self._wake_selector() + return True + + def remove_writer(self, fd: "_FileDescriptorLike") -> bool: + try: + del self._writers[fd] + except KeyError: + return False + self._wake_selector() + return True diff --git a/venv/lib/python3.9/site-packages/tornado/platform/caresresolver.py b/venv/lib/python3.9/site-packages/tornado/platform/caresresolver.py new file mode 100644 index 00000000..1ba45c9a --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/platform/caresresolver.py @@ -0,0 +1,94 @@ +import pycares # type: ignore +import socket + +from tornado.concurrent import Future +from tornado import gen +from tornado.ioloop import IOLoop +from tornado.netutil import Resolver, is_valid_ip + +import typing + +if typing.TYPE_CHECKING: + from typing import Generator, Any, List, Tuple, Dict # noqa: F401 + + +class CaresResolver(Resolver): + """Name resolver based on the c-ares library. + + This is a non-blocking and non-threaded resolver. It may not produce the + same results as the system resolver, but can be used for non-blocking + resolution when threads cannot be used. + + ``pycares`` will not return a mix of ``AF_INET`` and ``AF_INET6`` when + ``family`` is ``AF_UNSPEC``, so it is only recommended for use in + ``AF_INET`` (i.e. IPv4). This is the default for + ``tornado.simple_httpclient``, but other libraries may default to + ``AF_UNSPEC``. + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. deprecated:: 6.2 + This class is deprecated and will be removed in Tornado 7.0. Use the default + thread-based resolver instead. + """ + + def initialize(self) -> None: + self.io_loop = IOLoop.current() + self.channel = pycares.Channel(sock_state_cb=self._sock_state_cb) + self.fds = {} # type: Dict[int, int] + + def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None: + state = (IOLoop.READ if readable else 0) | (IOLoop.WRITE if writable else 0) + if not state: + self.io_loop.remove_handler(fd) + del self.fds[fd] + elif fd in self.fds: + self.io_loop.update_handler(fd, state) + self.fds[fd] = state + else: + self.io_loop.add_handler(fd, self._handle_events, state) + self.fds[fd] = state + + def _handle_events(self, fd: int, events: int) -> None: + read_fd = pycares.ARES_SOCKET_BAD + write_fd = pycares.ARES_SOCKET_BAD + if events & IOLoop.READ: + read_fd = fd + if events & IOLoop.WRITE: + write_fd = fd + self.channel.process_fd(read_fd, write_fd) + + @gen.coroutine + def resolve( + self, host: str, port: int, family: int = 0 + ) -> "Generator[Any, Any, List[Tuple[int, Any]]]": + if is_valid_ip(host): + addresses = [host] + else: + # gethostbyname doesn't take callback as a kwarg + fut = Future() # type: Future[Tuple[Any, Any]] + self.channel.gethostbyname( + host, family, lambda result, error: fut.set_result((result, error)) + ) + result, error = yield fut + if error: + raise IOError( + "C-Ares returned error %s: %s while resolving %s" + % (error, pycares.errno.strerror(error), host) + ) + addresses = result.addresses + addrinfo = [] + for address in addresses: + if "." in address: + address_family = socket.AF_INET + elif ":" in address: + address_family = socket.AF_INET6 + else: + address_family = socket.AF_UNSPEC + if family != socket.AF_UNSPEC and family != address_family: + raise IOError( + "Requested socket family %d but got %d" % (family, address_family) + ) + addrinfo.append((typing.cast(int, address_family), (address, port))) + return addrinfo diff --git a/venv/lib/python3.9/site-packages/tornado/platform/twisted.py b/venv/lib/python3.9/site-packages/tornado/platform/twisted.py new file mode 100644 index 00000000..153fe436 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/platform/twisted.py @@ -0,0 +1,150 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""Bridges between the Twisted package and Tornado. +""" + +import socket +import sys + +import twisted.internet.abstract # type: ignore +import twisted.internet.asyncioreactor # type: ignore +from twisted.internet.defer import Deferred # type: ignore +from twisted.python import failure # type: ignore +import twisted.names.cache # type: ignore +import twisted.names.client # type: ignore +import twisted.names.hosts # type: ignore +import twisted.names.resolve # type: ignore + + +from tornado.concurrent import Future, future_set_exc_info +from tornado.escape import utf8 +from tornado import gen +from tornado.netutil import Resolver + +import typing + +if typing.TYPE_CHECKING: + from typing import Generator, Any, List, Tuple # noqa: F401 + + +class TwistedResolver(Resolver): + """Twisted-based asynchronous resolver. + + This is a non-blocking and non-threaded resolver. It is + recommended only when threads cannot be used, since it has + limitations compared to the standard ``getaddrinfo``-based + `~tornado.netutil.Resolver` and + `~tornado.netutil.DefaultExecutorResolver`. Specifically, it returns at + most one result, and arguments other than ``host`` and ``family`` + are ignored. It may fail to resolve when ``family`` is not + ``socket.AF_UNSPEC``. + + Requires Twisted 12.1 or newer. + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. deprecated:: 6.2 + This class is deprecated and will be removed in Tornado 7.0. Use the default + thread-based resolver instead. + """ + + def initialize(self) -> None: + # partial copy of twisted.names.client.createResolver, which doesn't + # allow for a reactor to be passed in. + self.reactor = twisted.internet.asyncioreactor.AsyncioSelectorReactor() + + host_resolver = twisted.names.hosts.Resolver("/etc/hosts") + cache_resolver = twisted.names.cache.CacheResolver(reactor=self.reactor) + real_resolver = twisted.names.client.Resolver( + "/etc/resolv.conf", reactor=self.reactor + ) + self.resolver = twisted.names.resolve.ResolverChain( + [host_resolver, cache_resolver, real_resolver] + ) + + @gen.coroutine + def resolve( + self, host: str, port: int, family: int = 0 + ) -> "Generator[Any, Any, List[Tuple[int, Any]]]": + # getHostByName doesn't accept IP addresses, so if the input + # looks like an IP address just return it immediately. + if twisted.internet.abstract.isIPAddress(host): + resolved = host + resolved_family = socket.AF_INET + elif twisted.internet.abstract.isIPv6Address(host): + resolved = host + resolved_family = socket.AF_INET6 + else: + deferred = self.resolver.getHostByName(utf8(host)) + fut = Future() # type: Future[Any] + deferred.addBoth(fut.set_result) + resolved = yield fut + if isinstance(resolved, failure.Failure): + try: + resolved.raiseException() + except twisted.names.error.DomainError as e: + raise IOError(e) + elif twisted.internet.abstract.isIPAddress(resolved): + resolved_family = socket.AF_INET + elif twisted.internet.abstract.isIPv6Address(resolved): + resolved_family = socket.AF_INET6 + else: + resolved_family = socket.AF_UNSPEC + if family != socket.AF_UNSPEC and family != resolved_family: + raise Exception( + "Requested socket family %d but got %d" % (family, resolved_family) + ) + result = [(typing.cast(int, resolved_family), (resolved, port))] + return result + + +def install() -> None: + """Install ``AsyncioSelectorReactor`` as the default Twisted reactor. + + .. deprecated:: 5.1 + + This function is provided for backwards compatibility; code + that does not require compatibility with older versions of + Tornado should use + ``twisted.internet.asyncioreactor.install()`` directly. + + .. versionchanged:: 6.0.3 + + In Tornado 5.x and before, this function installed a reactor + based on the Tornado ``IOLoop``. When that reactor + implementation was removed in Tornado 6.0.0, this function was + removed as well. It was restored in Tornado 6.0.3 using the + ``asyncio`` reactor instead. + + """ + from twisted.internet.asyncioreactor import install + + install() + + +if hasattr(gen.convert_yielded, "register"): + + @gen.convert_yielded.register(Deferred) # type: ignore + def _(d: Deferred) -> Future: + f = Future() # type: Future[Any] + + def errback(failure: failure.Failure) -> None: + try: + failure.raiseException() + # Should never happen, but just in case + raise Exception("errback called without error") + except: + future_set_exc_info(f, sys.exc_info()) + + d.addCallbacks(f.set_result, errback) + return f diff --git a/venv/lib/python3.9/site-packages/tornado/process.py b/venv/lib/python3.9/site-packages/tornado/process.py new file mode 100644 index 00000000..26428feb --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/process.py @@ -0,0 +1,373 @@ +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Utilities for working with multiple processes, including both forking +the server into multiple processes and managing subprocesses. +""" + +import os +import multiprocessing +import signal +import subprocess +import sys +import time + +from binascii import hexlify + +from tornado.concurrent import ( + Future, + future_set_result_unless_cancelled, + future_set_exception_unless_cancelled, +) +from tornado import ioloop +from tornado.iostream import PipeIOStream +from tornado.log import gen_log + +import typing +from typing import Optional, Any, Callable + +if typing.TYPE_CHECKING: + from typing import List # noqa: F401 + +# Re-export this exception for convenience. +CalledProcessError = subprocess.CalledProcessError + + +def cpu_count() -> int: + """Returns the number of processors on this machine.""" + if multiprocessing is None: + return 1 + try: + return multiprocessing.cpu_count() + except NotImplementedError: + pass + try: + return os.sysconf("SC_NPROCESSORS_CONF") # type: ignore + except (AttributeError, ValueError): + pass + gen_log.error("Could not detect number of processors; assuming 1") + return 1 + + +def _reseed_random() -> None: + if "random" not in sys.modules: + return + import random + + # If os.urandom is available, this method does the same thing as + # random.seed (at least as of python 2.6). If os.urandom is not + # available, we mix in the pid in addition to a timestamp. + try: + seed = int(hexlify(os.urandom(16)), 16) + except NotImplementedError: + seed = int(time.time() * 1000) ^ os.getpid() + random.seed(seed) + + +_task_id = None + + +def fork_processes( + num_processes: Optional[int], max_restarts: Optional[int] = None +) -> int: + """Starts multiple worker processes. + + If ``num_processes`` is None or <= 0, we detect the number of cores + available on this machine and fork that number of child + processes. If ``num_processes`` is given and > 0, we fork that + specific number of sub-processes. + + Since we use processes and not threads, there is no shared memory + between any server code. + + Note that multiple processes are not compatible with the autoreload + module (or the ``autoreload=True`` option to `tornado.web.Application` + which defaults to True when ``debug=True``). + When using multiple processes, no IOLoops can be created or + referenced until after the call to ``fork_processes``. + + In each child process, ``fork_processes`` returns its *task id*, a + number between 0 and ``num_processes``. Processes that exit + abnormally (due to a signal or non-zero exit status) are restarted + with the same id (up to ``max_restarts`` times). In the parent + process, ``fork_processes`` calls ``sys.exit(0)`` after all child + processes have exited normally. + + max_restarts defaults to 100. + + Availability: Unix + """ + if sys.platform == "win32": + # The exact form of this condition matters to mypy; it understands + # if but not assert in this context. + raise Exception("fork not available on windows") + if max_restarts is None: + max_restarts = 100 + + global _task_id + assert _task_id is None + if num_processes is None or num_processes <= 0: + num_processes = cpu_count() + gen_log.info("Starting %d processes", num_processes) + children = {} + + def start_child(i: int) -> Optional[int]: + pid = os.fork() + if pid == 0: + # child process + _reseed_random() + global _task_id + _task_id = i + return i + else: + children[pid] = i + return None + + for i in range(num_processes): + id = start_child(i) + if id is not None: + return id + num_restarts = 0 + while children: + pid, status = os.wait() + if pid not in children: + continue + id = children.pop(pid) + if os.WIFSIGNALED(status): + gen_log.warning( + "child %d (pid %d) killed by signal %d, restarting", + id, + pid, + os.WTERMSIG(status), + ) + elif os.WEXITSTATUS(status) != 0: + gen_log.warning( + "child %d (pid %d) exited with status %d, restarting", + id, + pid, + os.WEXITSTATUS(status), + ) + else: + gen_log.info("child %d (pid %d) exited normally", id, pid) + continue + num_restarts += 1 + if num_restarts > max_restarts: + raise RuntimeError("Too many child restarts, giving up") + new_id = start_child(id) + if new_id is not None: + return new_id + # All child processes exited cleanly, so exit the master process + # instead of just returning to right after the call to + # fork_processes (which will probably just start up another IOLoop + # unless the caller checks the return value). + sys.exit(0) + + +def task_id() -> Optional[int]: + """Returns the current task id, if any. + + Returns None if this process was not created by `fork_processes`. + """ + global _task_id + return _task_id + + +class Subprocess(object): + """Wraps ``subprocess.Popen`` with IOStream support. + + The constructor is the same as ``subprocess.Popen`` with the following + additions: + + * ``stdin``, ``stdout``, and ``stderr`` may have the value + ``tornado.process.Subprocess.STREAM``, which will make the corresponding + attribute of the resulting Subprocess a `.PipeIOStream`. If this option + is used, the caller is responsible for closing the streams when done + with them. + + The ``Subprocess.STREAM`` option and the ``set_exit_callback`` and + ``wait_for_exit`` methods do not work on Windows. There is + therefore no reason to use this class instead of + ``subprocess.Popen`` on that platform. + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + """ + + STREAM = object() + + _initialized = False + _waiting = {} # type: ignore + _old_sigchld = None + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.io_loop = ioloop.IOLoop.current() + # All FDs we create should be closed on error; those in to_close + # should be closed in the parent process on success. + pipe_fds = [] # type: List[int] + to_close = [] # type: List[int] + if kwargs.get("stdin") is Subprocess.STREAM: + in_r, in_w = os.pipe() + kwargs["stdin"] = in_r + pipe_fds.extend((in_r, in_w)) + to_close.append(in_r) + self.stdin = PipeIOStream(in_w) + if kwargs.get("stdout") is Subprocess.STREAM: + out_r, out_w = os.pipe() + kwargs["stdout"] = out_w + pipe_fds.extend((out_r, out_w)) + to_close.append(out_w) + self.stdout = PipeIOStream(out_r) + if kwargs.get("stderr") is Subprocess.STREAM: + err_r, err_w = os.pipe() + kwargs["stderr"] = err_w + pipe_fds.extend((err_r, err_w)) + to_close.append(err_w) + self.stderr = PipeIOStream(err_r) + try: + self.proc = subprocess.Popen(*args, **kwargs) + except: + for fd in pipe_fds: + os.close(fd) + raise + for fd in to_close: + os.close(fd) + self.pid = self.proc.pid + for attr in ["stdin", "stdout", "stderr"]: + if not hasattr(self, attr): # don't clobber streams set above + setattr(self, attr, getattr(self.proc, attr)) + self._exit_callback = None # type: Optional[Callable[[int], None]] + self.returncode = None # type: Optional[int] + + def set_exit_callback(self, callback: Callable[[int], None]) -> None: + """Runs ``callback`` when this process exits. + + The callback takes one argument, the return code of the process. + + This method uses a ``SIGCHLD`` handler, which is a global setting + and may conflict if you have other libraries trying to handle the + same signal. If you are using more than one ``IOLoop`` it may + be necessary to call `Subprocess.initialize` first to designate + one ``IOLoop`` to run the signal handlers. + + In many cases a close callback on the stdout or stderr streams + can be used as an alternative to an exit callback if the + signal handler is causing a problem. + + Availability: Unix + """ + self._exit_callback = callback + Subprocess.initialize() + Subprocess._waiting[self.pid] = self + Subprocess._try_cleanup_process(self.pid) + + def wait_for_exit(self, raise_error: bool = True) -> "Future[int]": + """Returns a `.Future` which resolves when the process exits. + + Usage:: + + ret = yield proc.wait_for_exit() + + This is a coroutine-friendly alternative to `set_exit_callback` + (and a replacement for the blocking `subprocess.Popen.wait`). + + By default, raises `subprocess.CalledProcessError` if the process + has a non-zero exit status. Use ``wait_for_exit(raise_error=False)`` + to suppress this behavior and return the exit status without raising. + + .. versionadded:: 4.2 + + Availability: Unix + """ + future = Future() # type: Future[int] + + def callback(ret: int) -> None: + if ret != 0 and raise_error: + # Unfortunately we don't have the original args any more. + future_set_exception_unless_cancelled( + future, CalledProcessError(ret, "unknown") + ) + else: + future_set_result_unless_cancelled(future, ret) + + self.set_exit_callback(callback) + return future + + @classmethod + def initialize(cls) -> None: + """Initializes the ``SIGCHLD`` handler. + + The signal handler is run on an `.IOLoop` to avoid locking issues. + Note that the `.IOLoop` used for signal handling need not be the + same one used by individual Subprocess objects (as long as the + ``IOLoops`` are each running in separate threads). + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been + removed. + + Availability: Unix + """ + if cls._initialized: + return + io_loop = ioloop.IOLoop.current() + cls._old_sigchld = signal.signal( + signal.SIGCHLD, + lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup), + ) + cls._initialized = True + + @classmethod + def uninitialize(cls) -> None: + """Removes the ``SIGCHLD`` handler.""" + if not cls._initialized: + return + signal.signal(signal.SIGCHLD, cls._old_sigchld) + cls._initialized = False + + @classmethod + def _cleanup(cls) -> None: + for pid in list(cls._waiting.keys()): # make a copy + cls._try_cleanup_process(pid) + + @classmethod + def _try_cleanup_process(cls, pid: int) -> None: + try: + ret_pid, status = os.waitpid(pid, os.WNOHANG) # type: ignore + except ChildProcessError: + return + if ret_pid == 0: + return + assert ret_pid == pid + subproc = cls._waiting.pop(pid) + subproc.io_loop.add_callback_from_signal(subproc._set_returncode, status) + + def _set_returncode(self, status: int) -> None: + if sys.platform == "win32": + self.returncode = -1 + else: + if os.WIFSIGNALED(status): + self.returncode = -os.WTERMSIG(status) + else: + assert os.WIFEXITED(status) + self.returncode = os.WEXITSTATUS(status) + # We've taken over wait() duty from the subprocess.Popen + # object. If we don't inform it of the process's return code, + # it will log a warning at destruction in python 3.6+. + self.proc.returncode = self.returncode + if self._exit_callback: + callback = self._exit_callback + self._exit_callback = None + callback(self.returncode) diff --git a/venv/lib/python3.9/site-packages/tornado/py.typed b/venv/lib/python3.9/site-packages/tornado/py.typed new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/py.typed diff --git a/venv/lib/python3.9/site-packages/tornado/queues.py b/venv/lib/python3.9/site-packages/tornado/queues.py new file mode 100644 index 00000000..1358d0ec --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/queues.py @@ -0,0 +1,422 @@ +# Copyright 2015 The Tornado Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Asynchronous queues for coroutines. These classes are very similar +to those provided in the standard library's `asyncio package +<https://docs.python.org/3/library/asyncio-queue.html>`_. + +.. warning:: + + Unlike the standard library's `queue` module, the classes defined here + are *not* thread-safe. To use these queues from another thread, + use `.IOLoop.add_callback` to transfer control to the `.IOLoop` thread + before calling any queue methods. + +""" + +import collections +import datetime +import heapq + +from tornado import gen, ioloop +from tornado.concurrent import Future, future_set_result_unless_cancelled +from tornado.locks import Event + +from typing import Union, TypeVar, Generic, Awaitable, Optional +import typing + +if typing.TYPE_CHECKING: + from typing import Deque, Tuple, Any # noqa: F401 + +_T = TypeVar("_T") + +__all__ = ["Queue", "PriorityQueue", "LifoQueue", "QueueFull", "QueueEmpty"] + + +class QueueEmpty(Exception): + """Raised by `.Queue.get_nowait` when the queue has no items.""" + + pass + + +class QueueFull(Exception): + """Raised by `.Queue.put_nowait` when a queue is at its maximum size.""" + + pass + + +def _set_timeout( + future: Future, timeout: Union[None, float, datetime.timedelta] +) -> None: + if timeout: + + def on_timeout() -> None: + if not future.done(): + future.set_exception(gen.TimeoutError()) + + io_loop = ioloop.IOLoop.current() + timeout_handle = io_loop.add_timeout(timeout, on_timeout) + future.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle)) + + +class _QueueIterator(Generic[_T]): + def __init__(self, q: "Queue[_T]") -> None: + self.q = q + + def __anext__(self) -> Awaitable[_T]: + return self.q.get() + + +class Queue(Generic[_T]): + """Coordinate producer and consumer coroutines. + + If maxsize is 0 (the default) the queue size is unbounded. + + .. testcode:: + + import asyncio + from tornado.ioloop import IOLoop + from tornado.queues import Queue + + q = Queue(maxsize=2) + + async def consumer(): + async for item in q: + try: + print('Doing work on %s' % item) + await asyncio.sleep(0.01) + finally: + q.task_done() + + async def producer(): + for item in range(5): + await q.put(item) + print('Put %s' % item) + + async def main(): + # Start consumer without waiting (since it never finishes). + IOLoop.current().spawn_callback(consumer) + await producer() # Wait for producer to put all tasks. + await q.join() # Wait for consumer to finish all tasks. + print('Done') + + asyncio.run(main()) + + .. testoutput:: + + Put 0 + Put 1 + Doing work on 0 + Put 2 + Doing work on 1 + Put 3 + Doing work on 2 + Put 4 + Doing work on 3 + Doing work on 4 + Done + + + In versions of Python without native coroutines (before 3.5), + ``consumer()`` could be written as:: + + @gen.coroutine + def consumer(): + while True: + item = yield q.get() + try: + print('Doing work on %s' % item) + yield gen.sleep(0.01) + finally: + q.task_done() + + .. versionchanged:: 4.3 + Added ``async for`` support in Python 3.5. + + """ + + # Exact type depends on subclass. Could be another generic + # parameter and use protocols to be more precise here. + _queue = None # type: Any + + def __init__(self, maxsize: int = 0) -> None: + if maxsize is None: + raise TypeError("maxsize can't be None") + + if maxsize < 0: + raise ValueError("maxsize can't be negative") + + self._maxsize = maxsize + self._init() + self._getters = collections.deque([]) # type: Deque[Future[_T]] + self._putters = collections.deque([]) # type: Deque[Tuple[_T, Future[None]]] + self._unfinished_tasks = 0 + self._finished = Event() + self._finished.set() + + @property + def maxsize(self) -> int: + """Number of items allowed in the queue.""" + return self._maxsize + + def qsize(self) -> int: + """Number of items in the queue.""" + return len(self._queue) + + def empty(self) -> bool: + return not self._queue + + def full(self) -> bool: + if self.maxsize == 0: + return False + else: + return self.qsize() >= self.maxsize + + def put( + self, item: _T, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> "Future[None]": + """Put an item into the queue, perhaps waiting until there is room. + + Returns a Future, which raises `tornado.util.TimeoutError` after a + timeout. + + ``timeout`` may be a number denoting a time (on the same + scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a + `datetime.timedelta` object for a deadline relative to the + current time. + """ + future = Future() # type: Future[None] + try: + self.put_nowait(item) + except QueueFull: + self._putters.append((item, future)) + _set_timeout(future, timeout) + else: + future.set_result(None) + return future + + def put_nowait(self, item: _T) -> None: + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise `QueueFull`. + """ + self._consume_expired() + if self._getters: + assert self.empty(), "queue non-empty, why are getters waiting?" + getter = self._getters.popleft() + self.__put_internal(item) + future_set_result_unless_cancelled(getter, self._get()) + elif self.full(): + raise QueueFull + else: + self.__put_internal(item) + + def get( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[_T]: + """Remove and return an item from the queue. + + Returns an awaitable which resolves once an item is available, or raises + `tornado.util.TimeoutError` after a timeout. + + ``timeout`` may be a number denoting a time (on the same + scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a + `datetime.timedelta` object for a deadline relative to the + current time. + + .. note:: + + The ``timeout`` argument of this method differs from that + of the standard library's `queue.Queue.get`. That method + interprets numeric values as relative timeouts; this one + interprets them as absolute deadlines and requires + ``timedelta`` objects for relative timeouts (consistent + with other timeouts in Tornado). + + """ + future = Future() # type: Future[_T] + try: + future.set_result(self.get_nowait()) + except QueueEmpty: + self._getters.append(future) + _set_timeout(future, timeout) + return future + + def get_nowait(self) -> _T: + """Remove and return an item from the queue without blocking. + + Return an item if one is immediately available, else raise + `QueueEmpty`. + """ + self._consume_expired() + if self._putters: + assert self.full(), "queue not full, why are putters waiting?" + item, putter = self._putters.popleft() + self.__put_internal(item) + future_set_result_unless_cancelled(putter, None) + return self._get() + elif self.qsize(): + return self._get() + else: + raise QueueEmpty + + def task_done(self) -> None: + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each `.get` used to fetch a task, a + subsequent call to `.task_done` tells the queue that the processing + on the task is complete. + + If a `.join` is blocking, it resumes when all items have been + processed; that is, when every `.put` is matched by a `.task_done`. + + Raises `ValueError` if called more times than `.put`. + """ + if self._unfinished_tasks <= 0: + raise ValueError("task_done() called too many times") + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + def join( + self, timeout: Optional[Union[float, datetime.timedelta]] = None + ) -> Awaitable[None]: + """Block until all items in the queue are processed. + + Returns an awaitable, which raises `tornado.util.TimeoutError` after a + timeout. + """ + return self._finished.wait(timeout) + + def __aiter__(self) -> _QueueIterator[_T]: + return _QueueIterator(self) + + # These three are overridable in subclasses. + def _init(self) -> None: + self._queue = collections.deque() + + def _get(self) -> _T: + return self._queue.popleft() + + def _put(self, item: _T) -> None: + self._queue.append(item) + + # End of the overridable methods. + + def __put_internal(self, item: _T) -> None: + self._unfinished_tasks += 1 + self._finished.clear() + self._put(item) + + def _consume_expired(self) -> None: + # Remove timed-out waiters. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def __repr__(self) -> str: + return "<%s at %s %s>" % (type(self).__name__, hex(id(self)), self._format()) + + def __str__(self) -> str: + return "<%s %s>" % (type(self).__name__, self._format()) + + def _format(self) -> str: + result = "maxsize=%r" % (self.maxsize,) + if getattr(self, "_queue", None): + result += " queue=%r" % self._queue + if self._getters: + result += " getters[%s]" % len(self._getters) + if self._putters: + result += " putters[%s]" % len(self._putters) + if self._unfinished_tasks: + result += " tasks=%s" % self._unfinished_tasks + return result + + +class PriorityQueue(Queue): + """A `.Queue` that retrieves entries in priority order, lowest first. + + Entries are typically tuples like ``(priority number, data)``. + + .. testcode:: + + import asyncio + from tornado.queues import PriorityQueue + + async def main(): + q = PriorityQueue() + q.put((1, 'medium-priority item')) + q.put((0, 'high-priority item')) + q.put((10, 'low-priority item')) + + print(await q.get()) + print(await q.get()) + print(await q.get()) + + asyncio.run(main()) + + .. testoutput:: + + (0, 'high-priority item') + (1, 'medium-priority item') + (10, 'low-priority item') + """ + + def _init(self) -> None: + self._queue = [] + + def _put(self, item: _T) -> None: + heapq.heappush(self._queue, item) + + def _get(self) -> _T: # type: ignore[type-var] + return heapq.heappop(self._queue) + + +class LifoQueue(Queue): + """A `.Queue` that retrieves the most recently put items first. + + .. testcode:: + + import asyncio + from tornado.queues import LifoQueue + + async def main(): + q = LifoQueue() + q.put(3) + q.put(2) + q.put(1) + + print(await q.get()) + print(await q.get()) + print(await q.get()) + + asyncio.run(main()) + + .. testoutput:: + + 1 + 2 + 3 + """ + + def _init(self) -> None: + self._queue = [] + + def _put(self, item: _T) -> None: + self._queue.append(item) + + def _get(self) -> _T: # type: ignore[type-var] + return self._queue.pop() diff --git a/venv/lib/python3.9/site-packages/tornado/routing.py b/venv/lib/python3.9/site-packages/tornado/routing.py new file mode 100644 index 00000000..a145d719 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/routing.py @@ -0,0 +1,717 @@ +# Copyright 2015 The Tornado Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Flexible routing implementation. + +Tornado routes HTTP requests to appropriate handlers using `Router` +class implementations. The `tornado.web.Application` class is a +`Router` implementation and may be used directly, or the classes in +this module may be used for additional flexibility. The `RuleRouter` +class can match on more criteria than `.Application`, or the `Router` +interface can be subclassed for maximum customization. + +`Router` interface extends `~.httputil.HTTPServerConnectionDelegate` +to provide additional routing capabilities. This also means that any +`Router` implementation can be used directly as a ``request_callback`` +for `~.httpserver.HTTPServer` constructor. + +`Router` subclass must implement a ``find_handler`` method to provide +a suitable `~.httputil.HTTPMessageDelegate` instance to handle the +request: + +.. code-block:: python + + class CustomRouter(Router): + def find_handler(self, request, **kwargs): + # some routing logic providing a suitable HTTPMessageDelegate instance + return MessageDelegate(request.connection) + + class MessageDelegate(HTTPMessageDelegate): + def __init__(self, connection): + self.connection = connection + + def finish(self): + self.connection.write_headers( + ResponseStartLine("HTTP/1.1", 200, "OK"), + HTTPHeaders({"Content-Length": "2"}), + b"OK") + self.connection.finish() + + router = CustomRouter() + server = HTTPServer(router) + +The main responsibility of `Router` implementation is to provide a +mapping from a request to `~.httputil.HTTPMessageDelegate` instance +that will handle this request. In the example above we can see that +routing is possible even without instantiating an `~.web.Application`. + +For routing to `~.web.RequestHandler` implementations we need an +`~.web.Application` instance. `~.web.Application.get_handler_delegate` +provides a convenient way to create `~.httputil.HTTPMessageDelegate` +for a given request and `~.web.RequestHandler`. + +Here is a simple example of how we can we route to +`~.web.RequestHandler` subclasses by HTTP method: + +.. code-block:: python + + resources = {} + + class GetResource(RequestHandler): + def get(self, path): + if path not in resources: + raise HTTPError(404) + + self.finish(resources[path]) + + class PostResource(RequestHandler): + def post(self, path): + resources[path] = self.request.body + + class HTTPMethodRouter(Router): + def __init__(self, app): + self.app = app + + def find_handler(self, request, **kwargs): + handler = GetResource if request.method == "GET" else PostResource + return self.app.get_handler_delegate(request, handler, path_args=[request.path]) + + router = HTTPMethodRouter(Application()) + server = HTTPServer(router) + +`ReversibleRouter` interface adds the ability to distinguish between +the routes and reverse them to the original urls using route's name +and additional arguments. `~.web.Application` is itself an +implementation of `ReversibleRouter` class. + +`RuleRouter` and `ReversibleRuleRouter` are implementations of +`Router` and `ReversibleRouter` interfaces and can be used for +creating rule-based routing configurations. + +Rules are instances of `Rule` class. They contain a `Matcher`, which +provides the logic for determining whether the rule is a match for a +particular request and a target, which can be one of the following. + +1) An instance of `~.httputil.HTTPServerConnectionDelegate`: + +.. code-block:: python + + router = RuleRouter([ + Rule(PathMatches("/handler"), ConnectionDelegate()), + # ... more rules + ]) + + class ConnectionDelegate(HTTPServerConnectionDelegate): + def start_request(self, server_conn, request_conn): + return MessageDelegate(request_conn) + +2) A callable accepting a single argument of `~.httputil.HTTPServerRequest` type: + +.. code-block:: python + + router = RuleRouter([ + Rule(PathMatches("/callable"), request_callable) + ]) + + def request_callable(request): + request.write(b"HTTP/1.1 200 OK\\r\\nContent-Length: 2\\r\\n\\r\\nOK") + request.finish() + +3) Another `Router` instance: + +.. code-block:: python + + router = RuleRouter([ + Rule(PathMatches("/router.*"), CustomRouter()) + ]) + +Of course a nested `RuleRouter` or a `~.web.Application` is allowed: + +.. code-block:: python + + router = RuleRouter([ + Rule(HostMatches("example.com"), RuleRouter([ + Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)])), + ])) + ]) + + server = HTTPServer(router) + +In the example below `RuleRouter` is used to route between applications: + +.. code-block:: python + + app1 = Application([ + (r"/app1/handler", Handler1), + # other handlers ... + ]) + + app2 = Application([ + (r"/app2/handler", Handler2), + # other handlers ... + ]) + + router = RuleRouter([ + Rule(PathMatches("/app1.*"), app1), + Rule(PathMatches("/app2.*"), app2) + ]) + + server = HTTPServer(router) + +For more information on application-level routing see docs for `~.web.Application`. + +.. versionadded:: 4.5 + +""" + +import re +from functools import partial + +from tornado import httputil +from tornado.httpserver import _CallableAdapter +from tornado.escape import url_escape, url_unescape, utf8 +from tornado.log import app_log +from tornado.util import basestring_type, import_object, re_unescape, unicode_type + +from typing import Any, Union, Optional, Awaitable, List, Dict, Pattern, Tuple, overload + + +class Router(httputil.HTTPServerConnectionDelegate): + """Abstract router interface.""" + + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> Optional[httputil.HTTPMessageDelegate]: + """Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate` + that can serve the request. + Routing implementations may pass additional kwargs to extend the routing logic. + + :arg httputil.HTTPServerRequest request: current HTTP request. + :arg kwargs: additional keyword arguments passed by routing implementation. + :returns: an instance of `~.httputil.HTTPMessageDelegate` that will be used to + process the request. + """ + raise NotImplementedError() + + def start_request( + self, server_conn: object, request_conn: httputil.HTTPConnection + ) -> httputil.HTTPMessageDelegate: + return _RoutingDelegate(self, server_conn, request_conn) + + +class ReversibleRouter(Router): + """Abstract router interface for routers that can handle named routes + and support reversing them to original urls. + """ + + def reverse_url(self, name: str, *args: Any) -> Optional[str]: + """Returns url string for a given route name and arguments + or ``None`` if no match is found. + + :arg str name: route name. + :arg args: url parameters. + :returns: parametrized url string for a given route name (or ``None``). + """ + raise NotImplementedError() + + +class _RoutingDelegate(httputil.HTTPMessageDelegate): + def __init__( + self, router: Router, server_conn: object, request_conn: httputil.HTTPConnection + ) -> None: + self.server_conn = server_conn + self.request_conn = request_conn + self.delegate = None # type: Optional[httputil.HTTPMessageDelegate] + self.router = router # type: Router + + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: + assert isinstance(start_line, httputil.RequestStartLine) + request = httputil.HTTPServerRequest( + connection=self.request_conn, + server_connection=self.server_conn, + start_line=start_line, + headers=headers, + ) + + self.delegate = self.router.find_handler(request) + if self.delegate is None: + app_log.debug( + "Delegate for %s %s request not found", + start_line.method, + start_line.path, + ) + self.delegate = _DefaultMessageDelegate(self.request_conn) + + return self.delegate.headers_received(start_line, headers) + + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: + assert self.delegate is not None + return self.delegate.data_received(chunk) + + def finish(self) -> None: + assert self.delegate is not None + self.delegate.finish() + + def on_connection_close(self) -> None: + assert self.delegate is not None + self.delegate.on_connection_close() + + +class _DefaultMessageDelegate(httputil.HTTPMessageDelegate): + def __init__(self, connection: httputil.HTTPConnection) -> None: + self.connection = connection + + def finish(self) -> None: + self.connection.write_headers( + httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"), + httputil.HTTPHeaders(), + ) + self.connection.finish() + + +# _RuleList can either contain pre-constructed Rules or a sequence of +# arguments to be passed to the Rule constructor. +_RuleList = List[ + Union[ + "Rule", + List[Any], # Can't do detailed typechecking of lists. + Tuple[Union[str, "Matcher"], Any], + Tuple[Union[str, "Matcher"], Any, Dict[str, Any]], + Tuple[Union[str, "Matcher"], Any, Dict[str, Any], str], + ] +] + + +class RuleRouter(Router): + """Rule-based router implementation.""" + + def __init__(self, rules: Optional[_RuleList] = None) -> None: + """Constructs a router from an ordered list of rules:: + + RuleRouter([ + Rule(PathMatches("/handler"), Target), + # ... more rules + ]) + + You can also omit explicit `Rule` constructor and use tuples of arguments:: + + RuleRouter([ + (PathMatches("/handler"), Target), + ]) + + `PathMatches` is a default matcher, so the example above can be simplified:: + + RuleRouter([ + ("/handler", Target), + ]) + + In the examples above, ``Target`` can be a nested `Router` instance, an instance of + `~.httputil.HTTPServerConnectionDelegate` or an old-style callable, + accepting a request argument. + + :arg rules: a list of `Rule` instances or tuples of `Rule` + constructor arguments. + """ + self.rules = [] # type: List[Rule] + if rules: + self.add_rules(rules) + + def add_rules(self, rules: _RuleList) -> None: + """Appends new rules to the router. + + :arg rules: a list of Rule instances (or tuples of arguments, which are + passed to Rule constructor). + """ + for rule in rules: + if isinstance(rule, (tuple, list)): + assert len(rule) in (2, 3, 4) + if isinstance(rule[0], basestring_type): + rule = Rule(PathMatches(rule[0]), *rule[1:]) + else: + rule = Rule(*rule) + + self.rules.append(self.process_rule(rule)) + + def process_rule(self, rule: "Rule") -> "Rule": + """Override this method for additional preprocessing of each rule. + + :arg Rule rule: a rule to be processed. + :returns: the same or modified Rule instance. + """ + return rule + + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> Optional[httputil.HTTPMessageDelegate]: + for rule in self.rules: + target_params = rule.matcher.match(request) + if target_params is not None: + if rule.target_kwargs: + target_params["target_kwargs"] = rule.target_kwargs + + delegate = self.get_target_delegate( + rule.target, request, **target_params + ) + + if delegate is not None: + return delegate + + return None + + def get_target_delegate( + self, target: Any, request: httputil.HTTPServerRequest, **target_params: Any + ) -> Optional[httputil.HTTPMessageDelegate]: + """Returns an instance of `~.httputil.HTTPMessageDelegate` for a + Rule's target. This method is called by `~.find_handler` and can be + extended to provide additional target types. + + :arg target: a Rule's target. + :arg httputil.HTTPServerRequest request: current request. + :arg target_params: additional parameters that can be useful + for `~.httputil.HTTPMessageDelegate` creation. + """ + if isinstance(target, Router): + return target.find_handler(request, **target_params) + + elif isinstance(target, httputil.HTTPServerConnectionDelegate): + assert request.connection is not None + return target.start_request(request.server_connection, request.connection) + + elif callable(target): + assert request.connection is not None + return _CallableAdapter( + partial(target, **target_params), request.connection + ) + + return None + + +class ReversibleRuleRouter(ReversibleRouter, RuleRouter): + """A rule-based router that implements ``reverse_url`` method. + + Each rule added to this router may have a ``name`` attribute that can be + used to reconstruct an original uri. The actual reconstruction takes place + in a rule's matcher (see `Matcher.reverse`). + """ + + def __init__(self, rules: Optional[_RuleList] = None) -> None: + self.named_rules = {} # type: Dict[str, Any] + super().__init__(rules) + + def process_rule(self, rule: "Rule") -> "Rule": + rule = super().process_rule(rule) + + if rule.name: + if rule.name in self.named_rules: + app_log.warning( + "Multiple handlers named %s; replacing previous value", rule.name + ) + self.named_rules[rule.name] = rule + + return rule + + def reverse_url(self, name: str, *args: Any) -> Optional[str]: + if name in self.named_rules: + return self.named_rules[name].matcher.reverse(*args) + + for rule in self.rules: + if isinstance(rule.target, ReversibleRouter): + reversed_url = rule.target.reverse_url(name, *args) + if reversed_url is not None: + return reversed_url + + return None + + +class Rule(object): + """A routing rule.""" + + def __init__( + self, + matcher: "Matcher", + target: Any, + target_kwargs: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + ) -> None: + """Constructs a Rule instance. + + :arg Matcher matcher: a `Matcher` instance used for determining + whether the rule should be considered a match for a specific + request. + :arg target: a Rule's target (typically a ``RequestHandler`` or + `~.httputil.HTTPServerConnectionDelegate` subclass or even a nested `Router`, + depending on routing implementation). + :arg dict target_kwargs: a dict of parameters that can be useful + at the moment of target instantiation (for example, ``status_code`` + for a ``RequestHandler`` subclass). They end up in + ``target_params['target_kwargs']`` of `RuleRouter.get_target_delegate` + method. + :arg str name: the name of the rule that can be used to find it + in `ReversibleRouter.reverse_url` implementation. + """ + if isinstance(target, str): + # import the Module and instantiate the class + # Must be a fully qualified name (module.ClassName) + target = import_object(target) + + self.matcher = matcher # type: Matcher + self.target = target + self.target_kwargs = target_kwargs if target_kwargs else {} + self.name = name + + def reverse(self, *args: Any) -> Optional[str]: + return self.matcher.reverse(*args) + + def __repr__(self) -> str: + return "%s(%r, %s, kwargs=%r, name=%r)" % ( + self.__class__.__name__, + self.matcher, + self.target, + self.target_kwargs, + self.name, + ) + + +class Matcher(object): + """Represents a matcher for request features.""" + + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: + """Matches current instance against the request. + + :arg httputil.HTTPServerRequest request: current HTTP request + :returns: a dict of parameters to be passed to the target handler + (for example, ``handler_kwargs``, ``path_args``, ``path_kwargs`` + can be passed for proper `~.web.RequestHandler` instantiation). + An empty dict is a valid (and common) return value to indicate a match + when the argument-passing features are not used. + ``None`` must be returned to indicate that there is no match.""" + raise NotImplementedError() + + def reverse(self, *args: Any) -> Optional[str]: + """Reconstructs full url from matcher instance and additional arguments.""" + return None + + +class AnyMatches(Matcher): + """Matches any request.""" + + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: + return {} + + +class HostMatches(Matcher): + """Matches requests from hosts specified by ``host_pattern`` regex.""" + + def __init__(self, host_pattern: Union[str, Pattern]) -> None: + if isinstance(host_pattern, basestring_type): + if not host_pattern.endswith("$"): + host_pattern += "$" + self.host_pattern = re.compile(host_pattern) + else: + self.host_pattern = host_pattern + + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: + if self.host_pattern.match(request.host_name): + return {} + + return None + + +class DefaultHostMatches(Matcher): + """Matches requests from host that is equal to application's default_host. + Always returns no match if ``X-Real-Ip`` header is present. + """ + + def __init__(self, application: Any, host_pattern: Pattern) -> None: + self.application = application + self.host_pattern = host_pattern + + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: + # Look for default host if not behind load balancer (for debugging) + if "X-Real-Ip" not in request.headers: + if self.host_pattern.match(self.application.default_host): + return {} + return None + + +class PathMatches(Matcher): + """Matches requests with paths specified by ``path_pattern`` regex.""" + + def __init__(self, path_pattern: Union[str, Pattern]) -> None: + if isinstance(path_pattern, basestring_type): + if not path_pattern.endswith("$"): + path_pattern += "$" + self.regex = re.compile(path_pattern) + else: + self.regex = path_pattern + + assert len(self.regex.groupindex) in (0, self.regex.groups), ( + "groups in url regexes must either be all named or all " + "positional: %r" % self.regex.pattern + ) + + self._path, self._group_count = self._find_groups() + + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: + match = self.regex.match(request.path) + if match is None: + return None + if not self.regex.groups: + return {} + + path_args = [] # type: List[bytes] + path_kwargs = {} # type: Dict[str, bytes] + + # Pass matched groups to the handler. Since + # match.groups() includes both named and + # unnamed groups, we want to use either groups + # or groupdict but not both. + if self.regex.groupindex: + path_kwargs = dict( + (str(k), _unquote_or_none(v)) for (k, v) in match.groupdict().items() + ) + else: + path_args = [_unquote_or_none(s) for s in match.groups()] + + return dict(path_args=path_args, path_kwargs=path_kwargs) + + def reverse(self, *args: Any) -> Optional[str]: + if self._path is None: + raise ValueError("Cannot reverse url regex " + self.regex.pattern) + assert len(args) == self._group_count, ( + "required number of arguments " "not found" + ) + if not len(args): + return self._path + converted_args = [] + for a in args: + if not isinstance(a, (unicode_type, bytes)): + a = str(a) + converted_args.append(url_escape(utf8(a), plus=False)) + return self._path % tuple(converted_args) + + def _find_groups(self) -> Tuple[Optional[str], Optional[int]]: + """Returns a tuple (reverse string, group count) for a url. + + For example: Given the url pattern /([0-9]{4})/([a-z-]+)/, this method + would return ('/%s/%s/', 2). + """ + pattern = self.regex.pattern + if pattern.startswith("^"): + pattern = pattern[1:] + if pattern.endswith("$"): + pattern = pattern[:-1] + + if self.regex.groups != pattern.count("("): + # The pattern is too complicated for our simplistic matching, + # so we can't support reversing it. + return None, None + + pieces = [] + for fragment in pattern.split("("): + if ")" in fragment: + paren_loc = fragment.index(")") + if paren_loc >= 0: + try: + unescaped_fragment = re_unescape(fragment[paren_loc + 1 :]) + except ValueError: + # If we can't unescape part of it, we can't + # reverse this url. + return (None, None) + pieces.append("%s" + unescaped_fragment) + else: + try: + unescaped_fragment = re_unescape(fragment) + except ValueError: + # If we can't unescape part of it, we can't + # reverse this url. + return (None, None) + pieces.append(unescaped_fragment) + + return "".join(pieces), self.regex.groups + + +class URLSpec(Rule): + """Specifies mappings between URLs and handlers. + + .. versionchanged: 4.5 + `URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for + backwards compatibility. + """ + + def __init__( + self, + pattern: Union[str, Pattern], + handler: Any, + kwargs: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + ) -> None: + """Parameters: + + * ``pattern``: Regular expression to be matched. Any capturing + groups in the regex will be passed in to the handler's + get/post/etc methods as arguments (by keyword if named, by + position if unnamed. Named and unnamed capturing groups + may not be mixed in the same rule). + + * ``handler``: `~.web.RequestHandler` subclass to be invoked. + + * ``kwargs`` (optional): A dictionary of additional arguments + to be passed to the handler's constructor. + + * ``name`` (optional): A name for this handler. Used by + `~.web.Application.reverse_url`. + + """ + matcher = PathMatches(pattern) + super().__init__(matcher, handler, kwargs, name) + + self.regex = matcher.regex + self.handler_class = self.target + self.kwargs = kwargs + + def __repr__(self) -> str: + return "%s(%r, %s, kwargs=%r, name=%r)" % ( + self.__class__.__name__, + self.regex.pattern, + self.handler_class, + self.kwargs, + self.name, + ) + + +@overload +def _unquote_or_none(s: str) -> bytes: + pass + + +@overload # noqa: F811 +def _unquote_or_none(s: None) -> None: + pass + + +def _unquote_or_none(s: Optional[str]) -> Optional[bytes]: # noqa: F811 + """None-safe wrapper around url_unescape to handle unmatched optional + groups correctly. + + Note that args are passed as bytes so the handler can decide what + encoding to use. + """ + if s is None: + return s + return url_unescape(s, encoding=None, plus=False) diff --git a/venv/lib/python3.9/site-packages/tornado/simple_httpclient.py b/venv/lib/python3.9/site-packages/tornado/simple_httpclient.py new file mode 100644 index 00000000..2460863f --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/simple_httpclient.py @@ -0,0 +1,704 @@ +from tornado.escape import _unicode +from tornado import gen, version +from tornado.httpclient import ( + HTTPResponse, + HTTPError, + AsyncHTTPClient, + main, + _RequestProxy, + HTTPRequest, +) +from tornado import httputil +from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters +from tornado.ioloop import IOLoop +from tornado.iostream import StreamClosedError, IOStream +from tornado.netutil import ( + Resolver, + OverrideResolver, + _client_ssl_defaults, + is_valid_ip, +) +from tornado.log import gen_log +from tornado.tcpclient import TCPClient + +import base64 +import collections +import copy +import functools +import re +import socket +import ssl +import sys +import time +from io import BytesIO +import urllib.parse + +from typing import Dict, Any, Callable, Optional, Type, Union +from types import TracebackType +import typing + +if typing.TYPE_CHECKING: + from typing import Deque, Tuple, List # noqa: F401 + + +class HTTPTimeoutError(HTTPError): + """Error raised by SimpleAsyncHTTPClient on timeout. + + For historical reasons, this is a subclass of `.HTTPClientError` + which simulates a response code of 599. + + .. versionadded:: 5.1 + """ + + def __init__(self, message: str) -> None: + super().__init__(599, message=message) + + def __str__(self) -> str: + return self.message or "Timeout" + + +class HTTPStreamClosedError(HTTPError): + """Error raised by SimpleAsyncHTTPClient when the underlying stream is closed. + + When a more specific exception is available (such as `ConnectionResetError`), + it may be raised instead of this one. + + For historical reasons, this is a subclass of `.HTTPClientError` + which simulates a response code of 599. + + .. versionadded:: 5.1 + """ + + def __init__(self, message: str) -> None: + super().__init__(599, message=message) + + def __str__(self) -> str: + return self.message or "Stream closed" + + +class SimpleAsyncHTTPClient(AsyncHTTPClient): + """Non-blocking HTTP client with no external dependencies. + + This class implements an HTTP 1.1 client on top of Tornado's IOStreams. + Some features found in the curl-based AsyncHTTPClient are not yet + supported. In particular, proxies are not supported, connections + are not reused, and callers cannot select the network interface to be + used. + + This implementation supports the following arguments, which can be passed + to ``configure()`` to control the global singleton, or to the constructor + when ``force_instance=True``. + + ``max_clients`` is the number of concurrent requests that can be + in progress; when this limit is reached additional requests will be + queued. Note that time spent waiting in this queue still counts + against the ``request_timeout``. + + ``defaults`` is a dict of parameters that will be used as defaults on all + `.HTTPRequest` objects submitted to this client. + + ``hostname_mapping`` is a dictionary mapping hostnames to IP addresses. + It can be used to make local DNS changes when modifying system-wide + settings like ``/etc/hosts`` is not possible or desirable (e.g. in + unittests). ``resolver`` is similar, but using the `.Resolver` interface + instead of a simple mapping. + + ``max_buffer_size`` (default 100MB) is the number of bytes + that can be read into memory at once. ``max_body_size`` + (defaults to ``max_buffer_size``) is the largest response body + that the client will accept. Without a + ``streaming_callback``, the smaller of these two limits + applies; with a ``streaming_callback`` only ``max_body_size`` + does. + + .. versionchanged:: 4.2 + Added the ``max_body_size`` argument. + """ + + def initialize( # type: ignore + self, + max_clients: int = 10, + hostname_mapping: Optional[Dict[str, str]] = None, + max_buffer_size: int = 104857600, + resolver: Optional[Resolver] = None, + defaults: Optional[Dict[str, Any]] = None, + max_header_size: Optional[int] = None, + max_body_size: Optional[int] = None, + ) -> None: + super().initialize(defaults=defaults) + self.max_clients = max_clients + self.queue = ( + collections.deque() + ) # type: Deque[Tuple[object, HTTPRequest, Callable[[HTTPResponse], None]]] + self.active = ( + {} + ) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None]]] + self.waiting = ( + {} + ) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None], object]] + self.max_buffer_size = max_buffer_size + self.max_header_size = max_header_size + self.max_body_size = max_body_size + # TCPClient could create a Resolver for us, but we have to do it + # ourselves to support hostname_mapping. + if resolver: + self.resolver = resolver + self.own_resolver = False + else: + self.resolver = Resolver() + self.own_resolver = True + if hostname_mapping is not None: + self.resolver = OverrideResolver( + resolver=self.resolver, mapping=hostname_mapping + ) + self.tcp_client = TCPClient(resolver=self.resolver) + + def close(self) -> None: + super().close() + if self.own_resolver: + self.resolver.close() + self.tcp_client.close() + + def fetch_impl( + self, request: HTTPRequest, callback: Callable[[HTTPResponse], None] + ) -> None: + key = object() + self.queue.append((key, request, callback)) + assert request.connect_timeout is not None + assert request.request_timeout is not None + timeout_handle = None + if len(self.active) >= self.max_clients: + timeout = ( + min(request.connect_timeout, request.request_timeout) + or request.connect_timeout + or request.request_timeout + ) # min but skip zero + if timeout: + timeout_handle = self.io_loop.add_timeout( + self.io_loop.time() + timeout, + functools.partial(self._on_timeout, key, "in request queue"), + ) + self.waiting[key] = (request, callback, timeout_handle) + self._process_queue() + if self.queue: + gen_log.debug( + "max_clients limit reached, request queued. " + "%d active, %d queued requests." % (len(self.active), len(self.queue)) + ) + + def _process_queue(self) -> None: + while self.queue and len(self.active) < self.max_clients: + key, request, callback = self.queue.popleft() + if key not in self.waiting: + continue + self._remove_timeout(key) + self.active[key] = (request, callback) + release_callback = functools.partial(self._release_fetch, key) + self._handle_request(request, release_callback, callback) + + def _connection_class(self) -> type: + return _HTTPConnection + + def _handle_request( + self, + request: HTTPRequest, + release_callback: Callable[[], None], + final_callback: Callable[[HTTPResponse], None], + ) -> None: + self._connection_class()( + self, + request, + release_callback, + final_callback, + self.max_buffer_size, + self.tcp_client, + self.max_header_size, + self.max_body_size, + ) + + def _release_fetch(self, key: object) -> None: + del self.active[key] + self._process_queue() + + def _remove_timeout(self, key: object) -> None: + if key in self.waiting: + request, callback, timeout_handle = self.waiting[key] + if timeout_handle is not None: + self.io_loop.remove_timeout(timeout_handle) + del self.waiting[key] + + def _on_timeout(self, key: object, info: Optional[str] = None) -> None: + """Timeout callback of request. + + Construct a timeout HTTPResponse when a timeout occurs. + + :arg object key: A simple object to mark the request. + :info string key: More detailed timeout information. + """ + request, callback, timeout_handle = self.waiting[key] + self.queue.remove((key, request, callback)) + + error_message = "Timeout {0}".format(info) if info else "Timeout" + timeout_response = HTTPResponse( + request, + 599, + error=HTTPTimeoutError(error_message), + request_time=self.io_loop.time() - request.start_time, + ) + self.io_loop.add_callback(callback, timeout_response) + del self.waiting[key] + + +class _HTTPConnection(httputil.HTTPMessageDelegate): + _SUPPORTED_METHODS = set( + ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] + ) + + def __init__( + self, + client: Optional[SimpleAsyncHTTPClient], + request: HTTPRequest, + release_callback: Callable[[], None], + final_callback: Callable[[HTTPResponse], None], + max_buffer_size: int, + tcp_client: TCPClient, + max_header_size: int, + max_body_size: int, + ) -> None: + self.io_loop = IOLoop.current() + self.start_time = self.io_loop.time() + self.start_wall_time = time.time() + self.client = client + self.request = request + self.release_callback = release_callback + self.final_callback = final_callback + self.max_buffer_size = max_buffer_size + self.tcp_client = tcp_client + self.max_header_size = max_header_size + self.max_body_size = max_body_size + self.code = None # type: Optional[int] + self.headers = None # type: Optional[httputil.HTTPHeaders] + self.chunks = [] # type: List[bytes] + self._decompressor = None + # Timeout handle returned by IOLoop.add_timeout + self._timeout = None # type: object + self._sockaddr = None + IOLoop.current().add_future( + gen.convert_yielded(self.run()), lambda f: f.result() + ) + + async def run(self) -> None: + try: + self.parsed = urllib.parse.urlsplit(_unicode(self.request.url)) + if self.parsed.scheme not in ("http", "https"): + raise ValueError("Unsupported url scheme: %s" % self.request.url) + # urlsplit results have hostname and port results, but they + # didn't support ipv6 literals until python 2.7. + netloc = self.parsed.netloc + if "@" in netloc: + userpass, _, netloc = netloc.rpartition("@") + host, port = httputil.split_host_and_port(netloc) + if port is None: + port = 443 if self.parsed.scheme == "https" else 80 + if re.match(r"^\[.*\]$", host): + # raw ipv6 addresses in urls are enclosed in brackets + host = host[1:-1] + self.parsed_hostname = host # save final host for _on_connect + + if self.request.allow_ipv6 is False: + af = socket.AF_INET + else: + af = socket.AF_UNSPEC + + ssl_options = self._get_ssl_options(self.parsed.scheme) + + source_ip = None + if self.request.network_interface: + if is_valid_ip(self.request.network_interface): + source_ip = self.request.network_interface + else: + raise ValueError( + "Unrecognized IPv4 or IPv6 address for network_interface, got %r" + % (self.request.network_interface,) + ) + + if self.request.connect_timeout and self.request.request_timeout: + timeout = min( + self.request.connect_timeout, self.request.request_timeout + ) + elif self.request.connect_timeout: + timeout = self.request.connect_timeout + elif self.request.request_timeout: + timeout = self.request.request_timeout + else: + timeout = 0 + if timeout: + self._timeout = self.io_loop.add_timeout( + self.start_time + timeout, + functools.partial(self._on_timeout, "while connecting"), + ) + stream = await self.tcp_client.connect( + host, + port, + af=af, + ssl_options=ssl_options, + max_buffer_size=self.max_buffer_size, + source_ip=source_ip, + ) + + if self.final_callback is None: + # final_callback is cleared if we've hit our timeout. + stream.close() + return + self.stream = stream + self.stream.set_close_callback(self.on_connection_close) + self._remove_timeout() + if self.final_callback is None: + return + if self.request.request_timeout: + self._timeout = self.io_loop.add_timeout( + self.start_time + self.request.request_timeout, + functools.partial(self._on_timeout, "during request"), + ) + if ( + self.request.method not in self._SUPPORTED_METHODS + and not self.request.allow_nonstandard_methods + ): + raise KeyError("unknown method %s" % self.request.method) + for key in ( + "proxy_host", + "proxy_port", + "proxy_username", + "proxy_password", + "proxy_auth_mode", + ): + if getattr(self.request, key, None): + raise NotImplementedError("%s not supported" % key) + if "Connection" not in self.request.headers: + self.request.headers["Connection"] = "close" + if "Host" not in self.request.headers: + if "@" in self.parsed.netloc: + self.request.headers["Host"] = self.parsed.netloc.rpartition("@")[ + -1 + ] + else: + self.request.headers["Host"] = self.parsed.netloc + username, password = None, None + if self.parsed.username is not None: + username, password = self.parsed.username, self.parsed.password + elif self.request.auth_username is not None: + username = self.request.auth_username + password = self.request.auth_password or "" + if username is not None: + assert password is not None + if self.request.auth_mode not in (None, "basic"): + raise ValueError("unsupported auth_mode %s", self.request.auth_mode) + self.request.headers["Authorization"] = "Basic " + _unicode( + base64.b64encode( + httputil.encode_username_password(username, password) + ) + ) + if self.request.user_agent: + self.request.headers["User-Agent"] = self.request.user_agent + elif self.request.headers.get("User-Agent") is None: + self.request.headers["User-Agent"] = "Tornado/{}".format(version) + if not self.request.allow_nonstandard_methods: + # Some HTTP methods nearly always have bodies while others + # almost never do. Fail in this case unless the user has + # opted out of sanity checks with allow_nonstandard_methods. + body_expected = self.request.method in ("POST", "PATCH", "PUT") + body_present = ( + self.request.body is not None + or self.request.body_producer is not None + ) + if (body_expected and not body_present) or ( + body_present and not body_expected + ): + raise ValueError( + "Body must %sbe None for method %s (unless " + "allow_nonstandard_methods is true)" + % ("not " if body_expected else "", self.request.method) + ) + if self.request.expect_100_continue: + self.request.headers["Expect"] = "100-continue" + if self.request.body is not None: + # When body_producer is used the caller is responsible for + # setting Content-Length (or else chunked encoding will be used). + self.request.headers["Content-Length"] = str(len(self.request.body)) + if ( + self.request.method == "POST" + and "Content-Type" not in self.request.headers + ): + self.request.headers[ + "Content-Type" + ] = "application/x-www-form-urlencoded" + if self.request.decompress_response: + self.request.headers["Accept-Encoding"] = "gzip" + req_path = (self.parsed.path or "/") + ( + ("?" + self.parsed.query) if self.parsed.query else "" + ) + self.connection = self._create_connection(stream) + start_line = httputil.RequestStartLine(self.request.method, req_path, "") + self.connection.write_headers(start_line, self.request.headers) + if self.request.expect_100_continue: + await self.connection.read_response(self) + else: + await self._write_body(True) + except Exception: + if not self._handle_exception(*sys.exc_info()): + raise + + def _get_ssl_options( + self, scheme: str + ) -> Union[None, Dict[str, Any], ssl.SSLContext]: + if scheme == "https": + if self.request.ssl_options is not None: + return self.request.ssl_options + # If we are using the defaults, don't construct a + # new SSLContext. + if ( + self.request.validate_cert + and self.request.ca_certs is None + and self.request.client_cert is None + and self.request.client_key is None + ): + return _client_ssl_defaults + ssl_ctx = ssl.create_default_context( + ssl.Purpose.SERVER_AUTH, cafile=self.request.ca_certs + ) + if not self.request.validate_cert: + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE + if self.request.client_cert is not None: + ssl_ctx.load_cert_chain( + self.request.client_cert, self.request.client_key + ) + if hasattr(ssl, "OP_NO_COMPRESSION"): + # See netutil.ssl_options_to_context + ssl_ctx.options |= ssl.OP_NO_COMPRESSION + return ssl_ctx + return None + + def _on_timeout(self, info: Optional[str] = None) -> None: + """Timeout callback of _HTTPConnection instance. + + Raise a `HTTPTimeoutError` when a timeout occurs. + + :info string key: More detailed timeout information. + """ + self._timeout = None + error_message = "Timeout {0}".format(info) if info else "Timeout" + if self.final_callback is not None: + self._handle_exception( + HTTPTimeoutError, HTTPTimeoutError(error_message), None + ) + + def _remove_timeout(self) -> None: + if self._timeout is not None: + self.io_loop.remove_timeout(self._timeout) + self._timeout = None + + def _create_connection(self, stream: IOStream) -> HTTP1Connection: + stream.set_nodelay(True) + connection = HTTP1Connection( + stream, + True, + HTTP1ConnectionParameters( + no_keep_alive=True, + max_header_size=self.max_header_size, + max_body_size=self.max_body_size, + decompress=bool(self.request.decompress_response), + ), + self._sockaddr, + ) + return connection + + async def _write_body(self, start_read: bool) -> None: + if self.request.body is not None: + self.connection.write(self.request.body) + elif self.request.body_producer is not None: + fut = self.request.body_producer(self.connection.write) + if fut is not None: + await fut + self.connection.finish() + if start_read: + try: + await self.connection.read_response(self) + except StreamClosedError: + if not self._handle_exception(*sys.exc_info()): + raise + + def _release(self) -> None: + if self.release_callback is not None: + release_callback = self.release_callback + self.release_callback = None # type: ignore + release_callback() + + def _run_callback(self, response: HTTPResponse) -> None: + self._release() + if self.final_callback is not None: + final_callback = self.final_callback + self.final_callback = None # type: ignore + self.io_loop.add_callback(final_callback, response) + + def _handle_exception( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> bool: + if self.final_callback is not None: + self._remove_timeout() + if isinstance(value, StreamClosedError): + if value.real_error is None: + value = HTTPStreamClosedError("Stream closed") + else: + value = value.real_error + self._run_callback( + HTTPResponse( + self.request, + 599, + error=value, + request_time=self.io_loop.time() - self.start_time, + start_time=self.start_wall_time, + ) + ) + + if hasattr(self, "stream"): + # TODO: this may cause a StreamClosedError to be raised + # by the connection's Future. Should we cancel the + # connection more gracefully? + self.stream.close() + return True + else: + # If our callback has already been called, we are probably + # catching an exception that is not caused by us but rather + # some child of our callback. Rather than drop it on the floor, + # pass it along, unless it's just the stream being closed. + return isinstance(value, StreamClosedError) + + def on_connection_close(self) -> None: + if self.final_callback is not None: + message = "Connection closed" + if self.stream.error: + raise self.stream.error + try: + raise HTTPStreamClosedError(message) + except HTTPStreamClosedError: + self._handle_exception(*sys.exc_info()) + + async def headers_received( + self, + first_line: Union[httputil.ResponseStartLine, httputil.RequestStartLine], + headers: httputil.HTTPHeaders, + ) -> None: + assert isinstance(first_line, httputil.ResponseStartLine) + if self.request.expect_100_continue and first_line.code == 100: + await self._write_body(False) + return + self.code = first_line.code + self.reason = first_line.reason + self.headers = headers + + if self._should_follow_redirect(): + return + + if self.request.header_callback is not None: + # Reassemble the start line. + self.request.header_callback("%s %s %s\r\n" % first_line) + for k, v in self.headers.get_all(): + self.request.header_callback("%s: %s\r\n" % (k, v)) + self.request.header_callback("\r\n") + + def _should_follow_redirect(self) -> bool: + if self.request.follow_redirects: + assert self.request.max_redirects is not None + return ( + self.code in (301, 302, 303, 307, 308) + and self.request.max_redirects > 0 + and self.headers is not None + and self.headers.get("Location") is not None + ) + return False + + def finish(self) -> None: + assert self.code is not None + data = b"".join(self.chunks) + self._remove_timeout() + original_request = getattr(self.request, "original_request", self.request) + if self._should_follow_redirect(): + assert isinstance(self.request, _RequestProxy) + assert self.headers is not None + new_request = copy.copy(self.request.request) + new_request.url = urllib.parse.urljoin( + self.request.url, self.headers["Location"] + ) + assert self.request.max_redirects is not None + new_request.max_redirects = self.request.max_redirects - 1 + del new_request.headers["Host"] + # https://tools.ietf.org/html/rfc7231#section-6.4 + # + # The original HTTP spec said that after a 301 or 302 + # redirect, the request method should be preserved. + # However, browsers implemented this by changing the + # method to GET, and the behavior stuck. 303 redirects + # always specified this POST-to-GET behavior, arguably + # for *all* methods, but libcurl < 7.70 only does this + # for POST, while libcurl >= 7.70 does it for other methods. + if (self.code == 303 and self.request.method != "HEAD") or ( + self.code in (301, 302) and self.request.method == "POST" + ): + new_request.method = "GET" + new_request.body = None # type: ignore + for h in [ + "Content-Length", + "Content-Type", + "Content-Encoding", + "Transfer-Encoding", + ]: + try: + del self.request.headers[h] + except KeyError: + pass + new_request.original_request = original_request # type: ignore + final_callback = self.final_callback + self.final_callback = None # type: ignore + self._release() + assert self.client is not None + fut = self.client.fetch(new_request, raise_error=False) + fut.add_done_callback(lambda f: final_callback(f.result())) + self._on_end_request() + return + if self.request.streaming_callback: + buffer = BytesIO() + else: + buffer = BytesIO(data) # TODO: don't require one big string? + response = HTTPResponse( + original_request, + self.code, + reason=getattr(self, "reason", None), + headers=self.headers, + request_time=self.io_loop.time() - self.start_time, + start_time=self.start_wall_time, + buffer=buffer, + effective_url=self.request.url, + ) + self._run_callback(response) + self._on_end_request() + + def _on_end_request(self) -> None: + self.stream.close() + + def data_received(self, chunk: bytes) -> None: + if self._should_follow_redirect(): + # We're going to follow a redirect so just discard the body. + return + if self.request.streaming_callback is not None: + self.request.streaming_callback(chunk) + else: + self.chunks.append(chunk) + + +if __name__ == "__main__": + AsyncHTTPClient.configure(SimpleAsyncHTTPClient) + main() diff --git a/venv/lib/python3.9/site-packages/tornado/speedups.abi3.so b/venv/lib/python3.9/site-packages/tornado/speedups.abi3.so Binary files differnew file mode 100755 index 00000000..8ef95904 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/speedups.abi3.so diff --git a/venv/lib/python3.9/site-packages/tornado/tcpclient.py b/venv/lib/python3.9/site-packages/tornado/tcpclient.py new file mode 100644 index 00000000..0a829062 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/tcpclient.py @@ -0,0 +1,332 @@ +# +# Copyright 2014 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""A non-blocking TCP connection factory. +""" + +import functools +import socket +import numbers +import datetime +import ssl +import typing + +from tornado.concurrent import Future, future_add_done_callback +from tornado.ioloop import IOLoop +from tornado.iostream import IOStream +from tornado import gen +from tornado.netutil import Resolver +from tornado.gen import TimeoutError + +from typing import Any, Union, Dict, Tuple, List, Callable, Iterator, Optional + +if typing.TYPE_CHECKING: + from typing import Set # noqa(F401) + +_INITIAL_CONNECT_TIMEOUT = 0.3 + + +class _Connector(object): + """A stateless implementation of the "Happy Eyeballs" algorithm. + + "Happy Eyeballs" is documented in RFC6555 as the recommended practice + for when both IPv4 and IPv6 addresses are available. + + In this implementation, we partition the addresses by family, and + make the first connection attempt to whichever address was + returned first by ``getaddrinfo``. If that connection fails or + times out, we begin a connection in parallel to the first address + of the other family. If there are additional failures we retry + with other addresses, keeping one connection attempt per family + in flight at a time. + + http://tools.ietf.org/html/rfc6555 + + """ + + def __init__( + self, + addrinfo: List[Tuple], + connect: Callable[ + [socket.AddressFamily, Tuple], Tuple[IOStream, "Future[IOStream]"] + ], + ) -> None: + self.io_loop = IOLoop.current() + self.connect = connect + + self.future = ( + Future() + ) # type: Future[Tuple[socket.AddressFamily, Any, IOStream]] + self.timeout = None # type: Optional[object] + self.connect_timeout = None # type: Optional[object] + self.last_error = None # type: Optional[Exception] + self.remaining = len(addrinfo) + self.primary_addrs, self.secondary_addrs = self.split(addrinfo) + self.streams = set() # type: Set[IOStream] + + @staticmethod + def split( + addrinfo: List[Tuple], + ) -> Tuple[ + List[Tuple[socket.AddressFamily, Tuple]], + List[Tuple[socket.AddressFamily, Tuple]], + ]: + """Partition the ``addrinfo`` list by address family. + + Returns two lists. The first list contains the first entry from + ``addrinfo`` and all others with the same family, and the + second list contains all other addresses (normally one list will + be AF_INET and the other AF_INET6, although non-standard resolvers + may return additional families). + """ + primary = [] + secondary = [] + primary_af = addrinfo[0][0] + for af, addr in addrinfo: + if af == primary_af: + primary.append((af, addr)) + else: + secondary.append((af, addr)) + return primary, secondary + + def start( + self, + timeout: float = _INITIAL_CONNECT_TIMEOUT, + connect_timeout: Optional[Union[float, datetime.timedelta]] = None, + ) -> "Future[Tuple[socket.AddressFamily, Any, IOStream]]": + self.try_connect(iter(self.primary_addrs)) + self.set_timeout(timeout) + if connect_timeout is not None: + self.set_connect_timeout(connect_timeout) + return self.future + + def try_connect(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]]) -> None: + try: + af, addr = next(addrs) + except StopIteration: + # We've reached the end of our queue, but the other queue + # might still be working. Send a final error on the future + # only when both queues are finished. + if self.remaining == 0 and not self.future.done(): + self.future.set_exception( + self.last_error or IOError("connection failed") + ) + return + stream, future = self.connect(af, addr) + self.streams.add(stream) + future_add_done_callback( + future, functools.partial(self.on_connect_done, addrs, af, addr) + ) + + def on_connect_done( + self, + addrs: Iterator[Tuple[socket.AddressFamily, Tuple]], + af: socket.AddressFamily, + addr: Tuple, + future: "Future[IOStream]", + ) -> None: + self.remaining -= 1 + try: + stream = future.result() + except Exception as e: + if self.future.done(): + return + # Error: try again (but remember what happened so we have an + # error to raise in the end) + self.last_error = e + self.try_connect(addrs) + if self.timeout is not None: + # If the first attempt failed, don't wait for the + # timeout to try an address from the secondary queue. + self.io_loop.remove_timeout(self.timeout) + self.on_timeout() + return + self.clear_timeouts() + if self.future.done(): + # This is a late arrival; just drop it. + stream.close() + else: + self.streams.discard(stream) + self.future.set_result((af, addr, stream)) + self.close_streams() + + def set_timeout(self, timeout: float) -> None: + self.timeout = self.io_loop.add_timeout( + self.io_loop.time() + timeout, self.on_timeout + ) + + def on_timeout(self) -> None: + self.timeout = None + if not self.future.done(): + self.try_connect(iter(self.secondary_addrs)) + + def clear_timeout(self) -> None: + if self.timeout is not None: + self.io_loop.remove_timeout(self.timeout) + + def set_connect_timeout( + self, connect_timeout: Union[float, datetime.timedelta] + ) -> None: + self.connect_timeout = self.io_loop.add_timeout( + connect_timeout, self.on_connect_timeout + ) + + def on_connect_timeout(self) -> None: + if not self.future.done(): + self.future.set_exception(TimeoutError()) + self.close_streams() + + def clear_timeouts(self) -> None: + if self.timeout is not None: + self.io_loop.remove_timeout(self.timeout) + if self.connect_timeout is not None: + self.io_loop.remove_timeout(self.connect_timeout) + + def close_streams(self) -> None: + for stream in self.streams: + stream.close() + + +class TCPClient(object): + """A non-blocking TCP connection factory. + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + """ + + def __init__(self, resolver: Optional[Resolver] = None) -> None: + if resolver is not None: + self.resolver = resolver + self._own_resolver = False + else: + self.resolver = Resolver() + self._own_resolver = True + + def close(self) -> None: + if self._own_resolver: + self.resolver.close() + + async def connect( + self, + host: str, + port: int, + af: socket.AddressFamily = socket.AF_UNSPEC, + ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None, + max_buffer_size: Optional[int] = None, + source_ip: Optional[str] = None, + source_port: Optional[int] = None, + timeout: Optional[Union[float, datetime.timedelta]] = None, + ) -> IOStream: + """Connect to the given host and port. + + Asynchronously returns an `.IOStream` (or `.SSLIOStream` if + ``ssl_options`` is not None). + + Using the ``source_ip`` kwarg, one can specify the source + IP address to use when establishing the connection. + In case the user needs to resolve and + use a specific interface, it has to be handled outside + of Tornado as this depends very much on the platform. + + Raises `TimeoutError` if the input future does not complete before + ``timeout``, which may be specified in any form allowed by + `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time + relative to `.IOLoop.time`) + + Similarly, when the user requires a certain source port, it can + be specified using the ``source_port`` arg. + + .. versionchanged:: 4.5 + Added the ``source_ip`` and ``source_port`` arguments. + + .. versionchanged:: 5.0 + Added the ``timeout`` argument. + """ + if timeout is not None: + if isinstance(timeout, numbers.Real): + timeout = IOLoop.current().time() + timeout + elif isinstance(timeout, datetime.timedelta): + timeout = IOLoop.current().time() + timeout.total_seconds() + else: + raise TypeError("Unsupported timeout %r" % timeout) + if timeout is not None: + addrinfo = await gen.with_timeout( + timeout, self.resolver.resolve(host, port, af) + ) + else: + addrinfo = await self.resolver.resolve(host, port, af) + connector = _Connector( + addrinfo, + functools.partial( + self._create_stream, + max_buffer_size, + source_ip=source_ip, + source_port=source_port, + ), + ) + af, addr, stream = await connector.start(connect_timeout=timeout) + # TODO: For better performance we could cache the (af, addr) + # information here and re-use it on subsequent connections to + # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2) + if ssl_options is not None: + if timeout is not None: + stream = await gen.with_timeout( + timeout, + stream.start_tls( + False, ssl_options=ssl_options, server_hostname=host + ), + ) + else: + stream = await stream.start_tls( + False, ssl_options=ssl_options, server_hostname=host + ) + return stream + + def _create_stream( + self, + max_buffer_size: int, + af: socket.AddressFamily, + addr: Tuple, + source_ip: Optional[str] = None, + source_port: Optional[int] = None, + ) -> Tuple[IOStream, "Future[IOStream]"]: + # Always connect in plaintext; we'll convert to ssl if necessary + # after one connection has completed. + source_port_bind = source_port if isinstance(source_port, int) else 0 + source_ip_bind = source_ip + if source_port_bind and not source_ip: + # User required a specific port, but did not specify + # a certain source IP, will bind to the default loopback. + source_ip_bind = "::1" if af == socket.AF_INET6 else "127.0.0.1" + # Trying to use the same address family as the requested af socket: + # - 127.0.0.1 for IPv4 + # - ::1 for IPv6 + socket_obj = socket.socket(af) + if source_port_bind or source_ip_bind: + # If the user requires binding also to a specific IP/port. + try: + socket_obj.bind((source_ip_bind, source_port_bind)) + except socket.error: + socket_obj.close() + # Fail loudly if unable to use the IP/port. + raise + try: + stream = IOStream(socket_obj, max_buffer_size=max_buffer_size) + except socket.error as e: + fu = Future() # type: Future[IOStream] + fu.set_exception(e) + return stream, fu + else: + return stream, stream.connect(addr) diff --git a/venv/lib/python3.9/site-packages/tornado/tcpserver.py b/venv/lib/python3.9/site-packages/tornado/tcpserver.py new file mode 100644 index 00000000..deab8f2a --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/tcpserver.py @@ -0,0 +1,390 @@ +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""A non-blocking, single-threaded TCP server.""" + +import errno +import os +import socket +import ssl + +from tornado import gen +from tornado.log import app_log +from tornado.ioloop import IOLoop +from tornado.iostream import IOStream, SSLIOStream +from tornado.netutil import ( + bind_sockets, + add_accept_handler, + ssl_wrap_socket, + _DEFAULT_BACKLOG, +) +from tornado import process +from tornado.util import errno_from_exception + +import typing +from typing import Union, Dict, Any, Iterable, Optional, Awaitable + +if typing.TYPE_CHECKING: + from typing import Callable, List # noqa: F401 + + +class TCPServer(object): + r"""A non-blocking, single-threaded TCP server. + + To use `TCPServer`, define a subclass which overrides the `handle_stream` + method. For example, a simple echo server could be defined like this:: + + from tornado.tcpserver import TCPServer + from tornado.iostream import StreamClosedError + + class EchoServer(TCPServer): + async def handle_stream(self, stream, address): + while True: + try: + data = await stream.read_until(b"\n") await + stream.write(data) + except StreamClosedError: + break + + To make this server serve SSL traffic, send the ``ssl_options`` keyword + argument with an `ssl.SSLContext` object. For compatibility with older + versions of Python ``ssl_options`` may also be a dictionary of keyword + arguments for the `ssl.wrap_socket` method.:: + + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"), + os.path.join(data_dir, "mydomain.key")) + TCPServer(ssl_options=ssl_ctx) + + `TCPServer` initialization follows one of three patterns: + + 1. `listen`: single-process:: + + async def main(): + server = TCPServer() + server.listen(8888) + await asyncio.Event.wait() + + asyncio.run(main()) + + While this example does not create multiple processes on its own, when + the ``reuse_port=True`` argument is passed to ``listen()`` you can run + the program multiple times to create a multi-process service. + + 2. `add_sockets`: multi-process:: + + sockets = bind_sockets(8888) + tornado.process.fork_processes(0) + async def post_fork_main(): + server = TCPServer() + server.add_sockets(sockets) + await asyncio.Event().wait() + asyncio.run(post_fork_main()) + + The `add_sockets` interface is more complicated, but it can be used with + `tornado.process.fork_processes` to run a multi-process service with all + worker processes forked from a single parent. `add_sockets` can also be + used in single-process servers if you want to create your listening + sockets in some way other than `~tornado.netutil.bind_sockets`. + + Note that when using this pattern, nothing that touches the event loop + can be run before ``fork_processes``. + + 3. `bind`/`start`: simple **deprecated** multi-process:: + + server = TCPServer() + server.bind(8888) + server.start(0) # Forks multiple sub-processes + IOLoop.current().start() + + This pattern is deprecated because it requires interfaces in the + `asyncio` module that have been deprecated since Python 3.10. Support for + creating multiple processes in the ``start`` method will be removed in a + future version of Tornado. + + .. versionadded:: 3.1 + The ``max_buffer_size`` argument. + + .. versionchanged:: 5.0 + The ``io_loop`` argument has been removed. + """ + + def __init__( + self, + ssl_options: Optional[Union[Dict[str, Any], ssl.SSLContext]] = None, + max_buffer_size: Optional[int] = None, + read_chunk_size: Optional[int] = None, + ) -> None: + self.ssl_options = ssl_options + self._sockets = {} # type: Dict[int, socket.socket] + self._handlers = {} # type: Dict[int, Callable[[], None]] + self._pending_sockets = [] # type: List[socket.socket] + self._started = False + self._stopped = False + self.max_buffer_size = max_buffer_size + self.read_chunk_size = read_chunk_size + + # Verify the SSL options. Otherwise we don't get errors until clients + # connect. This doesn't verify that the keys are legitimate, but + # the SSL module doesn't do that until there is a connected socket + # which seems like too much work + if self.ssl_options is not None and isinstance(self.ssl_options, dict): + # Only certfile is required: it can contain both keys + if "certfile" not in self.ssl_options: + raise KeyError('missing key "certfile" in ssl_options') + + if not os.path.exists(self.ssl_options["certfile"]): + raise ValueError( + 'certfile "%s" does not exist' % self.ssl_options["certfile"] + ) + if "keyfile" in self.ssl_options and not os.path.exists( + self.ssl_options["keyfile"] + ): + raise ValueError( + 'keyfile "%s" does not exist' % self.ssl_options["keyfile"] + ) + + def listen( + self, + port: int, + address: Optional[str] = None, + family: socket.AddressFamily = socket.AF_UNSPEC, + backlog: int = _DEFAULT_BACKLOG, + flags: Optional[int] = None, + reuse_port: bool = False, + ) -> None: + """Starts accepting connections on the given port. + + This method may be called more than once to listen on multiple ports. + `listen` takes effect immediately; it is not necessary to call + `TCPServer.start` afterwards. It is, however, necessary to start the + event loop if it is not already running. + + All arguments have the same meaning as in + `tornado.netutil.bind_sockets`. + + .. versionchanged:: 6.2 + + Added ``family``, ``backlog``, ``flags``, and ``reuse_port`` + arguments to match `tornado.netutil.bind_sockets`. + """ + sockets = bind_sockets( + port, + address=address, + family=family, + backlog=backlog, + flags=flags, + reuse_port=reuse_port, + ) + self.add_sockets(sockets) + + def add_sockets(self, sockets: Iterable[socket.socket]) -> None: + """Makes this server start accepting connections on the given sockets. + + The ``sockets`` parameter is a list of socket objects such as + those returned by `~tornado.netutil.bind_sockets`. + `add_sockets` is typically used in combination with that + method and `tornado.process.fork_processes` to provide greater + control over the initialization of a multi-process server. + """ + for sock in sockets: + self._sockets[sock.fileno()] = sock + self._handlers[sock.fileno()] = add_accept_handler( + sock, self._handle_connection + ) + + def add_socket(self, socket: socket.socket) -> None: + """Singular version of `add_sockets`. Takes a single socket object.""" + self.add_sockets([socket]) + + def bind( + self, + port: int, + address: Optional[str] = None, + family: socket.AddressFamily = socket.AF_UNSPEC, + backlog: int = _DEFAULT_BACKLOG, + flags: Optional[int] = None, + reuse_port: bool = False, + ) -> None: + """Binds this server to the given port on the given address. + + To start the server, call `start`. If you want to run this server in a + single process, you can call `listen` as a shortcut to the sequence of + `bind` and `start` calls. + + Address may be either an IP address or hostname. If it's a hostname, + the server will listen on all IP addresses associated with the name. + Address may be an empty string or None to listen on all available + interfaces. Family may be set to either `socket.AF_INET` or + `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise both + will be used if available. + + The ``backlog`` argument has the same meaning as for `socket.listen + <socket.socket.listen>`. The ``reuse_port`` argument has the same + meaning as for `.bind_sockets`. + + This method may be called multiple times prior to `start` to listen on + multiple ports or interfaces. + + .. versionchanged:: 4.4 + Added the ``reuse_port`` argument. + + .. versionchanged:: 6.2 + Added the ``flags`` argument to match `.bind_sockets`. + + .. deprecated:: 6.2 + Use either ``listen()`` or ``add_sockets()`` instead of ``bind()`` + and ``start()``. + """ + sockets = bind_sockets( + port, + address=address, + family=family, + backlog=backlog, + flags=flags, + reuse_port=reuse_port, + ) + if self._started: + self.add_sockets(sockets) + else: + self._pending_sockets.extend(sockets) + + def start( + self, num_processes: Optional[int] = 1, max_restarts: Optional[int] = None + ) -> None: + """Starts this server in the `.IOLoop`. + + By default, we run the server in this process and do not fork any + additional child process. + + If num_processes is ``None`` or <= 0, we detect the number of cores + available on this machine and fork that number of child + processes. If num_processes is given and > 1, we fork that + specific number of sub-processes. + + Since we use processes and not threads, there is no shared memory + between any server code. + + Note that multiple processes are not compatible with the autoreload + module (or the ``autoreload=True`` option to `tornado.web.Application` + which defaults to True when ``debug=True``). + When using multiple processes, no IOLoops can be created or + referenced until after the call to ``TCPServer.start(n)``. + + Values of ``num_processes`` other than 1 are not supported on Windows. + + The ``max_restarts`` argument is passed to `.fork_processes`. + + .. versionchanged:: 6.0 + + Added ``max_restarts`` argument. + + .. deprecated:: 6.2 + Use either ``listen()`` or ``add_sockets()`` instead of ``bind()`` + and ``start()``. + """ + assert not self._started + self._started = True + if num_processes != 1: + process.fork_processes(num_processes, max_restarts) + sockets = self._pending_sockets + self._pending_sockets = [] + self.add_sockets(sockets) + + def stop(self) -> None: + """Stops listening for new connections. + + Requests currently in progress may still continue after the + server is stopped. + """ + if self._stopped: + return + self._stopped = True + for fd, sock in self._sockets.items(): + assert sock.fileno() == fd + # Unregister socket from IOLoop + self._handlers.pop(fd)() + sock.close() + + def handle_stream( + self, stream: IOStream, address: tuple + ) -> Optional[Awaitable[None]]: + """Override to handle a new `.IOStream` from an incoming connection. + + This method may be a coroutine; if so any exceptions it raises + asynchronously will be logged. Accepting of incoming connections + will not be blocked by this coroutine. + + If this `TCPServer` is configured for SSL, ``handle_stream`` + may be called before the SSL handshake has completed. Use + `.SSLIOStream.wait_for_handshake` if you need to verify the client's + certificate or use NPN/ALPN. + + .. versionchanged:: 4.2 + Added the option for this method to be a coroutine. + """ + raise NotImplementedError() + + def _handle_connection(self, connection: socket.socket, address: Any) -> None: + if self.ssl_options is not None: + assert ssl, "Python 2.6+ and OpenSSL required for SSL" + try: + connection = ssl_wrap_socket( + connection, + self.ssl_options, + server_side=True, + do_handshake_on_connect=False, + ) + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_EOF: + return connection.close() + else: + raise + except socket.error as err: + # If the connection is closed immediately after it is created + # (as in a port scan), we can get one of several errors. + # wrap_socket makes an internal call to getpeername, + # which may return either EINVAL (Mac OS X) or ENOTCONN + # (Linux). If it returns ENOTCONN, this error is + # silently swallowed by the ssl module, so we need to + # catch another error later on (AttributeError in + # SSLIOStream._do_ssl_handshake). + # To test this behavior, try nmap with the -sT flag. + # https://github.com/tornadoweb/tornado/pull/750 + if errno_from_exception(err) in (errno.ECONNABORTED, errno.EINVAL): + return connection.close() + else: + raise + try: + if self.ssl_options is not None: + stream = SSLIOStream( + connection, + max_buffer_size=self.max_buffer_size, + read_chunk_size=self.read_chunk_size, + ) # type: IOStream + else: + stream = IOStream( + connection, + max_buffer_size=self.max_buffer_size, + read_chunk_size=self.read_chunk_size, + ) + + future = self.handle_stream(stream, address) + if future is not None: + IOLoop.current().add_future( + gen.convert_yielded(future), lambda f: f.result() + ) + except Exception: + app_log.error("Error in connection callback", exc_info=True) diff --git a/venv/lib/python3.9/site-packages/tornado/template.py b/venv/lib/python3.9/site-packages/tornado/template.py new file mode 100644 index 00000000..d53e977c --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/template.py @@ -0,0 +1,1047 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""A simple template system that compiles templates to Python code. + +Basic usage looks like:: + + t = template.Template("<html>{{ myvalue }}</html>") + print(t.generate(myvalue="XXX")) + +`Loader` is a class that loads templates from a root directory and caches +the compiled templates:: + + loader = template.Loader("/home/btaylor") + print(loader.load("test.html").generate(myvalue="XXX")) + +We compile all templates to raw Python. Error-reporting is currently... uh, +interesting. Syntax for the templates:: + + ### base.html + <html> + <head> + <title>{% block title %}Default title{% end %}</title> + </head> + <body> + <ul> + {% for student in students %} + {% block student %} + <li>{{ escape(student.name) }}</li> + {% end %} + {% end %} + </ul> + </body> + </html> + + ### bold.html + {% extends "base.html" %} + + {% block title %}A bolder title{% end %} + + {% block student %} + <li><span style="bold">{{ escape(student.name) }}</span></li> + {% end %} + +Unlike most other template systems, we do not put any restrictions on the +expressions you can include in your statements. ``if`` and ``for`` blocks get +translated exactly into Python, so you can do complex expressions like:: + + {% for student in [p for p in people if p.student and p.age > 23] %} + <li>{{ escape(student.name) }}</li> + {% end %} + +Translating directly to Python means you can apply functions to expressions +easily, like the ``escape()`` function in the examples above. You can pass +functions in to your template just like any other variable +(In a `.RequestHandler`, override `.RequestHandler.get_template_namespace`):: + + ### Python code + def add(x, y): + return x + y + template.execute(add=add) + + ### The template + {{ add(1, 2) }} + +We provide the functions `escape() <.xhtml_escape>`, `.url_escape()`, +`.json_encode()`, and `.squeeze()` to all templates by default. + +Typical applications do not create `Template` or `Loader` instances by +hand, but instead use the `~.RequestHandler.render` and +`~.RequestHandler.render_string` methods of +`tornado.web.RequestHandler`, which load templates automatically based +on the ``template_path`` `.Application` setting. + +Variable names beginning with ``_tt_`` are reserved by the template +system and should not be used by application code. + +Syntax Reference +---------------- + +Template expressions are surrounded by double curly braces: ``{{ ... }}``. +The contents may be any python expression, which will be escaped according +to the current autoescape setting and inserted into the output. Other +template directives use ``{% %}``. + +To comment out a section so that it is omitted from the output, surround it +with ``{# ... #}``. + + +To include a literal ``{{``, ``{%``, or ``{#`` in the output, escape them as +``{{!``, ``{%!``, and ``{#!``, respectively. + + +``{% apply *function* %}...{% end %}`` + Applies a function to the output of all template code between ``apply`` + and ``end``:: + + {% apply linkify %}{{name}} said: {{message}}{% end %} + + Note that as an implementation detail apply blocks are implemented + as nested functions and thus may interact strangely with variables + set via ``{% set %}``, or the use of ``{% break %}`` or ``{% continue %}`` + within loops. + +``{% autoescape *function* %}`` + Sets the autoescape mode for the current file. This does not affect + other files, even those referenced by ``{% include %}``. Note that + autoescaping can also be configured globally, at the `.Application` + or `Loader`.:: + + {% autoescape xhtml_escape %} + {% autoescape None %} + +``{% block *name* %}...{% end %}`` + Indicates a named, replaceable block for use with ``{% extends %}``. + Blocks in the parent template will be replaced with the contents of + the same-named block in a child template.:: + + <!-- base.html --> + <title>{% block title %}Default title{% end %}</title> + + <!-- mypage.html --> + {% extends "base.html" %} + {% block title %}My page title{% end %} + +``{% comment ... %}`` + A comment which will be removed from the template output. Note that + there is no ``{% end %}`` tag; the comment goes from the word ``comment`` + to the closing ``%}`` tag. + +``{% extends *filename* %}`` + Inherit from another template. Templates that use ``extends`` should + contain one or more ``block`` tags to replace content from the parent + template. Anything in the child template not contained in a ``block`` + tag will be ignored. For an example, see the ``{% block %}`` tag. + +``{% for *var* in *expr* %}...{% end %}`` + Same as the python ``for`` statement. ``{% break %}`` and + ``{% continue %}`` may be used inside the loop. + +``{% from *x* import *y* %}`` + Same as the python ``import`` statement. + +``{% if *condition* %}...{% elif *condition* %}...{% else %}...{% end %}`` + Conditional statement - outputs the first section whose condition is + true. (The ``elif`` and ``else`` sections are optional) + +``{% import *module* %}`` + Same as the python ``import`` statement. + +``{% include *filename* %}`` + Includes another template file. The included file can see all the local + variables as if it were copied directly to the point of the ``include`` + directive (the ``{% autoescape %}`` directive is an exception). + Alternately, ``{% module Template(filename, **kwargs) %}`` may be used + to include another template with an isolated namespace. + +``{% module *expr* %}`` + Renders a `~tornado.web.UIModule`. The output of the ``UIModule`` is + not escaped:: + + {% module Template("foo.html", arg=42) %} + + ``UIModules`` are a feature of the `tornado.web.RequestHandler` + class (and specifically its ``render`` method) and will not work + when the template system is used on its own in other contexts. + +``{% raw *expr* %}`` + Outputs the result of the given expression without autoescaping. + +``{% set *x* = *y* %}`` + Sets a local variable. + +``{% try %}...{% except %}...{% else %}...{% finally %}...{% end %}`` + Same as the python ``try`` statement. + +``{% while *condition* %}... {% end %}`` + Same as the python ``while`` statement. ``{% break %}`` and + ``{% continue %}`` may be used inside the loop. + +``{% whitespace *mode* %}`` + Sets the whitespace mode for the remainder of the current file + (or until the next ``{% whitespace %}`` directive). See + `filter_whitespace` for available options. New in Tornado 4.3. +""" + +import datetime +from io import StringIO +import linecache +import os.path +import posixpath +import re +import threading + +from tornado import escape +from tornado.log import app_log +from tornado.util import ObjectDict, exec_in, unicode_type + +from typing import Any, Union, Callable, List, Dict, Iterable, Optional, TextIO +import typing + +if typing.TYPE_CHECKING: + from typing import Tuple, ContextManager # noqa: F401 + +_DEFAULT_AUTOESCAPE = "xhtml_escape" + + +class _UnsetMarker: + pass + + +_UNSET = _UnsetMarker() + + +def filter_whitespace(mode: str, text: str) -> str: + """Transform whitespace in ``text`` according to ``mode``. + + Available modes are: + + * ``all``: Return all whitespace unmodified. + * ``single``: Collapse consecutive whitespace with a single whitespace + character, preserving newlines. + * ``oneline``: Collapse all runs of whitespace into a single space + character, removing all newlines in the process. + + .. versionadded:: 4.3 + """ + if mode == "all": + return text + elif mode == "single": + text = re.sub(r"([\t ]+)", " ", text) + text = re.sub(r"(\s*\n\s*)", "\n", text) + return text + elif mode == "oneline": + return re.sub(r"(\s+)", " ", text) + else: + raise Exception("invalid whitespace mode %s" % mode) + + +class Template(object): + """A compiled template. + + We compile into Python from the given template_string. You can generate + the template from variables with generate(). + """ + + # note that the constructor's signature is not extracted with + # autodoc because _UNSET looks like garbage. When changing + # this signature update website/sphinx/template.rst too. + def __init__( + self, + template_string: Union[str, bytes], + name: str = "<string>", + loader: Optional["BaseLoader"] = None, + compress_whitespace: Union[bool, _UnsetMarker] = _UNSET, + autoescape: Optional[Union[str, _UnsetMarker]] = _UNSET, + whitespace: Optional[str] = None, + ) -> None: + """Construct a Template. + + :arg str template_string: the contents of the template file. + :arg str name: the filename from which the template was loaded + (used for error message). + :arg tornado.template.BaseLoader loader: the `~tornado.template.BaseLoader` responsible + for this template, used to resolve ``{% include %}`` and ``{% extend %}`` directives. + :arg bool compress_whitespace: Deprecated since Tornado 4.3. + Equivalent to ``whitespace="single"`` if true and + ``whitespace="all"`` if false. + :arg str autoescape: The name of a function in the template + namespace, or ``None`` to disable escaping by default. + :arg str whitespace: A string specifying treatment of whitespace; + see `filter_whitespace` for options. + + .. versionchanged:: 4.3 + Added ``whitespace`` parameter; deprecated ``compress_whitespace``. + """ + self.name = escape.native_str(name) + + if compress_whitespace is not _UNSET: + # Convert deprecated compress_whitespace (bool) to whitespace (str). + if whitespace is not None: + raise Exception("cannot set both whitespace and compress_whitespace") + whitespace = "single" if compress_whitespace else "all" + if whitespace is None: + if loader and loader.whitespace: + whitespace = loader.whitespace + else: + # Whitespace defaults by filename. + if name.endswith(".html") or name.endswith(".js"): + whitespace = "single" + else: + whitespace = "all" + # Validate the whitespace setting. + assert whitespace is not None + filter_whitespace(whitespace, "") + + if not isinstance(autoescape, _UnsetMarker): + self.autoescape = autoescape # type: Optional[str] + elif loader: + self.autoescape = loader.autoescape + else: + self.autoescape = _DEFAULT_AUTOESCAPE + + self.namespace = loader.namespace if loader else {} + reader = _TemplateReader(name, escape.native_str(template_string), whitespace) + self.file = _File(self, _parse(reader, self)) + self.code = self._generate_python(loader) + self.loader = loader + try: + # Under python2.5, the fake filename used here must match + # the module name used in __name__ below. + # The dont_inherit flag prevents template.py's future imports + # from being applied to the generated code. + self.compiled = compile( + escape.to_unicode(self.code), + "%s.generated.py" % self.name.replace(".", "_"), + "exec", + dont_inherit=True, + ) + except Exception: + formatted_code = _format_code(self.code).rstrip() + app_log.error("%s code:\n%s", self.name, formatted_code) + raise + + def generate(self, **kwargs: Any) -> bytes: + """Generate this template with the given arguments.""" + namespace = { + "escape": escape.xhtml_escape, + "xhtml_escape": escape.xhtml_escape, + "url_escape": escape.url_escape, + "json_encode": escape.json_encode, + "squeeze": escape.squeeze, + "linkify": escape.linkify, + "datetime": datetime, + "_tt_utf8": escape.utf8, # for internal use + "_tt_string_types": (unicode_type, bytes), + # __name__ and __loader__ allow the traceback mechanism to find + # the generated source code. + "__name__": self.name.replace(".", "_"), + "__loader__": ObjectDict(get_source=lambda name: self.code), + } + namespace.update(self.namespace) + namespace.update(kwargs) + exec_in(self.compiled, namespace) + execute = typing.cast(Callable[[], bytes], namespace["_tt_execute"]) + # Clear the traceback module's cache of source data now that + # we've generated a new template (mainly for this module's + # unittests, where different tests reuse the same name). + linecache.clearcache() + return execute() + + def _generate_python(self, loader: Optional["BaseLoader"]) -> str: + buffer = StringIO() + try: + # named_blocks maps from names to _NamedBlock objects + named_blocks = {} # type: Dict[str, _NamedBlock] + ancestors = self._get_ancestors(loader) + ancestors.reverse() + for ancestor in ancestors: + ancestor.find_named_blocks(loader, named_blocks) + writer = _CodeWriter(buffer, named_blocks, loader, ancestors[0].template) + ancestors[0].generate(writer) + return buffer.getvalue() + finally: + buffer.close() + + def _get_ancestors(self, loader: Optional["BaseLoader"]) -> List["_File"]: + ancestors = [self.file] + for chunk in self.file.body.chunks: + if isinstance(chunk, _ExtendsBlock): + if not loader: + raise ParseError( + "{% extends %} block found, but no " "template loader" + ) + template = loader.load(chunk.name, self.name) + ancestors.extend(template._get_ancestors(loader)) + return ancestors + + +class BaseLoader(object): + """Base class for template loaders. + + You must use a template loader to use template constructs like + ``{% extends %}`` and ``{% include %}``. The loader caches all + templates after they are loaded the first time. + """ + + def __init__( + self, + autoescape: str = _DEFAULT_AUTOESCAPE, + namespace: Optional[Dict[str, Any]] = None, + whitespace: Optional[str] = None, + ) -> None: + """Construct a template loader. + + :arg str autoescape: The name of a function in the template + namespace, such as "xhtml_escape", or ``None`` to disable + autoescaping by default. + :arg dict namespace: A dictionary to be added to the default template + namespace, or ``None``. + :arg str whitespace: A string specifying default behavior for + whitespace in templates; see `filter_whitespace` for options. + Default is "single" for files ending in ".html" and ".js" and + "all" for other files. + + .. versionchanged:: 4.3 + Added ``whitespace`` parameter. + """ + self.autoescape = autoescape + self.namespace = namespace or {} + self.whitespace = whitespace + self.templates = {} # type: Dict[str, Template] + # self.lock protects self.templates. It's a reentrant lock + # because templates may load other templates via `include` or + # `extends`. Note that thanks to the GIL this code would be safe + # even without the lock, but could lead to wasted work as multiple + # threads tried to compile the same template simultaneously. + self.lock = threading.RLock() + + def reset(self) -> None: + """Resets the cache of compiled templates.""" + with self.lock: + self.templates = {} + + def resolve_path(self, name: str, parent_path: Optional[str] = None) -> str: + """Converts a possibly-relative path to absolute (used internally).""" + raise NotImplementedError() + + def load(self, name: str, parent_path: Optional[str] = None) -> Template: + """Loads a template.""" + name = self.resolve_path(name, parent_path=parent_path) + with self.lock: + if name not in self.templates: + self.templates[name] = self._create_template(name) + return self.templates[name] + + def _create_template(self, name: str) -> Template: + raise NotImplementedError() + + +class Loader(BaseLoader): + """A template loader that loads from a single root directory.""" + + def __init__(self, root_directory: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.root = os.path.abspath(root_directory) + + def resolve_path(self, name: str, parent_path: Optional[str] = None) -> str: + if ( + parent_path + and not parent_path.startswith("<") + and not parent_path.startswith("/") + and not name.startswith("/") + ): + current_path = os.path.join(self.root, parent_path) + file_dir = os.path.dirname(os.path.abspath(current_path)) + relative_path = os.path.abspath(os.path.join(file_dir, name)) + if relative_path.startswith(self.root): + name = relative_path[len(self.root) + 1 :] + return name + + def _create_template(self, name: str) -> Template: + path = os.path.join(self.root, name) + with open(path, "rb") as f: + template = Template(f.read(), name=name, loader=self) + return template + + +class DictLoader(BaseLoader): + """A template loader that loads from a dictionary.""" + + def __init__(self, dict: Dict[str, str], **kwargs: Any) -> None: + super().__init__(**kwargs) + self.dict = dict + + def resolve_path(self, name: str, parent_path: Optional[str] = None) -> str: + if ( + parent_path + and not parent_path.startswith("<") + and not parent_path.startswith("/") + and not name.startswith("/") + ): + file_dir = posixpath.dirname(parent_path) + name = posixpath.normpath(posixpath.join(file_dir, name)) + return name + + def _create_template(self, name: str) -> Template: + return Template(self.dict[name], name=name, loader=self) + + +class _Node(object): + def each_child(self) -> Iterable["_Node"]: + return () + + def generate(self, writer: "_CodeWriter") -> None: + raise NotImplementedError() + + def find_named_blocks( + self, loader: Optional[BaseLoader], named_blocks: Dict[str, "_NamedBlock"] + ) -> None: + for child in self.each_child(): + child.find_named_blocks(loader, named_blocks) + + +class _File(_Node): + def __init__(self, template: Template, body: "_ChunkList") -> None: + self.template = template + self.body = body + self.line = 0 + + def generate(self, writer: "_CodeWriter") -> None: + writer.write_line("def _tt_execute():", self.line) + with writer.indent(): + writer.write_line("_tt_buffer = []", self.line) + writer.write_line("_tt_append = _tt_buffer.append", self.line) + self.body.generate(writer) + writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line) + + def each_child(self) -> Iterable["_Node"]: + return (self.body,) + + +class _ChunkList(_Node): + def __init__(self, chunks: List[_Node]) -> None: + self.chunks = chunks + + def generate(self, writer: "_CodeWriter") -> None: + for chunk in self.chunks: + chunk.generate(writer) + + def each_child(self) -> Iterable["_Node"]: + return self.chunks + + +class _NamedBlock(_Node): + def __init__(self, name: str, body: _Node, template: Template, line: int) -> None: + self.name = name + self.body = body + self.template = template + self.line = line + + def each_child(self) -> Iterable["_Node"]: + return (self.body,) + + def generate(self, writer: "_CodeWriter") -> None: + block = writer.named_blocks[self.name] + with writer.include(block.template, self.line): + block.body.generate(writer) + + def find_named_blocks( + self, loader: Optional[BaseLoader], named_blocks: Dict[str, "_NamedBlock"] + ) -> None: + named_blocks[self.name] = self + _Node.find_named_blocks(self, loader, named_blocks) + + +class _ExtendsBlock(_Node): + def __init__(self, name: str) -> None: + self.name = name + + +class _IncludeBlock(_Node): + def __init__(self, name: str, reader: "_TemplateReader", line: int) -> None: + self.name = name + self.template_name = reader.name + self.line = line + + def find_named_blocks( + self, loader: Optional[BaseLoader], named_blocks: Dict[str, _NamedBlock] + ) -> None: + assert loader is not None + included = loader.load(self.name, self.template_name) + included.file.find_named_blocks(loader, named_blocks) + + def generate(self, writer: "_CodeWriter") -> None: + assert writer.loader is not None + included = writer.loader.load(self.name, self.template_name) + with writer.include(included, self.line): + included.file.body.generate(writer) + + +class _ApplyBlock(_Node): + def __init__(self, method: str, line: int, body: _Node) -> None: + self.method = method + self.line = line + self.body = body + + def each_child(self) -> Iterable["_Node"]: + return (self.body,) + + def generate(self, writer: "_CodeWriter") -> None: + method_name = "_tt_apply%d" % writer.apply_counter + writer.apply_counter += 1 + writer.write_line("def %s():" % method_name, self.line) + with writer.indent(): + writer.write_line("_tt_buffer = []", self.line) + writer.write_line("_tt_append = _tt_buffer.append", self.line) + self.body.generate(writer) + writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line) + writer.write_line( + "_tt_append(_tt_utf8(%s(%s())))" % (self.method, method_name), self.line + ) + + +class _ControlBlock(_Node): + def __init__(self, statement: str, line: int, body: _Node) -> None: + self.statement = statement + self.line = line + self.body = body + + def each_child(self) -> Iterable[_Node]: + return (self.body,) + + def generate(self, writer: "_CodeWriter") -> None: + writer.write_line("%s:" % self.statement, self.line) + with writer.indent(): + self.body.generate(writer) + # Just in case the body was empty + writer.write_line("pass", self.line) + + +class _IntermediateControlBlock(_Node): + def __init__(self, statement: str, line: int) -> None: + self.statement = statement + self.line = line + + def generate(self, writer: "_CodeWriter") -> None: + # In case the previous block was empty + writer.write_line("pass", self.line) + writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1) + + +class _Statement(_Node): + def __init__(self, statement: str, line: int) -> None: + self.statement = statement + self.line = line + + def generate(self, writer: "_CodeWriter") -> None: + writer.write_line(self.statement, self.line) + + +class _Expression(_Node): + def __init__(self, expression: str, line: int, raw: bool = False) -> None: + self.expression = expression + self.line = line + self.raw = raw + + def generate(self, writer: "_CodeWriter") -> None: + writer.write_line("_tt_tmp = %s" % self.expression, self.line) + writer.write_line( + "if isinstance(_tt_tmp, _tt_string_types):" " _tt_tmp = _tt_utf8(_tt_tmp)", + self.line, + ) + writer.write_line("else: _tt_tmp = _tt_utf8(str(_tt_tmp))", self.line) + if not self.raw and writer.current_template.autoescape is not None: + # In python3 functions like xhtml_escape return unicode, + # so we have to convert to utf8 again. + writer.write_line( + "_tt_tmp = _tt_utf8(%s(_tt_tmp))" % writer.current_template.autoescape, + self.line, + ) + writer.write_line("_tt_append(_tt_tmp)", self.line) + + +class _Module(_Expression): + def __init__(self, expression: str, line: int) -> None: + super().__init__("_tt_modules." + expression, line, raw=True) + + +class _Text(_Node): + def __init__(self, value: str, line: int, whitespace: str) -> None: + self.value = value + self.line = line + self.whitespace = whitespace + + def generate(self, writer: "_CodeWriter") -> None: + value = self.value + + # Compress whitespace if requested, with a crude heuristic to avoid + # altering preformatted whitespace. + if "<pre>" not in value: + value = filter_whitespace(self.whitespace, value) + + if value: + writer.write_line("_tt_append(%r)" % escape.utf8(value), self.line) + + +class ParseError(Exception): + """Raised for template syntax errors. + + ``ParseError`` instances have ``filename`` and ``lineno`` attributes + indicating the position of the error. + + .. versionchanged:: 4.3 + Added ``filename`` and ``lineno`` attributes. + """ + + def __init__( + self, message: str, filename: Optional[str] = None, lineno: int = 0 + ) -> None: + self.message = message + # The names "filename" and "lineno" are chosen for consistency + # with python SyntaxError. + self.filename = filename + self.lineno = lineno + + def __str__(self) -> str: + return "%s at %s:%d" % (self.message, self.filename, self.lineno) + + +class _CodeWriter(object): + def __init__( + self, + file: TextIO, + named_blocks: Dict[str, _NamedBlock], + loader: Optional[BaseLoader], + current_template: Template, + ) -> None: + self.file = file + self.named_blocks = named_blocks + self.loader = loader + self.current_template = current_template + self.apply_counter = 0 + self.include_stack = [] # type: List[Tuple[Template, int]] + self._indent = 0 + + def indent_size(self) -> int: + return self._indent + + def indent(self) -> "ContextManager": + class Indenter(object): + def __enter__(_) -> "_CodeWriter": + self._indent += 1 + return self + + def __exit__(_, *args: Any) -> None: + assert self._indent > 0 + self._indent -= 1 + + return Indenter() + + def include(self, template: Template, line: int) -> "ContextManager": + self.include_stack.append((self.current_template, line)) + self.current_template = template + + class IncludeTemplate(object): + def __enter__(_) -> "_CodeWriter": + return self + + def __exit__(_, *args: Any) -> None: + self.current_template = self.include_stack.pop()[0] + + return IncludeTemplate() + + def write_line( + self, line: str, line_number: int, indent: Optional[int] = None + ) -> None: + if indent is None: + indent = self._indent + line_comment = " # %s:%d" % (self.current_template.name, line_number) + if self.include_stack: + ancestors = [ + "%s:%d" % (tmpl.name, lineno) for (tmpl, lineno) in self.include_stack + ] + line_comment += " (via %s)" % ", ".join(reversed(ancestors)) + print(" " * indent + line + line_comment, file=self.file) + + +class _TemplateReader(object): + def __init__(self, name: str, text: str, whitespace: str) -> None: + self.name = name + self.text = text + self.whitespace = whitespace + self.line = 1 + self.pos = 0 + + def find(self, needle: str, start: int = 0, end: Optional[int] = None) -> int: + assert start >= 0, start + pos = self.pos + start += pos + if end is None: + index = self.text.find(needle, start) + else: + end += pos + assert end >= start + index = self.text.find(needle, start, end) + if index != -1: + index -= pos + return index + + def consume(self, count: Optional[int] = None) -> str: + if count is None: + count = len(self.text) - self.pos + newpos = self.pos + count + self.line += self.text.count("\n", self.pos, newpos) + s = self.text[self.pos : newpos] + self.pos = newpos + return s + + def remaining(self) -> int: + return len(self.text) - self.pos + + def __len__(self) -> int: + return self.remaining() + + def __getitem__(self, key: Union[int, slice]) -> str: + if isinstance(key, slice): + size = len(self) + start, stop, step = key.indices(size) + if start is None: + start = self.pos + else: + start += self.pos + if stop is not None: + stop += self.pos + return self.text[slice(start, stop, step)] + elif key < 0: + return self.text[key] + else: + return self.text[self.pos + key] + + def __str__(self) -> str: + return self.text[self.pos :] + + def raise_parse_error(self, msg: str) -> None: + raise ParseError(msg, self.name, self.line) + + +def _format_code(code: str) -> str: + lines = code.splitlines() + format = "%%%dd %%s\n" % len(repr(len(lines) + 1)) + return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)]) + + +def _parse( + reader: _TemplateReader, + template: Template, + in_block: Optional[str] = None, + in_loop: Optional[str] = None, +) -> _ChunkList: + body = _ChunkList([]) + while True: + # Find next template directive + curly = 0 + while True: + curly = reader.find("{", curly) + if curly == -1 or curly + 1 == reader.remaining(): + # EOF + if in_block: + reader.raise_parse_error( + "Missing {%% end %%} block for %s" % in_block + ) + body.chunks.append( + _Text(reader.consume(), reader.line, reader.whitespace) + ) + return body + # If the first curly brace is not the start of a special token, + # start searching from the character after it + if reader[curly + 1] not in ("{", "%", "#"): + curly += 1 + continue + # When there are more than 2 curlies in a row, use the + # innermost ones. This is useful when generating languages + # like latex where curlies are also meaningful + if ( + curly + 2 < reader.remaining() + and reader[curly + 1] == "{" + and reader[curly + 2] == "{" + ): + curly += 1 + continue + break + + # Append any text before the special token + if curly > 0: + cons = reader.consume(curly) + body.chunks.append(_Text(cons, reader.line, reader.whitespace)) + + start_brace = reader.consume(2) + line = reader.line + + # Template directives may be escaped as "{{!" or "{%!". + # In this case output the braces and consume the "!". + # This is especially useful in conjunction with jquery templates, + # which also use double braces. + if reader.remaining() and reader[0] == "!": + reader.consume(1) + body.chunks.append(_Text(start_brace, line, reader.whitespace)) + continue + + # Comment + if start_brace == "{#": + end = reader.find("#}") + if end == -1: + reader.raise_parse_error("Missing end comment #}") + contents = reader.consume(end).strip() + reader.consume(2) + continue + + # Expression + if start_brace == "{{": + end = reader.find("}}") + if end == -1: + reader.raise_parse_error("Missing end expression }}") + contents = reader.consume(end).strip() + reader.consume(2) + if not contents: + reader.raise_parse_error("Empty expression") + body.chunks.append(_Expression(contents, line)) + continue + + # Block + assert start_brace == "{%", start_brace + end = reader.find("%}") + if end == -1: + reader.raise_parse_error("Missing end block %}") + contents = reader.consume(end).strip() + reader.consume(2) + if not contents: + reader.raise_parse_error("Empty block tag ({% %})") + + operator, space, suffix = contents.partition(" ") + suffix = suffix.strip() + + # Intermediate ("else", "elif", etc) blocks + intermediate_blocks = { + "else": set(["if", "for", "while", "try"]), + "elif": set(["if"]), + "except": set(["try"]), + "finally": set(["try"]), + } + allowed_parents = intermediate_blocks.get(operator) + if allowed_parents is not None: + if not in_block: + reader.raise_parse_error( + "%s outside %s block" % (operator, allowed_parents) + ) + if in_block not in allowed_parents: + reader.raise_parse_error( + "%s block cannot be attached to %s block" % (operator, in_block) + ) + body.chunks.append(_IntermediateControlBlock(contents, line)) + continue + + # End tag + elif operator == "end": + if not in_block: + reader.raise_parse_error("Extra {% end %} block") + return body + + elif operator in ( + "extends", + "include", + "set", + "import", + "from", + "comment", + "autoescape", + "whitespace", + "raw", + "module", + ): + if operator == "comment": + continue + if operator == "extends": + suffix = suffix.strip('"').strip("'") + if not suffix: + reader.raise_parse_error("extends missing file path") + block = _ExtendsBlock(suffix) # type: _Node + elif operator in ("import", "from"): + if not suffix: + reader.raise_parse_error("import missing statement") + block = _Statement(contents, line) + elif operator == "include": + suffix = suffix.strip('"').strip("'") + if not suffix: + reader.raise_parse_error("include missing file path") + block = _IncludeBlock(suffix, reader, line) + elif operator == "set": + if not suffix: + reader.raise_parse_error("set missing statement") + block = _Statement(suffix, line) + elif operator == "autoescape": + fn = suffix.strip() # type: Optional[str] + if fn == "None": + fn = None + template.autoescape = fn + continue + elif operator == "whitespace": + mode = suffix.strip() + # Validate the selected mode + filter_whitespace(mode, "") + reader.whitespace = mode + continue + elif operator == "raw": + block = _Expression(suffix, line, raw=True) + elif operator == "module": + block = _Module(suffix, line) + body.chunks.append(block) + continue + + elif operator in ("apply", "block", "try", "if", "for", "while"): + # parse inner body recursively + if operator in ("for", "while"): + block_body = _parse(reader, template, operator, operator) + elif operator == "apply": + # apply creates a nested function so syntactically it's not + # in the loop. + block_body = _parse(reader, template, operator, None) + else: + block_body = _parse(reader, template, operator, in_loop) + + if operator == "apply": + if not suffix: + reader.raise_parse_error("apply missing method name") + block = _ApplyBlock(suffix, line, block_body) + elif operator == "block": + if not suffix: + reader.raise_parse_error("block missing name") + block = _NamedBlock(suffix, block_body, template, line) + else: + block = _ControlBlock(contents, line, block_body) + body.chunks.append(block) + continue + + elif operator in ("break", "continue"): + if not in_loop: + reader.raise_parse_error( + "%s outside %s block" % (operator, set(["for", "while"])) + ) + body.chunks.append(_Statement(contents, line)) + continue + + else: + reader.raise_parse_error("unknown operator: %r" % operator) diff --git a/venv/lib/python3.9/site-packages/tornado/test/__init__.py b/venv/lib/python3.9/site-packages/tornado/test/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/__init__.py diff --git a/venv/lib/python3.9/site-packages/tornado/test/__main__.py b/venv/lib/python3.9/site-packages/tornado/test/__main__.py new file mode 100644 index 00000000..430c895f --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/__main__.py @@ -0,0 +1,12 @@ +"""Shim to allow python -m tornado.test. + +This only works in python 2.7+. +""" +from tornado.test.runtests import all, main + +# tornado.testing.main autodiscovery relies on 'all' being present in +# the main module, so import it here even though it is not used directly. +# The following line prevents a pyflakes warning. +all = all + +main() diff --git a/venv/lib/python3.9/site-packages/tornado/test/asyncio_test.py b/venv/lib/python3.9/site-packages/tornado/test/asyncio_test.py new file mode 100644 index 00000000..348c0ceb --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/asyncio_test.py @@ -0,0 +1,209 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import asyncio +import unittest +import warnings + +from concurrent.futures import ThreadPoolExecutor +from tornado import gen +from tornado.ioloop import IOLoop +from tornado.platform.asyncio import ( + AsyncIOLoop, + to_asyncio_future, + AnyThreadEventLoopPolicy, +) +from tornado.testing import AsyncTestCase, gen_test + + +class AsyncIOLoopTest(AsyncTestCase): + @property + def asyncio_loop(self): + return self.io_loop.asyncio_loop # type: ignore + + def test_asyncio_callback(self): + # Basic test that the asyncio loop is set up correctly. + async def add_callback(): + asyncio.get_event_loop().call_soon(self.stop) + + self.asyncio_loop.run_until_complete(add_callback()) + self.wait() + + @gen_test + def test_asyncio_future(self): + # Test that we can yield an asyncio future from a tornado coroutine. + # Without 'yield from', we must wrap coroutines in ensure_future, + # which was introduced during Python 3.4, deprecating the prior "async". + if hasattr(asyncio, "ensure_future"): + ensure_future = asyncio.ensure_future + else: + # async is a reserved word in Python 3.7 + ensure_future = getattr(asyncio, "async") + + x = yield ensure_future( + asyncio.get_event_loop().run_in_executor(None, lambda: 42) + ) + self.assertEqual(x, 42) + + @gen_test + def test_asyncio_yield_from(self): + @gen.coroutine + def f(): + event_loop = asyncio.get_event_loop() + x = yield from event_loop.run_in_executor(None, lambda: 42) + return x + + result = yield f() + self.assertEqual(result, 42) + + def test_asyncio_adapter(self): + # This test demonstrates that when using the asyncio coroutine + # runner (i.e. run_until_complete), the to_asyncio_future + # adapter is needed. No adapter is needed in the other direction, + # as demonstrated by other tests in the package. + @gen.coroutine + def tornado_coroutine(): + yield gen.moment + raise gen.Return(42) + + async def native_coroutine_without_adapter(): + return await tornado_coroutine() + + async def native_coroutine_with_adapter(): + return await to_asyncio_future(tornado_coroutine()) + + # Use the adapter, but two degrees from the tornado coroutine. + async def native_coroutine_with_adapter2(): + return await to_asyncio_future(native_coroutine_without_adapter()) + + # Tornado supports native coroutines both with and without adapters + self.assertEqual(self.io_loop.run_sync(native_coroutine_without_adapter), 42) + self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter), 42) + self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter2), 42) + + # Asyncio only supports coroutines that yield asyncio-compatible + # Futures (which our Future is since 5.0). + self.assertEqual( + self.asyncio_loop.run_until_complete(native_coroutine_without_adapter()), + 42, + ) + self.assertEqual( + self.asyncio_loop.run_until_complete(native_coroutine_with_adapter()), + 42, + ) + self.assertEqual( + self.asyncio_loop.run_until_complete(native_coroutine_with_adapter2()), + 42, + ) + + +class LeakTest(unittest.TestCase): + def setUp(self): + # Trigger a cleanup of the mapping so we start with a clean slate. + AsyncIOLoop(make_current=False).close() + # If we don't clean up after ourselves other tests may fail on + # py34. + self.orig_policy = asyncio.get_event_loop_policy() + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + + def tearDown(self): + try: + loop = asyncio.get_event_loop_policy().get_event_loop() + except Exception: + # We may not have a current event loop at this point. + pass + else: + loop.close() + asyncio.set_event_loop_policy(self.orig_policy) + + def test_ioloop_close_leak(self): + orig_count = len(IOLoop._ioloop_for_asyncio) + for i in range(10): + # Create and close an AsyncIOLoop using Tornado interfaces. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + loop = AsyncIOLoop() + loop.close() + new_count = len(IOLoop._ioloop_for_asyncio) - orig_count + self.assertEqual(new_count, 0) + + def test_asyncio_close_leak(self): + orig_count = len(IOLoop._ioloop_for_asyncio) + for i in range(10): + # Create and close an AsyncIOMainLoop using asyncio interfaces. + loop = asyncio.new_event_loop() + loop.call_soon(IOLoop.current) + loop.call_soon(loop.stop) + loop.run_forever() + loop.close() + new_count = len(IOLoop._ioloop_for_asyncio) - orig_count + # Because the cleanup is run on new loop creation, we have one + # dangling entry in the map (but only one). + self.assertEqual(new_count, 1) + + +class AnyThreadEventLoopPolicyTest(unittest.TestCase): + def setUp(self): + self.orig_policy = asyncio.get_event_loop_policy() + self.executor = ThreadPoolExecutor(1) + + def tearDown(self): + asyncio.set_event_loop_policy(self.orig_policy) + self.executor.shutdown() + + def get_event_loop_on_thread(self): + def get_and_close_event_loop(): + """Get the event loop. Close it if one is returned. + + Returns the (closed) event loop. This is a silly thing + to do and leaves the thread in a broken state, but it's + enough for this test. Closing the loop avoids resource + leak warnings. + """ + loop = asyncio.get_event_loop() + loop.close() + return loop + + future = self.executor.submit(get_and_close_event_loop) + return future.result() + + def test_asyncio_accessor(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + # With the default policy, non-main threads don't get an event + # loop. + self.assertRaises( + RuntimeError, self.executor.submit(asyncio.get_event_loop).result + ) + # Set the policy and we can get a loop. + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + self.assertIsInstance( + self.executor.submit(asyncio.get_event_loop).result(), + asyncio.AbstractEventLoop, + ) + # Clean up to silence leak warnings. Always use asyncio since + # IOLoop doesn't (currently) close the underlying loop. + self.executor.submit(lambda: asyncio.get_event_loop().close()).result() # type: ignore + + def test_tornado_accessor(self): + # Tornado's IOLoop.current() API can create a loop for any thread, + # regardless of this event loop policy. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + self.assertIsInstance(self.executor.submit(IOLoop.current).result(), IOLoop) + # Clean up to silence leak warnings. Always use asyncio since + # IOLoop doesn't (currently) close the underlying loop. + self.executor.submit(lambda: asyncio.get_event_loop().close()).result() # type: ignore + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + self.assertIsInstance(self.executor.submit(IOLoop.current).result(), IOLoop) + self.executor.submit(lambda: asyncio.get_event_loop().close()).result() # type: ignore diff --git a/venv/lib/python3.9/site-packages/tornado/test/auth_test.py b/venv/lib/python3.9/site-packages/tornado/test/auth_test.py new file mode 100644 index 00000000..3cd715f7 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/auth_test.py @@ -0,0 +1,609 @@ +# These tests do not currently do much to verify the correct implementation +# of the openid/oauth protocols, they just exercise the major code paths +# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in +# python 3) + +import unittest + +from tornado.auth import ( + OpenIdMixin, + OAuthMixin, + OAuth2Mixin, + GoogleOAuth2Mixin, + FacebookGraphMixin, + TwitterMixin, +) +from tornado.escape import json_decode +from tornado import gen +from tornado.httpclient import HTTPClientError +from tornado.httputil import url_concat +from tornado.log import app_log +from tornado.testing import AsyncHTTPTestCase, ExpectLog +from tornado.web import RequestHandler, Application, HTTPError + +try: + from unittest import mock +except ImportError: + mock = None # type: ignore + + +class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin): + def initialize(self, test): + self._OPENID_ENDPOINT = test.get_url("/openid/server/authenticate") + + @gen.coroutine + def get(self): + if self.get_argument("openid.mode", None): + user = yield self.get_authenticated_user( + http_client=self.settings["http_client"] + ) + if user is None: + raise Exception("user is None") + self.finish(user) + return + res = self.authenticate_redirect() # type: ignore + assert res is None + + +class OpenIdServerAuthenticateHandler(RequestHandler): + def post(self): + if self.get_argument("openid.mode") != "check_authentication": + raise Exception("incorrect openid.mode %r") + self.write("is_valid:true") + + +class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin): + def initialize(self, test, version): + self._OAUTH_VERSION = version + self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token") + self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/oauth1/server/access_token") + + def _oauth_consumer_token(self): + return dict(key="asdf", secret="qwer") + + @gen.coroutine + def get(self): + if self.get_argument("oauth_token", None): + user = yield self.get_authenticated_user( + http_client=self.settings["http_client"] + ) + if user is None: + raise Exception("user is None") + self.finish(user) + return + yield self.authorize_redirect(http_client=self.settings["http_client"]) + + @gen.coroutine + def _oauth_get_user_future(self, access_token): + if self.get_argument("fail_in_get_user", None): + raise Exception("failing in get_user") + if access_token != dict(key="uiop", secret="5678"): + raise Exception("incorrect access token %r" % access_token) + return dict(email="foo@example.com") + + +class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler): + """Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine.""" + + @gen.coroutine + def get(self): + if self.get_argument("oauth_token", None): + # Ensure that any exceptions are set on the returned Future, + # not simply thrown into the surrounding StackContext. + try: + yield self.get_authenticated_user() + except Exception as e: + self.set_status(503) + self.write("got exception: %s" % e) + else: + yield self.authorize_redirect() + + +class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin): + def initialize(self, version): + self._OAUTH_VERSION = version + + def _oauth_consumer_token(self): + return dict(key="asdf", secret="qwer") + + def get(self): + params = self._oauth_request_parameters( + "http://www.example.com/api/asdf", + dict(key="uiop", secret="5678"), + parameters=dict(foo="bar"), + ) + self.write(params) + + +class OAuth1ServerRequestTokenHandler(RequestHandler): + def get(self): + self.write("oauth_token=zxcv&oauth_token_secret=1234") + + +class OAuth1ServerAccessTokenHandler(RequestHandler): + def get(self): + self.write("oauth_token=uiop&oauth_token_secret=5678") + + +class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin): + def initialize(self, test): + self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth2/server/authorize") + + def get(self): + res = self.authorize_redirect() # type: ignore + assert res is None + + +class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin): + def initialize(self, test): + self._OAUTH_AUTHORIZE_URL = test.get_url("/facebook/server/authorize") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/facebook/server/access_token") + self._FACEBOOK_BASE_URL = test.get_url("/facebook/server") + + @gen.coroutine + def get(self): + if self.get_argument("code", None): + user = yield self.get_authenticated_user( + redirect_uri=self.request.full_url(), + client_id=self.settings["facebook_api_key"], + client_secret=self.settings["facebook_secret"], + code=self.get_argument("code"), + ) + self.write(user) + else: + self.authorize_redirect( + redirect_uri=self.request.full_url(), + client_id=self.settings["facebook_api_key"], + extra_params={"scope": "read_stream,offline_access"}, + ) + + +class FacebookServerAccessTokenHandler(RequestHandler): + def get(self): + self.write(dict(access_token="asdf", expires_in=3600)) + + +class FacebookServerMeHandler(RequestHandler): + def get(self): + self.write("{}") + + +class TwitterClientHandler(RequestHandler, TwitterMixin): + def initialize(self, test): + self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/twitter/server/access_token") + self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize") + self._OAUTH_AUTHENTICATE_URL = test.get_url("/twitter/server/authenticate") + self._TWITTER_BASE_URL = test.get_url("/twitter/api") + + def get_auth_http_client(self): + return self.settings["http_client"] + + +class TwitterClientLoginHandler(TwitterClientHandler): + @gen.coroutine + def get(self): + if self.get_argument("oauth_token", None): + user = yield self.get_authenticated_user() + if user is None: + raise Exception("user is None") + self.finish(user) + return + yield self.authorize_redirect() + + +class TwitterClientAuthenticateHandler(TwitterClientHandler): + # Like TwitterClientLoginHandler, but uses authenticate_redirect + # instead of authorize_redirect. + @gen.coroutine + def get(self): + if self.get_argument("oauth_token", None): + user = yield self.get_authenticated_user() + if user is None: + raise Exception("user is None") + self.finish(user) + return + yield self.authenticate_redirect() + + +class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler): + @gen.coroutine + def get(self): + if self.get_argument("oauth_token", None): + user = yield self.get_authenticated_user() + self.finish(user) + else: + # New style: with @gen.coroutine the result must be yielded + # or else the request will be auto-finished too soon. + yield self.authorize_redirect() + + +class TwitterClientShowUserHandler(TwitterClientHandler): + @gen.coroutine + def get(self): + # TODO: would be nice to go through the login flow instead of + # cheating with a hard-coded access token. + try: + response = yield self.twitter_request( + "/users/show/%s" % self.get_argument("name"), + access_token=dict(key="hjkl", secret="vbnm"), + ) + except HTTPClientError: + # TODO(bdarnell): Should we catch HTTP errors and + # transform some of them (like 403s) into AuthError? + self.set_status(500) + self.finish("error from twitter request") + else: + self.finish(response) + + +class TwitterServerAccessTokenHandler(RequestHandler): + def get(self): + self.write("oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo") + + +class TwitterServerShowUserHandler(RequestHandler): + def get(self, screen_name): + if screen_name == "error": + raise HTTPError(500) + assert "oauth_nonce" in self.request.arguments + assert "oauth_timestamp" in self.request.arguments + assert "oauth_signature" in self.request.arguments + assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key" + assert self.get_argument("oauth_signature_method") == "HMAC-SHA1" + assert self.get_argument("oauth_version") == "1.0" + assert self.get_argument("oauth_token") == "hjkl" + self.write(dict(screen_name=screen_name, name=screen_name.capitalize())) + + +class TwitterServerVerifyCredentialsHandler(RequestHandler): + def get(self): + assert "oauth_nonce" in self.request.arguments + assert "oauth_timestamp" in self.request.arguments + assert "oauth_signature" in self.request.arguments + assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key" + assert self.get_argument("oauth_signature_method") == "HMAC-SHA1" + assert self.get_argument("oauth_version") == "1.0" + assert self.get_argument("oauth_token") == "hjkl" + self.write(dict(screen_name="foo", name="Foo")) + + +class AuthTest(AsyncHTTPTestCase): + def get_app(self): + return Application( + [ + # test endpoints + ("/openid/client/login", OpenIdClientLoginHandler, dict(test=self)), + ( + "/oauth10/client/login", + OAuth1ClientLoginHandler, + dict(test=self, version="1.0"), + ), + ( + "/oauth10/client/request_params", + OAuth1ClientRequestParametersHandler, + dict(version="1.0"), + ), + ( + "/oauth10a/client/login", + OAuth1ClientLoginHandler, + dict(test=self, version="1.0a"), + ), + ( + "/oauth10a/client/login_coroutine", + OAuth1ClientLoginCoroutineHandler, + dict(test=self, version="1.0a"), + ), + ( + "/oauth10a/client/request_params", + OAuth1ClientRequestParametersHandler, + dict(version="1.0a"), + ), + ("/oauth2/client/login", OAuth2ClientLoginHandler, dict(test=self)), + ("/facebook/client/login", FacebookClientLoginHandler, dict(test=self)), + ("/twitter/client/login", TwitterClientLoginHandler, dict(test=self)), + ( + "/twitter/client/authenticate", + TwitterClientAuthenticateHandler, + dict(test=self), + ), + ( + "/twitter/client/login_gen_coroutine", + TwitterClientLoginGenCoroutineHandler, + dict(test=self), + ), + ( + "/twitter/client/show_user", + TwitterClientShowUserHandler, + dict(test=self), + ), + # simulated servers + ("/openid/server/authenticate", OpenIdServerAuthenticateHandler), + ("/oauth1/server/request_token", OAuth1ServerRequestTokenHandler), + ("/oauth1/server/access_token", OAuth1ServerAccessTokenHandler), + ("/facebook/server/access_token", FacebookServerAccessTokenHandler), + ("/facebook/server/me", FacebookServerMeHandler), + ("/twitter/server/access_token", TwitterServerAccessTokenHandler), + (r"/twitter/api/users/show/(.*)\.json", TwitterServerShowUserHandler), + ( + r"/twitter/api/account/verify_credentials\.json", + TwitterServerVerifyCredentialsHandler, + ), + ], + http_client=self.http_client, + twitter_consumer_key="test_twitter_consumer_key", + twitter_consumer_secret="test_twitter_consumer_secret", + facebook_api_key="test_facebook_api_key", + facebook_secret="test_facebook_secret", + ) + + def test_openid_redirect(self): + response = self.fetch("/openid/client/login", follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue("/openid/server/authenticate?" in response.headers["Location"]) + + def test_openid_get_user(self): + response = self.fetch( + "/openid/client/login?openid.mode=blah" + "&openid.ns.ax=http://openid.net/srv/ax/1.0" + "&openid.ax.type.email=http://axschema.org/contact/email" + "&openid.ax.value.email=foo@example.com" + ) + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual(parsed["email"], "foo@example.com") + + def test_oauth10_redirect(self): + response = self.fetch("/oauth10/client/login", follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue( + response.headers["Location"].endswith( + "/oauth1/server/authorize?oauth_token=zxcv" + ) + ) + # the cookie is base64('zxcv')|base64('1234') + self.assertTrue( + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) + + def test_oauth10_get_user(self): + response = self.fetch( + "/oauth10/client/login?oauth_token=zxcv", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual(parsed["email"], "foo@example.com") + self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678")) + + def test_oauth10_request_parameters(self): + response = self.fetch("/oauth10/client/request_params") + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual(parsed["oauth_consumer_key"], "asdf") + self.assertEqual(parsed["oauth_token"], "uiop") + self.assertTrue("oauth_nonce" in parsed) + self.assertTrue("oauth_signature" in parsed) + + def test_oauth10a_redirect(self): + response = self.fetch("/oauth10a/client/login", follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue( + response.headers["Location"].endswith( + "/oauth1/server/authorize?oauth_token=zxcv" + ) + ) + # the cookie is base64('zxcv')|base64('1234') + self.assertTrue( + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) + + @unittest.skipIf(mock is None, "mock package not present") + def test_oauth10a_redirect_error(self): + with mock.patch.object(OAuth1ServerRequestTokenHandler, "get") as get: + get.side_effect = Exception("boom") + with ExpectLog(app_log, "Uncaught exception"): + response = self.fetch("/oauth10a/client/login", follow_redirects=False) + self.assertEqual(response.code, 500) + + def test_oauth10a_get_user(self): + response = self.fetch( + "/oauth10a/client/login?oauth_token=zxcv", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual(parsed["email"], "foo@example.com") + self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678")) + + def test_oauth10a_request_parameters(self): + response = self.fetch("/oauth10a/client/request_params") + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual(parsed["oauth_consumer_key"], "asdf") + self.assertEqual(parsed["oauth_token"], "uiop") + self.assertTrue("oauth_nonce" in parsed) + self.assertTrue("oauth_signature" in parsed) + + def test_oauth10a_get_user_coroutine_exception(self): + response = self.fetch( + "/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) + self.assertEqual(response.code, 503) + + def test_oauth2_redirect(self): + response = self.fetch("/oauth2/client/login", follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue("/oauth2/server/authorize?" in response.headers["Location"]) + + def test_facebook_login(self): + response = self.fetch("/facebook/client/login", follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue("/facebook/server/authorize?" in response.headers["Location"]) + response = self.fetch( + "/facebook/client/login?code=1234", follow_redirects=False + ) + self.assertEqual(response.code, 200) + user = json_decode(response.body) + self.assertEqual(user["access_token"], "asdf") + self.assertEqual(user["session_expires"], "3600") + + def base_twitter_redirect(self, url): + # Same as test_oauth10a_redirect + response = self.fetch(url, follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue( + response.headers["Location"].endswith( + "/oauth1/server/authorize?oauth_token=zxcv" + ) + ) + # the cookie is base64('zxcv')|base64('1234') + self.assertTrue( + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) + + def test_twitter_redirect(self): + self.base_twitter_redirect("/twitter/client/login") + + def test_twitter_redirect_gen_coroutine(self): + self.base_twitter_redirect("/twitter/client/login_gen_coroutine") + + def test_twitter_authenticate_redirect(self): + response = self.fetch("/twitter/client/authenticate", follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue( + response.headers["Location"].endswith( + "/twitter/server/authenticate?oauth_token=zxcv" + ), + response.headers["Location"], + ) + # the cookie is base64('zxcv')|base64('1234') + self.assertTrue( + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) + + def test_twitter_get_user(self): + response = self.fetch( + "/twitter/client/login?oauth_token=zxcv", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual( + parsed, + { + "access_token": { + "key": "hjkl", + "screen_name": "foo", + "secret": "vbnm", + }, + "name": "Foo", + "screen_name": "foo", + "username": "foo", + }, + ) + + def test_twitter_show_user(self): + response = self.fetch("/twitter/client/show_user?name=somebody") + response.rethrow() + self.assertEqual( + json_decode(response.body), {"name": "Somebody", "screen_name": "somebody"} + ) + + def test_twitter_show_user_error(self): + response = self.fetch("/twitter/client/show_user?name=error") + self.assertEqual(response.code, 500) + self.assertEqual(response.body, b"error from twitter request") + + +class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin): + def initialize(self, test): + self.test = test + self._OAUTH_REDIRECT_URI = test.get_url("/client/login") + self._OAUTH_AUTHORIZE_URL = test.get_url("/google/oauth2/authorize") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/google/oauth2/token") + + @gen.coroutine + def get(self): + code = self.get_argument("code", None) + if code is not None: + # retrieve authenticate google user + access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI, code) + user = yield self.oauth2_request( + self.test.get_url("/google/oauth2/userinfo"), + access_token=access["access_token"], + ) + # return the user and access token as json + user["access_token"] = access["access_token"] + self.write(user) + else: + self.authorize_redirect( + redirect_uri=self._OAUTH_REDIRECT_URI, + client_id=self.settings["google_oauth"]["key"], + client_secret=self.settings["google_oauth"]["secret"], + scope=["profile", "email"], + response_type="code", + extra_params={"prompt": "select_account"}, + ) + + +class GoogleOAuth2AuthorizeHandler(RequestHandler): + def get(self): + # issue a fake auth code and redirect to redirect_uri + code = "fake-authorization-code" + self.redirect(url_concat(self.get_argument("redirect_uri"), dict(code=code))) + + +class GoogleOAuth2TokenHandler(RequestHandler): + def post(self): + assert self.get_argument("code") == "fake-authorization-code" + # issue a fake token + self.finish( + {"access_token": "fake-access-token", "expires_in": "never-expires"} + ) + + +class GoogleOAuth2UserinfoHandler(RequestHandler): + def get(self): + assert self.get_argument("access_token") == "fake-access-token" + # return a fake user + self.finish({"name": "Foo", "email": "foo@example.com"}) + + +class GoogleOAuth2Test(AsyncHTTPTestCase): + def get_app(self): + return Application( + [ + # test endpoints + ("/client/login", GoogleLoginHandler, dict(test=self)), + # simulated google authorization server endpoints + ("/google/oauth2/authorize", GoogleOAuth2AuthorizeHandler), + ("/google/oauth2/token", GoogleOAuth2TokenHandler), + ("/google/oauth2/userinfo", GoogleOAuth2UserinfoHandler), + ], + google_oauth={ + "key": "fake_google_client_id", + "secret": "fake_google_client_secret", + }, + ) + + def test_google_login(self): + response = self.fetch("/client/login") + self.assertDictEqual( + { + "name": "Foo", + "email": "foo@example.com", + "access_token": "fake-access-token", + }, + json_decode(response.body), + ) diff --git a/venv/lib/python3.9/site-packages/tornado/test/autoreload_test.py b/venv/lib/python3.9/site-packages/tornado/test/autoreload_test.py new file mode 100644 index 00000000..be481e10 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/autoreload_test.py @@ -0,0 +1,127 @@ +import os +import shutil +import subprocess +from subprocess import Popen +import sys +from tempfile import mkdtemp +import time +import unittest + + +class AutoreloadTest(unittest.TestCase): + def setUp(self): + self.path = mkdtemp() + + def tearDown(self): + try: + shutil.rmtree(self.path) + except OSError: + # Windows disallows deleting files that are in use by + # another process, and even though we've waited for our + # child process below, it appears that its lock on these + # files is not guaranteed to be released by this point. + # Sleep and try again (once). + time.sleep(1) + shutil.rmtree(self.path) + + def test_reload_module(self): + main = """\ +import os +import sys + +from tornado import autoreload + +# This import will fail if path is not set up correctly +import testapp + +print('Starting') +if 'TESTAPP_STARTED' not in os.environ: + os.environ['TESTAPP_STARTED'] = '1' + sys.stdout.flush() + autoreload._reload() +""" + + # Create temporary test application + os.mkdir(os.path.join(self.path, "testapp")) + open(os.path.join(self.path, "testapp/__init__.py"), "w").close() + with open(os.path.join(self.path, "testapp/__main__.py"), "w") as f: + f.write(main) + + # Make sure the tornado module under test is available to the test + # application + pythonpath = os.getcwd() + if "PYTHONPATH" in os.environ: + pythonpath += os.pathsep + os.environ["PYTHONPATH"] + + p = Popen( + [sys.executable, "-m", "testapp"], + stdout=subprocess.PIPE, + cwd=self.path, + env=dict(os.environ, PYTHONPATH=pythonpath), + universal_newlines=True, + ) + out = p.communicate()[0] + self.assertEqual(out, "Starting\nStarting\n") + + def test_reload_wrapper_preservation(self): + # This test verifies that when `python -m tornado.autoreload` + # is used on an application that also has an internal + # autoreload, the reload wrapper is preserved on restart. + main = """\ +import os +import sys + +# This import will fail if path is not set up correctly +import testapp + +if 'tornado.autoreload' not in sys.modules: + raise Exception('started without autoreload wrapper') + +import tornado.autoreload + +print('Starting') +sys.stdout.flush() +if 'TESTAPP_STARTED' not in os.environ: + os.environ['TESTAPP_STARTED'] = '1' + # Simulate an internal autoreload (one not caused + # by the wrapper). + tornado.autoreload._reload() +else: + # Exit directly so autoreload doesn't catch it. + os._exit(0) +""" + + # Create temporary test application + os.mkdir(os.path.join(self.path, "testapp")) + init_file = os.path.join(self.path, "testapp", "__init__.py") + open(init_file, "w").close() + main_file = os.path.join(self.path, "testapp", "__main__.py") + with open(main_file, "w") as f: + f.write(main) + + # Make sure the tornado module under test is available to the test + # application + pythonpath = os.getcwd() + if "PYTHONPATH" in os.environ: + pythonpath += os.pathsep + os.environ["PYTHONPATH"] + + autoreload_proc = Popen( + [sys.executable, "-m", "tornado.autoreload", "-m", "testapp"], + stdout=subprocess.PIPE, + cwd=self.path, + env=dict(os.environ, PYTHONPATH=pythonpath), + universal_newlines=True, + ) + + # This timeout needs to be fairly generous for pypy due to jit + # warmup costs. + for i in range(40): + if autoreload_proc.poll() is not None: + break + time.sleep(0.1) + else: + autoreload_proc.kill() + raise Exception("subprocess failed to terminate") + + out = autoreload_proc.communicate()[0] + self.assertEqual(out, "Starting\n" * 2) diff --git a/venv/lib/python3.9/site-packages/tornado/test/concurrent_test.py b/venv/lib/python3.9/site-packages/tornado/test/concurrent_test.py new file mode 100644 index 00000000..33fcb650 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/concurrent_test.py @@ -0,0 +1,212 @@ +# +# Copyright 2012 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from concurrent import futures +import logging +import re +import socket +import typing +import unittest + +from tornado.concurrent import ( + Future, + run_on_executor, + future_set_result_unless_cancelled, +) +from tornado.escape import utf8, to_unicode +from tornado import gen +from tornado.iostream import IOStream +from tornado.tcpserver import TCPServer +from tornado.testing import AsyncTestCase, bind_unused_port, gen_test + + +class MiscFutureTest(AsyncTestCase): + def test_future_set_result_unless_cancelled(self): + fut = Future() # type: Future[int] + future_set_result_unless_cancelled(fut, 42) + self.assertEqual(fut.result(), 42) + self.assertFalse(fut.cancelled()) + + fut = Future() + fut.cancel() + is_cancelled = fut.cancelled() + future_set_result_unless_cancelled(fut, 42) + self.assertEqual(fut.cancelled(), is_cancelled) + if not is_cancelled: + self.assertEqual(fut.result(), 42) + + +# The following series of classes demonstrate and test various styles +# of use, with and without generators and futures. + + +class CapServer(TCPServer): + @gen.coroutine + def handle_stream(self, stream, address): + data = yield stream.read_until(b"\n") + data = to_unicode(data) + if data == data.upper(): + stream.write(b"error\talready capitalized\n") + else: + # data already has \n + stream.write(utf8("ok\t%s" % data.upper())) + stream.close() + + +class CapError(Exception): + pass + + +class BaseCapClient(object): + def __init__(self, port): + self.port = port + + def process_response(self, data): + m = re.match("(.*)\t(.*)\n", to_unicode(data)) + if m is None: + raise Exception("did not match") + status, message = m.groups() + if status == "ok": + return message + else: + raise CapError(message) + + +class GeneratorCapClient(BaseCapClient): + @gen.coroutine + def capitalize(self, request_data): + logging.debug("capitalize") + stream = IOStream(socket.socket()) + logging.debug("connecting") + yield stream.connect(("127.0.0.1", self.port)) + stream.write(utf8(request_data + "\n")) + logging.debug("reading") + data = yield stream.read_until(b"\n") + logging.debug("returning") + stream.close() + raise gen.Return(self.process_response(data)) + + +class ClientTestMixin(object): + client_class = None # type: typing.Callable + + def setUp(self): + super().setUp() # type: ignore + self.server = CapServer() + sock, port = bind_unused_port() + self.server.add_sockets([sock]) + self.client = self.client_class(port=port) + + def tearDown(self): + self.server.stop() + super().tearDown() # type: ignore + + def test_future(self: typing.Any): + future = self.client.capitalize("hello") + self.io_loop.add_future(future, self.stop) + self.wait() + self.assertEqual(future.result(), "HELLO") + + def test_future_error(self: typing.Any): + future = self.client.capitalize("HELLO") + self.io_loop.add_future(future, self.stop) + self.wait() + self.assertRaisesRegex(CapError, "already capitalized", future.result) # type: ignore + + def test_generator(self: typing.Any): + @gen.coroutine + def f(): + result = yield self.client.capitalize("hello") + self.assertEqual(result, "HELLO") + + self.io_loop.run_sync(f) + + def test_generator_error(self: typing.Any): + @gen.coroutine + def f(): + with self.assertRaisesRegex(CapError, "already capitalized"): + yield self.client.capitalize("HELLO") + + self.io_loop.run_sync(f) + + +class GeneratorClientTest(ClientTestMixin, AsyncTestCase): + client_class = GeneratorCapClient + + +class RunOnExecutorTest(AsyncTestCase): + @gen_test + def test_no_calling(self): + class Object(object): + def __init__(self): + self.executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor + def f(self): + return 42 + + o = Object() + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_call_with_no_args(self): + class Object(object): + def __init__(self): + self.executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor() + def f(self): + return 42 + + o = Object() + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_call_with_executor(self): + class Object(object): + def __init__(self): + self.__executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor(executor="_Object__executor") + def f(self): + return 42 + + o = Object() + answer = yield o.f() + self.assertEqual(answer, 42) + + @gen_test + def test_async_await(self): + class Object(object): + def __init__(self): + self.executor = futures.thread.ThreadPoolExecutor(1) + + @run_on_executor() + def f(self): + return 42 + + o = Object() + + async def f(): + answer = await o.f() + return answer + + result = yield f() + self.assertEqual(result, 42) + + +if __name__ == "__main__": + unittest.main() diff --git a/venv/lib/python3.9/site-packages/tornado/test/csv_translations/fr_FR.csv b/venv/lib/python3.9/site-packages/tornado/test/csv_translations/fr_FR.csv new file mode 100644 index 00000000..6321b6e7 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/csv_translations/fr_FR.csv @@ -0,0 +1 @@ +"school","école" diff --git a/venv/lib/python3.9/site-packages/tornado/test/curl_httpclient_test.py b/venv/lib/python3.9/site-packages/tornado/test/curl_httpclient_test.py new file mode 100644 index 00000000..99af2933 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/curl_httpclient_test.py @@ -0,0 +1,129 @@ +from hashlib import md5 +import unittest + +from tornado.escape import utf8 +from tornado.testing import AsyncHTTPTestCase +from tornado.test import httpclient_test +from tornado.web import Application, RequestHandler + + +try: + import pycurl +except ImportError: + pycurl = None # type: ignore + +if pycurl is not None: + from tornado.curl_httpclient import CurlAsyncHTTPClient + + +@unittest.skipIf(pycurl is None, "pycurl module not present") +class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase): + def get_http_client(self): + client = CurlAsyncHTTPClient(defaults=dict(allow_ipv6=False)) + # make sure AsyncHTTPClient magic doesn't give us the wrong class + self.assertTrue(isinstance(client, CurlAsyncHTTPClient)) + return client + + +class DigestAuthHandler(RequestHandler): + def initialize(self, username, password): + self.username = username + self.password = password + + def get(self): + realm = "test" + opaque = "asdf" + # Real implementations would use a random nonce. + nonce = "1234" + + auth_header = self.request.headers.get("Authorization", None) + if auth_header is not None: + auth_mode, params = auth_header.split(" ", 1) + assert auth_mode == "Digest" + param_dict = {} + for pair in params.split(","): + k, v = pair.strip().split("=", 1) + if v[0] == '"' and v[-1] == '"': + v = v[1:-1] + param_dict[k] = v + assert param_dict["realm"] == realm + assert param_dict["opaque"] == opaque + assert param_dict["nonce"] == nonce + assert param_dict["username"] == self.username + assert param_dict["uri"] == self.request.path + h1 = md5( + utf8("%s:%s:%s" % (self.username, realm, self.password)) + ).hexdigest() + h2 = md5( + utf8("%s:%s" % (self.request.method, self.request.path)) + ).hexdigest() + digest = md5(utf8("%s:%s:%s" % (h1, nonce, h2))).hexdigest() + if digest == param_dict["response"]: + self.write("ok") + else: + self.write("fail") + else: + self.set_status(401) + self.set_header( + "WWW-Authenticate", + 'Digest realm="%s", nonce="%s", opaque="%s"' % (realm, nonce, opaque), + ) + + +class CustomReasonHandler(RequestHandler): + def get(self): + self.set_status(200, "Custom reason") + + +class CustomFailReasonHandler(RequestHandler): + def get(self): + self.set_status(400, "Custom reason") + + +@unittest.skipIf(pycurl is None, "pycurl module not present") +class CurlHTTPClientTestCase(AsyncHTTPTestCase): + def setUp(self): + super().setUp() + self.http_client = self.create_client() + + def get_app(self): + return Application( + [ + ("/digest", DigestAuthHandler, {"username": "foo", "password": "bar"}), + ( + "/digest_non_ascii", + DigestAuthHandler, + {"username": "foo", "password": "barユ£"}, + ), + ("/custom_reason", CustomReasonHandler), + ("/custom_fail_reason", CustomFailReasonHandler), + ] + ) + + def create_client(self, **kwargs): + return CurlAsyncHTTPClient( + force_instance=True, defaults=dict(allow_ipv6=False), **kwargs + ) + + def test_digest_auth(self): + response = self.fetch( + "/digest", auth_mode="digest", auth_username="foo", auth_password="bar" + ) + self.assertEqual(response.body, b"ok") + + def test_custom_reason(self): + response = self.fetch("/custom_reason") + self.assertEqual(response.reason, "Custom reason") + + def test_fail_custom_reason(self): + response = self.fetch("/custom_fail_reason") + self.assertEqual(str(response.error), "HTTP 400: Custom reason") + + def test_digest_auth_non_ascii(self): + response = self.fetch( + "/digest_non_ascii", + auth_mode="digest", + auth_username="foo", + auth_password="barユ£", + ) + self.assertEqual(response.body, b"ok") diff --git a/venv/lib/python3.9/site-packages/tornado/test/escape_test.py b/venv/lib/python3.9/site-packages/tornado/test/escape_test.py new file mode 100644 index 00000000..a90d11d6 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/escape_test.py @@ -0,0 +1,322 @@ +import unittest + +import tornado +from tornado.escape import ( + utf8, + xhtml_escape, + xhtml_unescape, + url_escape, + url_unescape, + to_unicode, + json_decode, + json_encode, + squeeze, + recursive_unicode, +) +from tornado.util import unicode_type + +from typing import List, Tuple, Union, Dict, Any # noqa: F401 + +linkify_tests = [ + # (input, linkify_kwargs, expected_output) + ( + "hello http://world.com/!", + {}, + 'hello <a href="http://world.com/">http://world.com/</a>!', + ), + ( + "hello http://world.com/with?param=true&stuff=yes", + {}, + 'hello <a href="http://world.com/with?param=true&stuff=yes">http://world.com/with?param=true&stuff=yes</a>', # noqa: E501 + ), + # an opened paren followed by many chars killed Gruber's regex + ( + "http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + {}, + '<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', # noqa: E501 + ), + # as did too many dots at the end + ( + "http://url.com/withmany.......................................", + {}, + '<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................', # noqa: E501 + ), + ( + "http://url.com/withmany((((((((((((((((((((((((((((((((((a)", + {}, + '<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)', # noqa: E501 + ), + # some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls + # plus a fex extras (such as multiple parentheses). + ( + "http://foo.com/blah_blah", + {}, + '<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>', + ), + ( + "http://foo.com/blah_blah/", + {}, + '<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>', + ), + ( + "(Something like http://foo.com/blah_blah)", + {}, + '(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)', + ), + ( + "http://foo.com/blah_blah_(wikipedia)", + {}, + '<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>', + ), + ( + "http://foo.com/blah_(blah)_(wikipedia)_blah", + {}, + '<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>', # noqa: E501 + ), + ( + "(Something like http://foo.com/blah_blah_(wikipedia))", + {}, + '(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)', # noqa: E501 + ), + ( + "http://foo.com/blah_blah.", + {}, + '<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.', + ), + ( + "http://foo.com/blah_blah/.", + {}, + '<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.', + ), + ( + "<http://foo.com/blah_blah>", + {}, + '<<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>>', + ), + ( + "<http://foo.com/blah_blah/>", + {}, + '<<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>>', + ), + ( + "http://foo.com/blah_blah,", + {}, + '<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,', + ), + ( + "http://www.example.com/wpstyle/?p=364.", + {}, + '<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.', # noqa: E501 + ), + ( + "rdar://1234", + {"permitted_protocols": ["http", "rdar"]}, + '<a href="rdar://1234">rdar://1234</a>', + ), + ( + "rdar:/1234", + {"permitted_protocols": ["rdar"]}, + '<a href="rdar:/1234">rdar:/1234</a>', + ), + ( + "http://userid:password@example.com:8080", + {}, + '<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>', # noqa: E501 + ), + ( + "http://userid@example.com", + {}, + '<a href="http://userid@example.com">http://userid@example.com</a>', + ), + ( + "http://userid@example.com:8080", + {}, + '<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>', + ), + ( + "http://userid:password@example.com", + {}, + '<a href="http://userid:password@example.com">http://userid:password@example.com</a>', + ), + ( + "message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e", + {"permitted_protocols": ["http", "message"]}, + '<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">' + "message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>", + ), + ( + "http://\u27a1.ws/\u4a39", + {}, + '<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>', + ), + ( + "<tag>http://example.com</tag>", + {}, + '<tag><a href="http://example.com">http://example.com</a></tag>', + ), + ( + "Just a www.example.com link.", + {}, + 'Just a <a href="http://www.example.com">www.example.com</a> link.', + ), + ( + "Just a www.example.com link.", + {"require_protocol": True}, + "Just a www.example.com link.", + ), + ( + "A http://reallylong.com/link/that/exceedsthelenglimit.html", + {"require_protocol": True, "shorten": True}, + 'A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html"' + ' title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>', # noqa: E501 + ), + ( + "A http://reallylongdomainnamethatwillbetoolong.com/hi!", + {"shorten": True}, + 'A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi"' + ' title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!', # noqa: E501 + ), + ( + "A file:///passwords.txt and http://web.com link", + {}, + 'A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link', + ), + ( + "A file:///passwords.txt and http://web.com link", + {"permitted_protocols": ["file"]}, + 'A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link', + ), + ( + "www.external-link.com", + {"extra_params": 'rel="nofollow" class="external"'}, + '<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>', # noqa: E501 + ), + ( + "www.external-link.com and www.internal-link.com/blogs extra", + { + "extra_params": lambda href: 'class="internal"' + if href.startswith("http://www.internal-link.com") + else 'rel="nofollow" class="external"' + }, + '<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>' # noqa: E501 + ' and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra', # noqa: E501 + ), + ( + "www.external-link.com", + {"extra_params": lambda href: ' rel="nofollow" class="external" '}, + '<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>', # noqa: E501 + ), +] # type: List[Tuple[Union[str, bytes], Dict[str, Any], str]] + + +class EscapeTestCase(unittest.TestCase): + def test_linkify(self): + for text, kwargs, html in linkify_tests: + linked = tornado.escape.linkify(text, **kwargs) + self.assertEqual(linked, html) + + def test_xhtml_escape(self): + tests = [ + ("<foo>", "<foo>"), + ("<foo>", "<foo>"), + (b"<foo>", b"<foo>"), + ("<>&\"'", "<>&"'"), + ("&", "&amp;"), + ("<\u00e9>", "<\u00e9>"), + (b"<\xc3\xa9>", b"<\xc3\xa9>"), + ] # type: List[Tuple[Union[str, bytes], Union[str, bytes]]] + for unescaped, escaped in tests: + self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped)) + self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped))) + + def test_xhtml_unescape_numeric(self): + tests = [ + ("foo bar", "foo bar"), + ("foo bar", "foo bar"), + ("foo bar", "foo bar"), + ("foo઼bar", "foo\u0abcbar"), + ("foo&#xyz;bar", "foo&#xyz;bar"), # invalid encoding + ("foo&#;bar", "foo&#;bar"), # invalid encoding + ("foo&#x;bar", "foo&#x;bar"), # invalid encoding + ] + for escaped, unescaped in tests: + self.assertEqual(unescaped, xhtml_unescape(escaped)) + + def test_url_escape_unicode(self): + tests = [ + # byte strings are passed through as-is + ("\u00e9".encode("utf8"), "%C3%A9"), + ("\u00e9".encode("latin1"), "%E9"), + # unicode strings become utf8 + ("\u00e9", "%C3%A9"), + ] # type: List[Tuple[Union[str, bytes], str]] + for unescaped, escaped in tests: + self.assertEqual(url_escape(unescaped), escaped) + + def test_url_unescape_unicode(self): + tests = [ + ("%C3%A9", "\u00e9", "utf8"), + ("%C3%A9", "\u00c3\u00a9", "latin1"), + ("%C3%A9", utf8("\u00e9"), None), + ] + for escaped, unescaped, encoding in tests: + # input strings to url_unescape should only contain ascii + # characters, but make sure the function accepts both byte + # and unicode strings. + self.assertEqual(url_unescape(to_unicode(escaped), encoding), unescaped) + self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped) + + def test_url_escape_quote_plus(self): + unescaped = "+ #%" + plus_escaped = "%2B+%23%25" + escaped = "%2B%20%23%25" + self.assertEqual(url_escape(unescaped), plus_escaped) + self.assertEqual(url_escape(unescaped, plus=False), escaped) + self.assertEqual(url_unescape(plus_escaped), unescaped) + self.assertEqual(url_unescape(escaped, plus=False), unescaped) + self.assertEqual(url_unescape(plus_escaped, encoding=None), utf8(unescaped)) + self.assertEqual( + url_unescape(escaped, encoding=None, plus=False), utf8(unescaped) + ) + + def test_escape_return_types(self): + # On python2 the escape methods should generally return the same + # type as their argument + self.assertEqual(type(xhtml_escape("foo")), str) + self.assertEqual(type(xhtml_escape("foo")), unicode_type) + + def test_json_decode(self): + # json_decode accepts both bytes and unicode, but strings it returns + # are always unicode. + self.assertEqual(json_decode(b'"foo"'), "foo") + self.assertEqual(json_decode('"foo"'), "foo") + + # Non-ascii bytes are interpreted as utf8 + self.assertEqual(json_decode(utf8('"\u00e9"')), "\u00e9") + + def test_json_encode(self): + # json deals with strings, not bytes. On python 2 byte strings will + # convert automatically if they are utf8; on python 3 byte strings + # are not allowed. + self.assertEqual(json_decode(json_encode("\u00e9")), "\u00e9") + if bytes is str: + self.assertEqual(json_decode(json_encode(utf8("\u00e9"))), "\u00e9") + self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9") + + def test_squeeze(self): + self.assertEqual( + squeeze("sequences of whitespace chars"), + "sequences of whitespace chars", + ) + + def test_recursive_unicode(self): + tests = { + "dict": {b"foo": b"bar"}, + "list": [b"foo", b"bar"], + "tuple": (b"foo", b"bar"), + "bytes": b"foo", + } + self.assertEqual(recursive_unicode(tests["dict"]), {"foo": "bar"}) + self.assertEqual(recursive_unicode(tests["list"]), ["foo", "bar"]) + self.assertEqual(recursive_unicode(tests["tuple"]), ("foo", "bar")) + self.assertEqual(recursive_unicode(tests["bytes"]), "foo") diff --git a/venv/lib/python3.9/site-packages/tornado/test/gen_test.py b/venv/lib/python3.9/site-packages/tornado/test/gen_test.py new file mode 100644 index 00000000..c17bf65f --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/gen_test.py @@ -0,0 +1,1129 @@ +import asyncio +from concurrent import futures +import gc +import datetime +import platform +import sys +import time +import weakref +import unittest + +from tornado.concurrent import Future +from tornado.log import app_log +from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test +from tornado.test.util import skipOnTravis, skipNotCPython +from tornado.web import Application, RequestHandler, HTTPError + +from tornado import gen + +try: + import contextvars +except ImportError: + contextvars = None # type: ignore + +import typing + +if typing.TYPE_CHECKING: + from typing import List, Optional # noqa: F401 + + +class GenBasicTest(AsyncTestCase): + @gen.coroutine + def delay(self, iterations, arg): + """Returns arg after a number of IOLoop iterations.""" + for i in range(iterations): + yield gen.moment + raise gen.Return(arg) + + @gen.coroutine + def async_future(self, result): + yield gen.moment + return result + + @gen.coroutine + def async_exception(self, e): + yield gen.moment + raise e + + @gen.coroutine + def add_one_async(self, x): + yield gen.moment + raise gen.Return(x + 1) + + def test_no_yield(self): + @gen.coroutine + def f(): + pass + + self.io_loop.run_sync(f) + + def test_exception_phase1(self): + @gen.coroutine + def f(): + 1 / 0 + + self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f) + + def test_exception_phase2(self): + @gen.coroutine + def f(): + yield gen.moment + 1 / 0 + + self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f) + + def test_bogus_yield(self): + @gen.coroutine + def f(): + yield 42 + + self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f) + + def test_bogus_yield_tuple(self): + @gen.coroutine + def f(): + yield (1, 2) + + self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f) + + def test_reuse(self): + @gen.coroutine + def f(): + yield gen.moment + + self.io_loop.run_sync(f) + self.io_loop.run_sync(f) + + def test_none(self): + @gen.coroutine + def f(): + yield None + + self.io_loop.run_sync(f) + + def test_multi(self): + @gen.coroutine + def f(): + results = yield [self.add_one_async(1), self.add_one_async(2)] + self.assertEqual(results, [2, 3]) + + self.io_loop.run_sync(f) + + def test_multi_dict(self): + @gen.coroutine + def f(): + results = yield dict(foo=self.add_one_async(1), bar=self.add_one_async(2)) + self.assertEqual(results, dict(foo=2, bar=3)) + + self.io_loop.run_sync(f) + + def test_multi_delayed(self): + @gen.coroutine + def f(): + # callbacks run at different times + responses = yield gen.multi_future( + [self.delay(3, "v1"), self.delay(1, "v2")] + ) + self.assertEqual(responses, ["v1", "v2"]) + + self.io_loop.run_sync(f) + + def test_multi_dict_delayed(self): + @gen.coroutine + def f(): + # callbacks run at different times + responses = yield gen.multi_future( + dict(foo=self.delay(3, "v1"), bar=self.delay(1, "v2")) + ) + self.assertEqual(responses, dict(foo="v1", bar="v2")) + + self.io_loop.run_sync(f) + + @skipOnTravis + @gen_test + def test_multi_performance(self): + # Yielding a list used to have quadratic performance; make + # sure a large list stays reasonable. On my laptop a list of + # 2000 used to take 1.8s, now it takes 0.12. + start = time.time() + yield [gen.moment for i in range(2000)] + end = time.time() + self.assertLess(end - start, 1.0) + + @gen_test + def test_multi_empty(self): + # Empty lists or dicts should return the same type. + x = yield [] + self.assertTrue(isinstance(x, list)) + y = yield {} + self.assertTrue(isinstance(y, dict)) + + @gen_test + def test_future(self): + result = yield self.async_future(1) + self.assertEqual(result, 1) + + @gen_test + def test_multi_future(self): + results = yield [self.async_future(1), self.async_future(2)] + self.assertEqual(results, [1, 2]) + + @gen_test + def test_multi_future_duplicate(self): + # Note that this doesn't work with native corotines, only with + # decorated coroutines. + f = self.async_future(2) + results = yield [self.async_future(1), f, self.async_future(3), f] + self.assertEqual(results, [1, 2, 3, 2]) + + @gen_test + def test_multi_dict_future(self): + results = yield dict(foo=self.async_future(1), bar=self.async_future(2)) + self.assertEqual(results, dict(foo=1, bar=2)) + + @gen_test + def test_multi_exceptions(self): + with ExpectLog(app_log, "Multiple exceptions in yield list"): + with self.assertRaises(RuntimeError) as cm: + yield gen.Multi( + [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ] + ) + self.assertEqual(str(cm.exception), "error 1") + + # With only one exception, no error is logged. + with self.assertRaises(RuntimeError): + yield gen.Multi( + [self.async_exception(RuntimeError("error 1")), self.async_future(2)] + ) + + # Exception logging may be explicitly quieted. + with self.assertRaises(RuntimeError): + yield gen.Multi( + [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ], + quiet_exceptions=RuntimeError, + ) + + @gen_test + def test_multi_future_exceptions(self): + with ExpectLog(app_log, "Multiple exceptions in yield list"): + with self.assertRaises(RuntimeError) as cm: + yield [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ] + self.assertEqual(str(cm.exception), "error 1") + + # With only one exception, no error is logged. + with self.assertRaises(RuntimeError): + yield [self.async_exception(RuntimeError("error 1")), self.async_future(2)] + + # Exception logging may be explicitly quieted. + with self.assertRaises(RuntimeError): + yield gen.multi_future( + [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ], + quiet_exceptions=RuntimeError, + ) + + def test_sync_raise_return(self): + @gen.coroutine + def f(): + raise gen.Return() + + self.io_loop.run_sync(f) + + def test_async_raise_return(self): + @gen.coroutine + def f(): + yield gen.moment + raise gen.Return() + + self.io_loop.run_sync(f) + + def test_sync_raise_return_value(self): + @gen.coroutine + def f(): + raise gen.Return(42) + + self.assertEqual(42, self.io_loop.run_sync(f)) + + def test_sync_raise_return_value_tuple(self): + @gen.coroutine + def f(): + raise gen.Return((1, 2)) + + self.assertEqual((1, 2), self.io_loop.run_sync(f)) + + def test_async_raise_return_value(self): + @gen.coroutine + def f(): + yield gen.moment + raise gen.Return(42) + + self.assertEqual(42, self.io_loop.run_sync(f)) + + def test_async_raise_return_value_tuple(self): + @gen.coroutine + def f(): + yield gen.moment + raise gen.Return((1, 2)) + + self.assertEqual((1, 2), self.io_loop.run_sync(f)) + + +class GenCoroutineTest(AsyncTestCase): + def setUp(self): + # Stray StopIteration exceptions can lead to tests exiting prematurely, + # so we need explicit checks here to make sure the tests run all + # the way through. + self.finished = False + super().setUp() + + def tearDown(self): + super().tearDown() + assert self.finished + + def test_attributes(self): + self.finished = True + + def f(): + yield gen.moment + + coro = gen.coroutine(f) + self.assertEqual(coro.__name__, f.__name__) + self.assertEqual(coro.__module__, f.__module__) + self.assertIs(coro.__wrapped__, f) # type: ignore + + def test_is_coroutine_function(self): + self.finished = True + + def f(): + yield gen.moment + + coro = gen.coroutine(f) + self.assertFalse(gen.is_coroutine_function(f)) + self.assertTrue(gen.is_coroutine_function(coro)) + self.assertFalse(gen.is_coroutine_function(coro())) + + @gen_test + def test_sync_gen_return(self): + @gen.coroutine + def f(): + raise gen.Return(42) + + result = yield f() + self.assertEqual(result, 42) + self.finished = True + + @gen_test + def test_async_gen_return(self): + @gen.coroutine + def f(): + yield gen.moment + raise gen.Return(42) + + result = yield f() + self.assertEqual(result, 42) + self.finished = True + + @gen_test + def test_sync_return(self): + @gen.coroutine + def f(): + return 42 + + result = yield f() + self.assertEqual(result, 42) + self.finished = True + + @gen_test + def test_async_return(self): + @gen.coroutine + def f(): + yield gen.moment + return 42 + + result = yield f() + self.assertEqual(result, 42) + self.finished = True + + @gen_test + def test_async_early_return(self): + # A yield statement exists but is not executed, which means + # this function "returns" via an exception. This exception + # doesn't happen before the exception handling is set up. + @gen.coroutine + def f(): + if True: + return 42 + yield gen.Task(self.io_loop.add_callback) + + result = yield f() + self.assertEqual(result, 42) + self.finished = True + + @gen_test + def test_async_await(self): + @gen.coroutine + def f1(): + yield gen.moment + raise gen.Return(42) + + # This test verifies that an async function can await a + # yield-based gen.coroutine, and that a gen.coroutine + # (the test method itself) can yield an async function. + async def f2(): + result = await f1() + return result + + result = yield f2() + self.assertEqual(result, 42) + self.finished = True + + @gen_test + def test_asyncio_sleep_zero(self): + # asyncio.sleep(0) turns into a special case (equivalent to + # `yield None`) + async def f(): + import asyncio + + await asyncio.sleep(0) + return 42 + + result = yield f() + self.assertEqual(result, 42) + self.finished = True + + @gen_test + def test_async_await_mixed_multi_native_future(self): + @gen.coroutine + def f1(): + yield gen.moment + + async def f2(): + await f1() + return 42 + + @gen.coroutine + def f3(): + yield gen.moment + raise gen.Return(43) + + results = yield [f2(), f3()] + self.assertEqual(results, [42, 43]) + self.finished = True + + @gen_test + def test_async_with_timeout(self): + async def f1(): + return 42 + + result = yield gen.with_timeout(datetime.timedelta(hours=1), f1()) + self.assertEqual(result, 42) + self.finished = True + + @gen_test + def test_sync_return_no_value(self): + @gen.coroutine + def f(): + return + + result = yield f() + self.assertEqual(result, None) + self.finished = True + + @gen_test + def test_async_return_no_value(self): + # Without a return value we don't need python 3.3. + @gen.coroutine + def f(): + yield gen.moment + return + + result = yield f() + self.assertEqual(result, None) + self.finished = True + + @gen_test + def test_sync_raise(self): + @gen.coroutine + def f(): + 1 / 0 + + # The exception is raised when the future is yielded + # (or equivalently when its result method is called), + # not when the function itself is called). + future = f() + with self.assertRaises(ZeroDivisionError): + yield future + self.finished = True + + @gen_test + def test_async_raise(self): + @gen.coroutine + def f(): + yield gen.moment + 1 / 0 + + future = f() + with self.assertRaises(ZeroDivisionError): + yield future + self.finished = True + + @gen_test + def test_replace_yieldpoint_exception(self): + # Test exception handling: a coroutine can catch one exception + # raised by a yield point and raise a different one. + @gen.coroutine + def f1(): + 1 / 0 + + @gen.coroutine + def f2(): + try: + yield f1() + except ZeroDivisionError: + raise KeyError() + + future = f2() + with self.assertRaises(KeyError): + yield future + self.finished = True + + @gen_test + def test_swallow_yieldpoint_exception(self): + # Test exception handling: a coroutine can catch an exception + # raised by a yield point and not raise a different one. + @gen.coroutine + def f1(): + 1 / 0 + + @gen.coroutine + def f2(): + try: + yield f1() + except ZeroDivisionError: + raise gen.Return(42) + + result = yield f2() + self.assertEqual(result, 42) + self.finished = True + + @gen_test + def test_moment(self): + calls = [] + + @gen.coroutine + def f(name, yieldable): + for i in range(5): + calls.append(name) + yield yieldable + + # First, confirm the behavior without moment: each coroutine + # monopolizes the event loop until it finishes. + immediate = Future() # type: Future[None] + immediate.set_result(None) + yield [f("a", immediate), f("b", immediate)] + self.assertEqual("".join(calls), "aaaaabbbbb") + + # With moment, they take turns. + calls = [] + yield [f("a", gen.moment), f("b", gen.moment)] + self.assertEqual("".join(calls), "ababababab") + self.finished = True + + calls = [] + yield [f("a", gen.moment), f("b", immediate)] + self.assertEqual("".join(calls), "abbbbbaaaa") + + @gen_test + def test_sleep(self): + yield gen.sleep(0.01) + self.finished = True + + @gen_test + def test_py3_leak_exception_context(self): + class LeakedException(Exception): + pass + + @gen.coroutine + def inner(iteration): + raise LeakedException(iteration) + + try: + yield inner(1) + except LeakedException as e: + self.assertEqual(str(e), "1") + self.assertIsNone(e.__context__) + + try: + yield inner(2) + except LeakedException as e: + self.assertEqual(str(e), "2") + self.assertIsNone(e.__context__) + + self.finished = True + + @skipNotCPython + @unittest.skipIf( + (3,) < sys.version_info < (3, 6), "asyncio.Future has reference cycles" + ) + def test_coroutine_refcounting(self): + # On CPython, tasks and their arguments should be released immediately + # without waiting for garbage collection. + @gen.coroutine + def inner(): + class Foo(object): + pass + + local_var = Foo() + self.local_ref = weakref.ref(local_var) + + def dummy(): + pass + + yield gen.coroutine(dummy)() + raise ValueError("Some error") + + @gen.coroutine + def inner2(): + try: + yield inner() + except ValueError: + pass + + self.io_loop.run_sync(inner2, timeout=3) + + self.assertIs(self.local_ref(), None) + self.finished = True + + def test_asyncio_future_debug_info(self): + self.finished = True + # Enable debug mode + asyncio_loop = asyncio.get_event_loop() + self.addCleanup(asyncio_loop.set_debug, asyncio_loop.get_debug()) + asyncio_loop.set_debug(True) + + def f(): + yield gen.moment + + coro = gen.coroutine(f)() + self.assertIsInstance(coro, asyncio.Future) + # We expect the coroutine repr() to show the place where + # it was instantiated + expected = "created at %s:%d" % (__file__, f.__code__.co_firstlineno + 3) + actual = repr(coro) + self.assertIn(expected, actual) + + @gen_test + def test_asyncio_gather(self): + # This demonstrates that tornado coroutines can be understood + # by asyncio (This failed prior to Tornado 5.0). + @gen.coroutine + def f(): + yield gen.moment + raise gen.Return(1) + + ret = yield asyncio.gather(f(), f()) + self.assertEqual(ret, [1, 1]) + self.finished = True + + +class GenCoroutineSequenceHandler(RequestHandler): + @gen.coroutine + def get(self): + yield gen.moment + self.write("1") + yield gen.moment + self.write("2") + yield gen.moment + self.finish("3") + + +class GenCoroutineUnfinishedSequenceHandler(RequestHandler): + @gen.coroutine + def get(self): + yield gen.moment + self.write("1") + yield gen.moment + self.write("2") + yield gen.moment + # just write, don't finish + self.write("3") + + +# "Undecorated" here refers to the absence of @asynchronous. +class UndecoratedCoroutinesHandler(RequestHandler): + @gen.coroutine + def prepare(self): + self.chunks = [] # type: List[str] + yield gen.moment + self.chunks.append("1") + + @gen.coroutine + def get(self): + self.chunks.append("2") + yield gen.moment + self.chunks.append("3") + yield gen.moment + self.write("".join(self.chunks)) + + +class AsyncPrepareErrorHandler(RequestHandler): + @gen.coroutine + def prepare(self): + yield gen.moment + raise HTTPError(403) + + def get(self): + self.finish("ok") + + +class NativeCoroutineHandler(RequestHandler): + async def get(self): + await asyncio.sleep(0) + self.write("ok") + + +class GenWebTest(AsyncHTTPTestCase): + def get_app(self): + return Application( + [ + ("/coroutine_sequence", GenCoroutineSequenceHandler), + ( + "/coroutine_unfinished_sequence", + GenCoroutineUnfinishedSequenceHandler, + ), + ("/undecorated_coroutine", UndecoratedCoroutinesHandler), + ("/async_prepare_error", AsyncPrepareErrorHandler), + ("/native_coroutine", NativeCoroutineHandler), + ] + ) + + def test_coroutine_sequence_handler(self): + response = self.fetch("/coroutine_sequence") + self.assertEqual(response.body, b"123") + + def test_coroutine_unfinished_sequence_handler(self): + response = self.fetch("/coroutine_unfinished_sequence") + self.assertEqual(response.body, b"123") + + def test_undecorated_coroutines(self): + response = self.fetch("/undecorated_coroutine") + self.assertEqual(response.body, b"123") + + def test_async_prepare_error_handler(self): + response = self.fetch("/async_prepare_error") + self.assertEqual(response.code, 403) + + def test_native_coroutine_handler(self): + response = self.fetch("/native_coroutine") + self.assertEqual(response.code, 200) + self.assertEqual(response.body, b"ok") + + +class WithTimeoutTest(AsyncTestCase): + @gen_test + def test_timeout(self): + with self.assertRaises(gen.TimeoutError): + yield gen.with_timeout(datetime.timedelta(seconds=0.1), Future()) + + @gen_test + def test_completes_before_timeout(self): + future = Future() # type: Future[str] + self.io_loop.add_timeout( + datetime.timedelta(seconds=0.1), lambda: future.set_result("asdf") + ) + result = yield gen.with_timeout(datetime.timedelta(seconds=3600), future) + self.assertEqual(result, "asdf") + + @gen_test + def test_fails_before_timeout(self): + future = Future() # type: Future[str] + self.io_loop.add_timeout( + datetime.timedelta(seconds=0.1), + lambda: future.set_exception(ZeroDivisionError()), + ) + with self.assertRaises(ZeroDivisionError): + yield gen.with_timeout(datetime.timedelta(seconds=3600), future) + + @gen_test + def test_already_resolved(self): + future = Future() # type: Future[str] + future.set_result("asdf") + result = yield gen.with_timeout(datetime.timedelta(seconds=3600), future) + self.assertEqual(result, "asdf") + + @gen_test + def test_timeout_concurrent_future(self): + # A concurrent future that does not resolve before the timeout. + with futures.ThreadPoolExecutor(1) as executor: + with self.assertRaises(gen.TimeoutError): + yield gen.with_timeout( + self.io_loop.time(), executor.submit(time.sleep, 0.1) + ) + + @gen_test + def test_completed_concurrent_future(self): + # A concurrent future that is resolved before we even submit it + # to with_timeout. + with futures.ThreadPoolExecutor(1) as executor: + + def dummy(): + pass + + f = executor.submit(dummy) + f.result() # wait for completion + yield gen.with_timeout(datetime.timedelta(seconds=3600), f) + + @gen_test + def test_normal_concurrent_future(self): + # A conccurrent future that resolves while waiting for the timeout. + with futures.ThreadPoolExecutor(1) as executor: + yield gen.with_timeout( + datetime.timedelta(seconds=3600), + executor.submit(lambda: time.sleep(0.01)), + ) + + +class WaitIteratorTest(AsyncTestCase): + @gen_test + def test_empty_iterator(self): + g = gen.WaitIterator() + self.assertTrue(g.done(), "empty generator iterated") + + with self.assertRaises(ValueError): + g = gen.WaitIterator(Future(), bar=Future()) + + self.assertEqual(g.current_index, None, "bad nil current index") + self.assertEqual(g.current_future, None, "bad nil current future") + + @gen_test + def test_already_done(self): + f1 = Future() # type: Future[int] + f2 = Future() # type: Future[int] + f3 = Future() # type: Future[int] + f1.set_result(24) + f2.set_result(42) + f3.set_result(84) + + g = gen.WaitIterator(f1, f2, f3) + i = 0 + while not g.done(): + r = yield g.next() + # Order is not guaranteed, but the current implementation + # preserves ordering of already-done Futures. + if i == 0: + self.assertEqual(g.current_index, 0) + self.assertIs(g.current_future, f1) + self.assertEqual(r, 24) + elif i == 1: + self.assertEqual(g.current_index, 1) + self.assertIs(g.current_future, f2) + self.assertEqual(r, 42) + elif i == 2: + self.assertEqual(g.current_index, 2) + self.assertIs(g.current_future, f3) + self.assertEqual(r, 84) + i += 1 + + self.assertEqual(g.current_index, None, "bad nil current index") + self.assertEqual(g.current_future, None, "bad nil current future") + + dg = gen.WaitIterator(f1=f1, f2=f2) + + while not dg.done(): + dr = yield dg.next() + if dg.current_index == "f1": + self.assertTrue( + dg.current_future == f1 and dr == 24, + "WaitIterator dict status incorrect", + ) + elif dg.current_index == "f2": + self.assertTrue( + dg.current_future == f2 and dr == 42, + "WaitIterator dict status incorrect", + ) + else: + self.fail("got bad WaitIterator index {}".format(dg.current_index)) + + i += 1 + + self.assertEqual(dg.current_index, None, "bad nil current index") + self.assertEqual(dg.current_future, None, "bad nil current future") + + def finish_coroutines(self, iteration, futures): + if iteration == 3: + futures[2].set_result(24) + elif iteration == 5: + futures[0].set_exception(ZeroDivisionError()) + elif iteration == 8: + futures[1].set_result(42) + futures[3].set_result(84) + + if iteration < 8: + self.io_loop.add_callback(self.finish_coroutines, iteration + 1, futures) + + @gen_test + def test_iterator(self): + futures = [Future(), Future(), Future(), Future()] # type: List[Future[int]] + + self.finish_coroutines(0, futures) + + g = gen.WaitIterator(*futures) + + i = 0 + while not g.done(): + try: + r = yield g.next() + except ZeroDivisionError: + self.assertIs(g.current_future, futures[0], "exception future invalid") + else: + if i == 0: + self.assertEqual(r, 24, "iterator value incorrect") + self.assertEqual(g.current_index, 2, "wrong index") + elif i == 2: + self.assertEqual(r, 42, "iterator value incorrect") + self.assertEqual(g.current_index, 1, "wrong index") + elif i == 3: + self.assertEqual(r, 84, "iterator value incorrect") + self.assertEqual(g.current_index, 3, "wrong index") + i += 1 + + @gen_test + def test_iterator_async_await(self): + # Recreate the previous test with py35 syntax. It's a little clunky + # because of the way the previous test handles an exception on + # a single iteration. + futures = [Future(), Future(), Future(), Future()] # type: List[Future[int]] + self.finish_coroutines(0, futures) + self.finished = False + + async def f(): + i = 0 + g = gen.WaitIterator(*futures) + try: + async for r in g: + if i == 0: + self.assertEqual(r, 24, "iterator value incorrect") + self.assertEqual(g.current_index, 2, "wrong index") + else: + raise Exception("expected exception on iteration 1") + i += 1 + except ZeroDivisionError: + i += 1 + async for r in g: + if i == 2: + self.assertEqual(r, 42, "iterator value incorrect") + self.assertEqual(g.current_index, 1, "wrong index") + elif i == 3: + self.assertEqual(r, 84, "iterator value incorrect") + self.assertEqual(g.current_index, 3, "wrong index") + else: + raise Exception("didn't expect iteration %d" % i) + i += 1 + self.finished = True + + yield f() + self.assertTrue(self.finished) + + @gen_test + def test_no_ref(self): + # In this usage, there is no direct hard reference to the + # WaitIterator itself, only the Future it returns. Since + # WaitIterator uses weak references internally to improve GC + # performance, this used to cause problems. + yield gen.with_timeout( + datetime.timedelta(seconds=0.1), gen.WaitIterator(gen.sleep(0)).next() + ) + + +class RunnerGCTest(AsyncTestCase): + def is_pypy3(self): + return platform.python_implementation() == "PyPy" and sys.version_info > (3,) + + @gen_test + def test_gc(self): + # GitHub issue 1769: Runner objects can get GCed unexpectedly + # while their future is alive. + weakref_scope = [None] # type: List[Optional[weakref.ReferenceType]] + + def callback(): + gc.collect(2) + weakref_scope[0]().set_result(123) # type: ignore + + @gen.coroutine + def tester(): + fut = Future() # type: Future[int] + weakref_scope[0] = weakref.ref(fut) + self.io_loop.add_callback(callback) + yield fut + + yield gen.with_timeout(datetime.timedelta(seconds=0.2), tester()) + + def test_gc_infinite_coro(self): + # GitHub issue 2229: suspended coroutines should be GCed when + # their loop is closed, even if they're involved in a reference + # cycle. + loop = self.get_new_ioloop() + result = [] # type: List[Optional[bool]] + wfut = [] + + @gen.coroutine + def infinite_coro(): + try: + while True: + yield gen.sleep(1e-3) + result.append(True) + finally: + # coroutine finalizer + result.append(None) + + @gen.coroutine + def do_something(): + fut = infinite_coro() + fut._refcycle = fut # type: ignore + wfut.append(weakref.ref(fut)) + yield gen.sleep(0.2) + + loop.run_sync(do_something) + loop.close() + gc.collect() + # Future was collected + self.assertIs(wfut[0](), None) + # At least one wakeup + self.assertGreaterEqual(len(result), 2) + if not self.is_pypy3(): + # coroutine finalizer was called (not on PyPy3 apparently) + self.assertIs(result[-1], None) + + def test_gc_infinite_async_await(self): + # Same as test_gc_infinite_coro, but with a `async def` function + import asyncio + + async def infinite_coro(result): + try: + while True: + await gen.sleep(1e-3) + result.append(True) + finally: + # coroutine finalizer + result.append(None) + + loop = self.get_new_ioloop() + result = [] # type: List[Optional[bool]] + wfut = [] + + @gen.coroutine + def do_something(): + fut = asyncio.get_event_loop().create_task(infinite_coro(result)) + fut._refcycle = fut # type: ignore + wfut.append(weakref.ref(fut)) + yield gen.sleep(0.2) + + loop.run_sync(do_something) + with ExpectLog("asyncio", "Task was destroyed but it is pending"): + loop.close() + gc.collect() + # Future was collected + self.assertIs(wfut[0](), None) + # At least one wakeup and one finally + self.assertGreaterEqual(len(result), 2) + if not self.is_pypy3(): + # coroutine finalizer was called (not on PyPy3 apparently) + self.assertIs(result[-1], None) + + def test_multi_moment(self): + # Test gen.multi with moment + # now that it's not a real Future + @gen.coroutine + def wait_a_moment(): + result = yield gen.multi([gen.moment, gen.moment]) + raise gen.Return(result) + + loop = self.get_new_ioloop() + result = loop.run_sync(wait_a_moment) + self.assertEqual(result, [None, None]) + + +if contextvars is not None: + ctx_var = contextvars.ContextVar("ctx_var") # type: contextvars.ContextVar[int] + + +@unittest.skipIf(contextvars is None, "contextvars module not present") +class ContextVarsTest(AsyncTestCase): + async def native_root(self, x): + ctx_var.set(x) + await self.inner(x) + + @gen.coroutine + def gen_root(self, x): + ctx_var.set(x) + yield + yield self.inner(x) + + async def inner(self, x): + self.assertEqual(ctx_var.get(), x) + await self.gen_inner(x) + self.assertEqual(ctx_var.get(), x) + + # IOLoop.run_in_executor doesn't automatically copy context + ctx = contextvars.copy_context() + await self.io_loop.run_in_executor(None, lambda: ctx.run(self.thread_inner, x)) + self.assertEqual(ctx_var.get(), x) + + # Neither does asyncio's run_in_executor. + await asyncio.get_event_loop().run_in_executor( + None, lambda: ctx.run(self.thread_inner, x) + ) + self.assertEqual(ctx_var.get(), x) + + @gen.coroutine + def gen_inner(self, x): + self.assertEqual(ctx_var.get(), x) + yield + self.assertEqual(ctx_var.get(), x) + + def thread_inner(self, x): + self.assertEqual(ctx_var.get(), x) + + @gen_test + def test_propagate(self): + # Verify that context vars get propagated across various + # combinations of native and decorated coroutines. + yield [ + self.native_root(1), + self.native_root(2), + self.gen_root(3), + self.gen_root(4), + ] + + @gen_test + def test_reset(self): + token = ctx_var.set(1) + yield + # reset asserts that we are still at the same level of the context tree, + # so we must make sure that we maintain that property across yield. + ctx_var.reset(token) + + @gen_test + def test_propagate_to_first_yield_with_native_async_function(self): + x = 10 + + async def native_async_function(): + self.assertEqual(ctx_var.get(), x) + + ctx_var.set(x) + yield native_async_function() + + +if __name__ == "__main__": + unittest.main() diff --git a/venv/lib/python3.9/site-packages/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo b/venv/lib/python3.9/site-packages/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo Binary files differnew file mode 100644 index 00000000..a97bf9c5 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.mo diff --git a/venv/lib/python3.9/site-packages/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po b/venv/lib/python3.9/site-packages/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po new file mode 100644 index 00000000..88d72c86 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/gettext_translations/fr_FR/LC_MESSAGES/tornado_test.po @@ -0,0 +1,47 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER +# This file is distributed under the same license as the PACKAGE package. +# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2015-01-27 11:05+0300\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" +"Language-Team: LANGUAGE <LL@li.org>\n" +"Language: \n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n > 1);\n" + +#: extract_me.py:11 +msgid "school" +msgstr "école" + +#: extract_me.py:12 +msgctxt "law" +msgid "right" +msgstr "le droit" + +#: extract_me.py:13 +msgctxt "good" +msgid "right" +msgstr "le bien" + +#: extract_me.py:14 +msgctxt "organization" +msgid "club" +msgid_plural "clubs" +msgstr[0] "le club" +msgstr[1] "les clubs" + +#: extract_me.py:15 +msgctxt "stick" +msgid "club" +msgid_plural "clubs" +msgstr[0] "le bâton" +msgstr[1] "les bâtons" diff --git a/venv/lib/python3.9/site-packages/tornado/test/http1connection_test.py b/venv/lib/python3.9/site-packages/tornado/test/http1connection_test.py new file mode 100644 index 00000000..34de6d38 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/http1connection_test.py @@ -0,0 +1,61 @@ +import socket +import typing # noqa(F401) + +from tornado.http1connection import HTTP1Connection +from tornado.httputil import HTTPMessageDelegate +from tornado.iostream import IOStream +from tornado.locks import Event +from tornado.netutil import add_accept_handler +from tornado.testing import AsyncTestCase, bind_unused_port, gen_test + + +class HTTP1ConnectionTest(AsyncTestCase): + code = None # type: typing.Optional[int] + + def setUp(self): + super().setUp() + self.asyncSetUp() + + @gen_test + def asyncSetUp(self): + listener, port = bind_unused_port() + event = Event() + + def accept_callback(conn, addr): + self.server_stream = IOStream(conn) + self.addCleanup(self.server_stream.close) + event.set() + + add_accept_handler(listener, accept_callback) + self.client_stream = IOStream(socket.socket()) + self.addCleanup(self.client_stream.close) + yield [self.client_stream.connect(("127.0.0.1", port)), event.wait()] + self.io_loop.remove_handler(listener) + listener.close() + + @gen_test + def test_http10_no_content_length(self): + # Regression test for a bug in which can_keep_alive would crash + # for an HTTP/1.0 (not 1.1) response with no content-length. + conn = HTTP1Connection(self.client_stream, True) + self.server_stream.write(b"HTTP/1.0 200 Not Modified\r\n\r\nhello") + self.server_stream.close() + + event = Event() + test = self + body = [] + + class Delegate(HTTPMessageDelegate): + def headers_received(self, start_line, headers): + test.code = start_line.code + + def data_received(self, data): + body.append(data) + + def finish(self): + event.set() + + yield conn.read_response(Delegate()) + yield event.wait() + self.assertEqual(self.code, 200) + self.assertEqual(b"".join(body), b"hello") diff --git a/venv/lib/python3.9/site-packages/tornado/test/httpclient_test.py b/venv/lib/python3.9/site-packages/tornado/test/httpclient_test.py new file mode 100644 index 00000000..a71ec0af --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/httpclient_test.py @@ -0,0 +1,915 @@ +import base64 +import binascii +from contextlib import closing +import copy +import gzip +import threading +import datetime +from io import BytesIO +import subprocess +import sys +import time +import typing # noqa: F401 +import unicodedata +import unittest + +from tornado.escape import utf8, native_str, to_unicode +from tornado import gen +from tornado.httpclient import ( + HTTPRequest, + HTTPResponse, + _RequestProxy, + HTTPError, + HTTPClient, +) +from tornado.httpserver import HTTPServer +from tornado.ioloop import IOLoop +from tornado.iostream import IOStream +from tornado.log import gen_log, app_log +from tornado import netutil +from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog +from tornado.test.util import skipOnTravis +from tornado.web import Application, RequestHandler, url +from tornado.httputil import format_timestamp, HTTPHeaders + + +class HelloWorldHandler(RequestHandler): + def get(self): + name = self.get_argument("name", "world") + self.set_header("Content-Type", "text/plain") + self.finish("Hello %s!" % name) + + +class PostHandler(RequestHandler): + def post(self): + self.finish( + "Post arg1: %s, arg2: %s" + % (self.get_argument("arg1"), self.get_argument("arg2")) + ) + + +class PutHandler(RequestHandler): + def put(self): + self.write("Put body: ") + self.write(self.request.body) + + +class RedirectHandler(RequestHandler): + def prepare(self): + self.write("redirects can have bodies too") + self.redirect( + self.get_argument("url"), status=int(self.get_argument("status", "302")) + ) + + +class RedirectWithoutLocationHandler(RequestHandler): + def prepare(self): + # For testing error handling of a redirect with no location header. + self.set_status(301) + self.finish() + + +class ChunkHandler(RequestHandler): + @gen.coroutine + def get(self): + self.write("asdf") + self.flush() + # Wait a bit to ensure the chunks are sent and received separately. + yield gen.sleep(0.01) + self.write("qwer") + + +class AuthHandler(RequestHandler): + def get(self): + self.finish(self.request.headers["Authorization"]) + + +class CountdownHandler(RequestHandler): + def get(self, count): + count = int(count) + if count > 0: + self.redirect(self.reverse_url("countdown", count - 1)) + else: + self.write("Zero") + + +class EchoPostHandler(RequestHandler): + def post(self): + self.write(self.request.body) + + +class UserAgentHandler(RequestHandler): + def get(self): + self.write(self.request.headers.get("User-Agent", "User agent not set")) + + +class ContentLength304Handler(RequestHandler): + def get(self): + self.set_status(304) + self.set_header("Content-Length", 42) + + def _clear_representation_headers(self): + # Tornado strips content-length from 304 responses, but here we + # want to simulate servers that include the headers anyway. + pass + + +class PatchHandler(RequestHandler): + def patch(self): + "Return the request payload - so we can check it is being kept" + self.write(self.request.body) + + +class AllMethodsHandler(RequestHandler): + SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ("OTHER",) # type: ignore + + def method(self): + assert self.request.method is not None + self.write(self.request.method) + + get = head = post = put = delete = options = patch = other = method # type: ignore + + +class SetHeaderHandler(RequestHandler): + def get(self): + # Use get_arguments for keys to get strings, but + # request.arguments for values to get bytes. + for k, v in zip(self.get_arguments("k"), self.request.arguments["v"]): + self.set_header(k, v) + + +class InvalidGzipHandler(RequestHandler): + def get(self) -> None: + # set Content-Encoding manually to avoid automatic gzip encoding + self.set_header("Content-Type", "text/plain") + self.set_header("Content-Encoding", "gzip") + # Triggering the potential bug seems to depend on input length. + # This length is taken from the bad-response example reported in + # https://github.com/tornadoweb/tornado/pull/2875 (uncompressed). + text = "".join("Hello World {}\n".format(i) for i in range(9000))[:149051] + body = gzip.compress(text.encode(), compresslevel=6) + b"\00" + self.write(body) + + +class HeaderEncodingHandler(RequestHandler): + def get(self): + self.finish(self.request.headers["Foo"].encode("ISO8859-1")) + + +# These tests end up getting run redundantly: once here with the default +# HTTPClient implementation, and then again in each implementation's own +# test suite. + + +class HTTPClientCommonTestCase(AsyncHTTPTestCase): + def get_app(self): + return Application( + [ + url("/hello", HelloWorldHandler), + url("/post", PostHandler), + url("/put", PutHandler), + url("/redirect", RedirectHandler), + url("/redirect_without_location", RedirectWithoutLocationHandler), + url("/chunk", ChunkHandler), + url("/auth", AuthHandler), + url("/countdown/([0-9]+)", CountdownHandler, name="countdown"), + url("/echopost", EchoPostHandler), + url("/user_agent", UserAgentHandler), + url("/304_with_content_length", ContentLength304Handler), + url("/all_methods", AllMethodsHandler), + url("/patch", PatchHandler), + url("/set_header", SetHeaderHandler), + url("/invalid_gzip", InvalidGzipHandler), + url("/header-encoding", HeaderEncodingHandler), + ], + gzip=True, + ) + + def test_patch_receives_payload(self): + body = b"some patch data" + response = self.fetch("/patch", method="PATCH", body=body) + self.assertEqual(response.code, 200) + self.assertEqual(response.body, body) + + @skipOnTravis + def test_hello_world(self): + response = self.fetch("/hello") + self.assertEqual(response.code, 200) + self.assertEqual(response.headers["Content-Type"], "text/plain") + self.assertEqual(response.body, b"Hello world!") + assert response.request_time is not None + self.assertEqual(int(response.request_time), 0) + + response = self.fetch("/hello?name=Ben") + self.assertEqual(response.body, b"Hello Ben!") + + def test_streaming_callback(self): + # streaming_callback is also tested in test_chunked + chunks = [] # type: typing.List[bytes] + response = self.fetch("/hello", streaming_callback=chunks.append) + # with streaming_callback, data goes to the callback and not response.body + self.assertEqual(chunks, [b"Hello world!"]) + self.assertFalse(response.body) + + def test_post(self): + response = self.fetch("/post", method="POST", body="arg1=foo&arg2=bar") + self.assertEqual(response.code, 200) + self.assertEqual(response.body, b"Post arg1: foo, arg2: bar") + + def test_chunked(self): + response = self.fetch("/chunk") + self.assertEqual(response.body, b"asdfqwer") + + chunks = [] # type: typing.List[bytes] + response = self.fetch("/chunk", streaming_callback=chunks.append) + self.assertEqual(chunks, [b"asdf", b"qwer"]) + self.assertFalse(response.body) + + def test_chunked_close(self): + # test case in which chunks spread read-callback processing + # over several ioloop iterations, but the connection is already closed. + sock, port = bind_unused_port() + with closing(sock): + + @gen.coroutine + def accept_callback(conn, address): + # fake an HTTP server using chunked encoding where the final chunks + # and connection close all happen at once + stream = IOStream(conn) + request_data = yield stream.read_until(b"\r\n\r\n") + if b"HTTP/1." not in request_data: + self.skipTest("requires HTTP/1.x") + yield stream.write( + b"""\ +HTTP/1.1 200 OK +Transfer-Encoding: chunked + +1 +1 +1 +2 +0 + +""".replace( + b"\n", b"\r\n" + ) + ) + stream.close() + + netutil.add_accept_handler(sock, accept_callback) # type: ignore + resp = self.fetch("http://127.0.0.1:%d/" % port) + resp.rethrow() + self.assertEqual(resp.body, b"12") + self.io_loop.remove_handler(sock.fileno()) + + def test_basic_auth(self): + # This test data appears in section 2 of RFC 7617. + self.assertEqual( + self.fetch( + "/auth", auth_username="Aladdin", auth_password="open sesame" + ).body, + b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", + ) + + def test_basic_auth_explicit_mode(self): + self.assertEqual( + self.fetch( + "/auth", + auth_username="Aladdin", + auth_password="open sesame", + auth_mode="basic", + ).body, + b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", + ) + + def test_basic_auth_unicode(self): + # This test data appears in section 2.1 of RFC 7617. + self.assertEqual( + self.fetch("/auth", auth_username="test", auth_password="123£").body, + b"Basic dGVzdDoxMjPCow==", + ) + + # The standard mandates NFC. Give it a decomposed username + # and ensure it is normalized to composed form. + username = unicodedata.normalize("NFD", "josé") + self.assertEqual( + self.fetch("/auth", auth_username=username, auth_password="səcrət").body, + b"Basic am9zw6k6c8mZY3LJmXQ=", + ) + + def test_unsupported_auth_mode(self): + # curl and simple clients handle errors a bit differently; the + # important thing is that they don't fall back to basic auth + # on an unknown mode. + with ExpectLog(gen_log, "uncaught exception", required=False): + with self.assertRaises((ValueError, HTTPError)): # type: ignore + self.fetch( + "/auth", + auth_username="Aladdin", + auth_password="open sesame", + auth_mode="asdf", + raise_error=True, + ) + + def test_follow_redirect(self): + response = self.fetch("/countdown/2", follow_redirects=False) + self.assertEqual(302, response.code) + self.assertTrue(response.headers["Location"].endswith("/countdown/1")) + + response = self.fetch("/countdown/2") + self.assertEqual(200, response.code) + self.assertTrue(response.effective_url.endswith("/countdown/0")) + self.assertEqual(b"Zero", response.body) + + def test_redirect_without_location(self): + response = self.fetch("/redirect_without_location", follow_redirects=True) + # If there is no location header, the redirect response should + # just be returned as-is. (This should arguably raise an + # error, but libcurl doesn't treat this as an error, so we + # don't either). + self.assertEqual(301, response.code) + + def test_redirect_put_with_body(self): + response = self.fetch( + "/redirect?url=/put&status=307", method="PUT", body="hello" + ) + self.assertEqual(response.body, b"Put body: hello") + + def test_redirect_put_without_body(self): + # This "without body" edge case is similar to what happens with body_producer. + response = self.fetch( + "/redirect?url=/put&status=307", + method="PUT", + allow_nonstandard_methods=True, + ) + self.assertEqual(response.body, b"Put body: ") + + def test_method_after_redirect(self): + # Legacy redirect codes (301, 302) convert POST requests to GET. + for status in [301, 302, 303]: + url = "/redirect?url=/all_methods&status=%d" % status + resp = self.fetch(url, method="POST", body=b"") + self.assertEqual(b"GET", resp.body) + + # Other methods are left alone, except for 303 redirect, depending on client + for method in ["GET", "OPTIONS", "PUT", "DELETE"]: + resp = self.fetch(url, method=method, allow_nonstandard_methods=True) + if status in [301, 302]: + self.assertEqual(utf8(method), resp.body) + else: + self.assertIn(resp.body, [utf8(method), b"GET"]) + + # HEAD is different so check it separately. + resp = self.fetch(url, method="HEAD") + self.assertEqual(200, resp.code) + self.assertEqual(b"", resp.body) + + # Newer redirects always preserve the original method. + for status in [307, 308]: + url = "/redirect?url=/all_methods&status=307" + for method in ["GET", "OPTIONS", "POST", "PUT", "DELETE"]: + resp = self.fetch(url, method=method, allow_nonstandard_methods=True) + self.assertEqual(method, to_unicode(resp.body)) + resp = self.fetch(url, method="HEAD") + self.assertEqual(200, resp.code) + self.assertEqual(b"", resp.body) + + def test_credentials_in_url(self): + url = self.get_url("/auth").replace("http://", "http://me:secret@") + response = self.fetch(url) + self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"), response.body) + + def test_body_encoding(self): + unicode_body = "\xe9" + byte_body = binascii.a2b_hex(b"e9") + + # unicode string in body gets converted to utf8 + response = self.fetch( + "/echopost", + method="POST", + body=unicode_body, + headers={"Content-Type": "application/blah"}, + ) + self.assertEqual(response.headers["Content-Length"], "2") + self.assertEqual(response.body, utf8(unicode_body)) + + # byte strings pass through directly + response = self.fetch( + "/echopost", + method="POST", + body=byte_body, + headers={"Content-Type": "application/blah"}, + ) + self.assertEqual(response.headers["Content-Length"], "1") + self.assertEqual(response.body, byte_body) + + # Mixing unicode in headers and byte string bodies shouldn't + # break anything + response = self.fetch( + "/echopost", + method="POST", + body=byte_body, + headers={"Content-Type": "application/blah"}, + user_agent="foo", + ) + self.assertEqual(response.headers["Content-Length"], "1") + self.assertEqual(response.body, byte_body) + + def test_types(self): + response = self.fetch("/hello") + self.assertEqual(type(response.body), bytes) + self.assertEqual(type(response.headers["Content-Type"]), str) + self.assertEqual(type(response.code), int) + self.assertEqual(type(response.effective_url), str) + + def test_gzip(self): + # All the tests in this file should be using gzip, but this test + # ensures that it is in fact getting compressed, and also tests + # the httpclient's decompress=False option. + # Setting Accept-Encoding manually bypasses the client's + # decompression so we can see the raw data. + response = self.fetch( + "/chunk", decompress_response=False, headers={"Accept-Encoding": "gzip"} + ) + self.assertEqual(response.headers["Content-Encoding"], "gzip") + self.assertNotEqual(response.body, b"asdfqwer") + # Our test data gets bigger when gzipped. Oops. :) + # Chunked encoding bypasses the MIN_LENGTH check. + self.assertEqual(len(response.body), 34) + f = gzip.GzipFile(mode="r", fileobj=response.buffer) + self.assertEqual(f.read(), b"asdfqwer") + + def test_invalid_gzip(self): + # test if client hangs on tricky invalid gzip + # curl/simple httpclient have different behavior (exception, logging) + with ExpectLog( + app_log, "(Uncaught exception|Exception in callback)", required=False + ): + try: + response = self.fetch("/invalid_gzip") + self.assertEqual(response.code, 200) + self.assertEqual(response.body[:14], b"Hello World 0\n") + except HTTPError: + pass # acceptable + + def test_header_callback(self): + first_line = [] + headers = {} + chunks = [] + + def header_callback(header_line): + if header_line.startswith("HTTP/1.1 101"): + # Upgrading to HTTP/2 + pass + elif header_line.startswith("HTTP/"): + first_line.append(header_line) + elif header_line != "\r\n": + k, v = header_line.split(":", 1) + headers[k.lower()] = v.strip() + + def streaming_callback(chunk): + # All header callbacks are run before any streaming callbacks, + # so the header data is available to process the data as it + # comes in. + self.assertEqual(headers["content-type"], "text/html; charset=UTF-8") + chunks.append(chunk) + + self.fetch( + "/chunk", + header_callback=header_callback, + streaming_callback=streaming_callback, + ) + self.assertEqual(len(first_line), 1, first_line) + self.assertRegex(first_line[0], "HTTP/[0-9]\\.[0-9] 200.*\r\n") + self.assertEqual(chunks, [b"asdf", b"qwer"]) + + @gen_test + def test_configure_defaults(self): + defaults = dict(user_agent="TestDefaultUserAgent", allow_ipv6=False) + # Construct a new instance of the configured client class + client = self.http_client.__class__(force_instance=True, defaults=defaults) + try: + response = yield client.fetch(self.get_url("/user_agent")) + self.assertEqual(response.body, b"TestDefaultUserAgent") + finally: + client.close() + + def test_header_types(self): + # Header values may be passed as character or utf8 byte strings, + # in a plain dictionary or an HTTPHeaders object. + # Keys must always be the native str type. + # All combinations should have the same results on the wire. + for value in ["MyUserAgent", b"MyUserAgent"]: + for container in [dict, HTTPHeaders]: + headers = container() + headers["User-Agent"] = value + resp = self.fetch("/user_agent", headers=headers) + self.assertEqual( + resp.body, + b"MyUserAgent", + "response=%r, value=%r, container=%r" + % (resp.body, value, container), + ) + + def test_multi_line_headers(self): + # Multi-line http headers are rare but rfc-allowed + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 + sock, port = bind_unused_port() + with closing(sock): + + @gen.coroutine + def accept_callback(conn, address): + stream = IOStream(conn) + request_data = yield stream.read_until(b"\r\n\r\n") + if b"HTTP/1." not in request_data: + self.skipTest("requires HTTP/1.x") + yield stream.write( + b"""\ +HTTP/1.1 200 OK +X-XSS-Protection: 1; +\tmode=block + +""".replace( + b"\n", b"\r\n" + ) + ) + stream.close() + + netutil.add_accept_handler(sock, accept_callback) # type: ignore + try: + resp = self.fetch("http://127.0.0.1:%d/" % port) + resp.rethrow() + self.assertEqual(resp.headers["X-XSS-Protection"], "1; mode=block") + finally: + self.io_loop.remove_handler(sock.fileno()) + + @gen_test + def test_header_encoding(self): + response = yield self.http_client.fetch( + self.get_url("/header-encoding"), + headers={ + "Foo": "b\xe4r", + }, + ) + self.assertEqual(response.body, "b\xe4r".encode("ISO8859-1")) + + def test_304_with_content_length(self): + # According to the spec 304 responses SHOULD NOT include + # Content-Length or other entity headers, but some servers do it + # anyway. + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5 + response = self.fetch("/304_with_content_length") + self.assertEqual(response.code, 304) + self.assertEqual(response.headers["Content-Length"], "42") + + @gen_test + def test_future_interface(self): + response = yield self.http_client.fetch(self.get_url("/hello")) + self.assertEqual(response.body, b"Hello world!") + + @gen_test + def test_future_http_error(self): + with self.assertRaises(HTTPError) as context: + yield self.http_client.fetch(self.get_url("/notfound")) + assert context.exception is not None + assert context.exception.response is not None + self.assertEqual(context.exception.code, 404) + self.assertEqual(context.exception.response.code, 404) + + @gen_test + def test_future_http_error_no_raise(self): + response = yield self.http_client.fetch( + self.get_url("/notfound"), raise_error=False + ) + self.assertEqual(response.code, 404) + + @gen_test + def test_reuse_request_from_response(self): + # The response.request attribute should be an HTTPRequest, not + # a _RequestProxy. + # This test uses self.http_client.fetch because self.fetch calls + # self.get_url on the input unconditionally. + url = self.get_url("/hello") + response = yield self.http_client.fetch(url) + self.assertEqual(response.request.url, url) + self.assertTrue(isinstance(response.request, HTTPRequest)) + response2 = yield self.http_client.fetch(response.request) + self.assertEqual(response2.body, b"Hello world!") + + @gen_test + def test_bind_source_ip(self): + url = self.get_url("/hello") + request = HTTPRequest(url, network_interface="127.0.0.1") + response = yield self.http_client.fetch(request) + self.assertEqual(response.code, 200) + + with self.assertRaises((ValueError, HTTPError)) as context: # type: ignore + request = HTTPRequest(url, network_interface="not-interface-or-ip") + yield self.http_client.fetch(request) + self.assertIn("not-interface-or-ip", str(context.exception)) + + def test_all_methods(self): + for method in ["GET", "DELETE", "OPTIONS"]: + response = self.fetch("/all_methods", method=method) + self.assertEqual(response.body, utf8(method)) + for method in ["POST", "PUT", "PATCH"]: + response = self.fetch("/all_methods", method=method, body=b"") + self.assertEqual(response.body, utf8(method)) + response = self.fetch("/all_methods", method="HEAD") + self.assertEqual(response.body, b"") + response = self.fetch( + "/all_methods", method="OTHER", allow_nonstandard_methods=True + ) + self.assertEqual(response.body, b"OTHER") + + def test_body_sanity_checks(self): + # These methods require a body. + for method in ("POST", "PUT", "PATCH"): + with self.assertRaises(ValueError) as context: + self.fetch("/all_methods", method=method, raise_error=True) + self.assertIn("must not be None", str(context.exception)) + + resp = self.fetch( + "/all_methods", method=method, allow_nonstandard_methods=True + ) + self.assertEqual(resp.code, 200) + + # These methods don't allow a body. + for method in ("GET", "DELETE", "OPTIONS"): + with self.assertRaises(ValueError) as context: + self.fetch( + "/all_methods", method=method, body=b"asdf", raise_error=True + ) + self.assertIn("must be None", str(context.exception)) + + # In most cases this can be overridden, but curl_httpclient + # does not allow body with a GET at all. + if method != "GET": + self.fetch( + "/all_methods", + method=method, + body=b"asdf", + allow_nonstandard_methods=True, + raise_error=True, + ) + self.assertEqual(resp.code, 200) + + # This test causes odd failures with the combination of + # curl_httpclient (at least with the version of libcurl available + # on ubuntu 12.04), TwistedIOLoop, and epoll. For POST (but not PUT), + # curl decides the response came back too soon and closes the connection + # to start again. It does this *before* telling the socket callback to + # unregister the FD. Some IOLoop implementations have special kernel + # integration to discover this immediately. Tornado's IOLoops + # ignore errors on remove_handler to accommodate this behavior, but + # Twisted's reactor does not. The removeReader call fails and so + # do all future removeAll calls (which our tests do at cleanup). + # + # def test_post_307(self): + # response = self.fetch("/redirect?status=307&url=/post", + # method="POST", body=b"arg1=foo&arg2=bar") + # self.assertEqual(response.body, b"Post arg1: foo, arg2: bar") + + def test_put_307(self): + response = self.fetch( + "/redirect?status=307&url=/put", method="PUT", body=b"hello" + ) + response.rethrow() + self.assertEqual(response.body, b"Put body: hello") + + def test_non_ascii_header(self): + # Non-ascii headers are sent as latin1. + response = self.fetch("/set_header?k=foo&v=%E9") + response.rethrow() + self.assertEqual(response.headers["Foo"], native_str("\u00e9")) + + def test_response_times(self): + # A few simple sanity checks of the response time fields to + # make sure they're using the right basis (between the + # wall-time and monotonic clocks). + start_time = time.time() + response = self.fetch("/hello") + response.rethrow() + assert response.request_time is not None + self.assertGreaterEqual(response.request_time, 0) + self.assertLess(response.request_time, 1.0) + # A very crude check to make sure that start_time is based on + # wall time and not the monotonic clock. + assert response.start_time is not None + self.assertLess(abs(response.start_time - start_time), 1.0) + + for k, v in response.time_info.items(): + self.assertTrue(0 <= v < 1.0, "time_info[%s] out of bounds: %s" % (k, v)) + + def test_zero_timeout(self): + response = self.fetch("/hello", connect_timeout=0) + self.assertEqual(response.code, 200) + + response = self.fetch("/hello", request_timeout=0) + self.assertEqual(response.code, 200) + + response = self.fetch("/hello", connect_timeout=0, request_timeout=0) + self.assertEqual(response.code, 200) + + @gen_test + def test_error_after_cancel(self): + fut = self.http_client.fetch(self.get_url("/404")) + self.assertTrue(fut.cancel()) + with ExpectLog(app_log, "Exception after Future was cancelled") as el: + # We can't wait on the cancelled Future any more, so just + # let the IOLoop run until the exception gets logged (or + # not, in which case we exit the loop and ExpectLog will + # raise). + for i in range(100): + yield gen.sleep(0.01) + if el.logged_stack: + break + + +class RequestProxyTest(unittest.TestCase): + def test_request_set(self): + proxy = _RequestProxy( + HTTPRequest("http://example.com/", user_agent="foo"), dict() + ) + self.assertEqual(proxy.user_agent, "foo") + + def test_default_set(self): + proxy = _RequestProxy( + HTTPRequest("http://example.com/"), dict(network_interface="foo") + ) + self.assertEqual(proxy.network_interface, "foo") + + def test_both_set(self): + proxy = _RequestProxy( + HTTPRequest("http://example.com/", proxy_host="foo"), dict(proxy_host="bar") + ) + self.assertEqual(proxy.proxy_host, "foo") + + def test_neither_set(self): + proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict()) + self.assertIs(proxy.auth_username, None) + + def test_bad_attribute(self): + proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict()) + with self.assertRaises(AttributeError): + proxy.foo + + def test_defaults_none(self): + proxy = _RequestProxy(HTTPRequest("http://example.com/"), None) + self.assertIs(proxy.auth_username, None) + + +class HTTPResponseTestCase(unittest.TestCase): + def test_str(self): + response = HTTPResponse( # type: ignore + HTTPRequest("http://example.com"), 200, buffer=BytesIO() + ) + s = str(response) + self.assertTrue(s.startswith("HTTPResponse(")) + self.assertIn("code=200", s) + + +class SyncHTTPClientTest(unittest.TestCase): + def setUp(self): + self.server_ioloop = IOLoop(make_current=False) + event = threading.Event() + + @gen.coroutine + def init_server(): + sock, self.port = bind_unused_port() + app = Application([("/", HelloWorldHandler)]) + self.server = HTTPServer(app) + self.server.add_socket(sock) + event.set() + + def start(): + self.server_ioloop.run_sync(init_server) + self.server_ioloop.start() + + self.server_thread = threading.Thread(target=start) + self.server_thread.start() + event.wait() + + self.http_client = HTTPClient() + + def tearDown(self): + def stop_server(): + self.server.stop() + # Delay the shutdown of the IOLoop by several iterations because + # the server may still have some cleanup work left when + # the client finishes with the response (this is noticeable + # with http/2, which leaves a Future with an unexamined + # StreamClosedError on the loop). + + @gen.coroutine + def slow_stop(): + yield self.server.close_all_connections() + # The number of iterations is difficult to predict. Typically, + # one is sufficient, although sometimes it needs more. + for i in range(5): + yield + self.server_ioloop.stop() + + self.server_ioloop.add_callback(slow_stop) + + self.server_ioloop.add_callback(stop_server) + self.server_thread.join() + self.http_client.close() + self.server_ioloop.close(all_fds=True) + + def get_url(self, path): + return "http://127.0.0.1:%d%s" % (self.port, path) + + def test_sync_client(self): + response = self.http_client.fetch(self.get_url("/")) + self.assertEqual(b"Hello world!", response.body) + + def test_sync_client_error(self): + # Synchronous HTTPClient raises errors directly; no need for + # response.rethrow() + with self.assertRaises(HTTPError) as assertion: + self.http_client.fetch(self.get_url("/notfound")) + self.assertEqual(assertion.exception.code, 404) + + +class SyncHTTPClientSubprocessTest(unittest.TestCase): + def test_destructor_log(self): + # Regression test for + # https://github.com/tornadoweb/tornado/issues/2539 + # + # In the past, the following program would log an + # "inconsistent AsyncHTTPClient cache" error from a destructor + # when the process is shutting down. The shutdown process is + # subtle and I don't fully understand it; the failure does not + # manifest if that lambda isn't there or is a simpler object + # like an int (nor does it manifest in the tornado test suite + # as a whole, which is why we use this subprocess). + proc = subprocess.run( + [ + sys.executable, + "-c", + "from tornado.httpclient import HTTPClient; f = lambda: None; c = HTTPClient()", + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=True, + timeout=5, + ) + if proc.stdout: + print("STDOUT:") + print(to_unicode(proc.stdout)) + if proc.stdout: + self.fail("subprocess produced unexpected output") + + +class HTTPRequestTestCase(unittest.TestCase): + def test_headers(self): + request = HTTPRequest("http://example.com", headers={"foo": "bar"}) + self.assertEqual(request.headers, {"foo": "bar"}) + + def test_headers_setter(self): + request = HTTPRequest("http://example.com") + request.headers = {"bar": "baz"} # type: ignore + self.assertEqual(request.headers, {"bar": "baz"}) + + def test_null_headers_setter(self): + request = HTTPRequest("http://example.com") + request.headers = None # type: ignore + self.assertEqual(request.headers, {}) + + def test_body(self): + request = HTTPRequest("http://example.com", body="foo") + self.assertEqual(request.body, utf8("foo")) + + def test_body_setter(self): + request = HTTPRequest("http://example.com") + request.body = "foo" # type: ignore + self.assertEqual(request.body, utf8("foo")) + + def test_if_modified_since(self): + http_date = datetime.datetime.utcnow() + request = HTTPRequest("http://example.com", if_modified_since=http_date) + self.assertEqual( + request.headers, {"If-Modified-Since": format_timestamp(http_date)} + ) + + +class HTTPErrorTestCase(unittest.TestCase): + def test_copy(self): + e = HTTPError(403) + e2 = copy.copy(e) + self.assertIsNot(e, e2) + self.assertEqual(e.code, e2.code) + + def test_plain_error(self): + e = HTTPError(403) + self.assertEqual(str(e), "HTTP 403: Forbidden") + self.assertEqual(repr(e), "HTTP 403: Forbidden") + + def test_error_with_response(self): + resp = HTTPResponse(HTTPRequest("http://example.com/"), 403) + with self.assertRaises(HTTPError) as cm: + resp.rethrow() + e = cm.exception + self.assertEqual(str(e), "HTTP 403: Forbidden") + self.assertEqual(repr(e), "HTTP 403: Forbidden") diff --git a/venv/lib/python3.9/site-packages/tornado/test/httpserver_test.py b/venv/lib/python3.9/site-packages/tornado/test/httpserver_test.py new file mode 100644 index 00000000..cd0a0e10 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/httpserver_test.py @@ -0,0 +1,1356 @@ +from tornado import gen, netutil +from tornado.escape import ( + json_decode, + json_encode, + utf8, + _unicode, + recursive_unicode, + native_str, +) +from tornado.http1connection import HTTP1Connection +from tornado.httpclient import HTTPError +from tornado.httpserver import HTTPServer +from tornado.httputil import ( + HTTPHeaders, + HTTPMessageDelegate, + HTTPServerConnectionDelegate, + ResponseStartLine, +) +from tornado.iostream import IOStream +from tornado.locks import Event +from tornado.log import gen_log +from tornado.netutil import ssl_options_to_context +from tornado.simple_httpclient import SimpleAsyncHTTPClient +from tornado.testing import ( + AsyncHTTPTestCase, + AsyncHTTPSTestCase, + AsyncTestCase, + ExpectLog, + gen_test, +) +from tornado.test.util import skipOnTravis +from tornado.web import Application, RequestHandler, stream_request_body + +from contextlib import closing +import datetime +import gzip +import logging +import os +import shutil +import socket +import ssl +import sys +import tempfile +import unittest +import urllib.parse +from io import BytesIO + +import typing + +if typing.TYPE_CHECKING: + from typing import Dict, List # noqa: F401 + + +async def read_stream_body(stream): + """Reads an HTTP response from `stream` and returns a tuple of its + start_line, headers and body.""" + chunks = [] + + class Delegate(HTTPMessageDelegate): + def headers_received(self, start_line, headers): + self.headers = headers + self.start_line = start_line + + def data_received(self, chunk): + chunks.append(chunk) + + def finish(self): + conn.detach() # type: ignore + + conn = HTTP1Connection(stream, True) + delegate = Delegate() + await conn.read_response(delegate) + return delegate.start_line, delegate.headers, b"".join(chunks) + + +class HandlerBaseTestCase(AsyncHTTPTestCase): + Handler = None + + def get_app(self): + return Application([("/", self.__class__.Handler)]) + + def fetch_json(self, *args, **kwargs): + response = self.fetch(*args, **kwargs) + response.rethrow() + return json_decode(response.body) + + +class HelloWorldRequestHandler(RequestHandler): + def initialize(self, protocol="http"): + self.expected_protocol = protocol + + def get(self): + if self.request.protocol != self.expected_protocol: + raise Exception("unexpected protocol") + self.finish("Hello world") + + def post(self): + self.finish("Got %d bytes in POST" % len(self.request.body)) + + +# In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2 +# ClientHello messages, which are rejected by SSLv3 and TLSv1 +# servers. Note that while the OPENSSL_VERSION_INFO was formally +# introduced in python3.2, it was present but undocumented in +# python 2.7 +skipIfOldSSL = unittest.skipIf( + getattr(ssl, "OPENSSL_VERSION_INFO", (0, 0)) < (1, 0), + "old version of ssl module and/or openssl", +) + + +class BaseSSLTest(AsyncHTTPSTestCase): + def get_app(self): + return Application([("/", HelloWorldRequestHandler, dict(protocol="https"))]) + + +class SSLTestMixin(object): + def get_ssl_options(self): + return dict( + ssl_version=self.get_ssl_version(), + **AsyncHTTPSTestCase.default_ssl_options() + ) + + def get_ssl_version(self): + raise NotImplementedError() + + def test_ssl(self: typing.Any): + response = self.fetch("/") + self.assertEqual(response.body, b"Hello world") + + def test_large_post(self: typing.Any): + response = self.fetch("/", method="POST", body="A" * 5000) + self.assertEqual(response.body, b"Got 5000 bytes in POST") + + def test_non_ssl_request(self: typing.Any): + # Make sure the server closes the connection when it gets a non-ssl + # connection, rather than waiting for a timeout or otherwise + # misbehaving. + with ExpectLog(gen_log, "(SSL Error|uncaught exception)"): + with ExpectLog(gen_log, "Uncaught exception", required=False): + with self.assertRaises((IOError, HTTPError)): # type: ignore + self.fetch( + self.get_url("/").replace("https:", "http:"), + request_timeout=3600, + connect_timeout=3600, + raise_error=True, + ) + + def test_error_logging(self: typing.Any): + # No stack traces are logged for SSL errors. + with ExpectLog(gen_log, "SSL Error") as expect_log: + with self.assertRaises((IOError, HTTPError)): # type: ignore + self.fetch( + self.get_url("/").replace("https:", "http:"), raise_error=True + ) + self.assertFalse(expect_log.logged_stack) + + +# Python's SSL implementation differs significantly between versions. +# For example, SSLv3 and TLSv1 throw an exception if you try to read +# from the socket before the handshake is complete, but the default +# of SSLv23 allows it. + + +class SSLv23Test(BaseSSLTest, SSLTestMixin): + def get_ssl_version(self): + return ssl.PROTOCOL_SSLv23 + + +@skipIfOldSSL +class SSLv3Test(BaseSSLTest, SSLTestMixin): + def get_ssl_version(self): + return ssl.PROTOCOL_SSLv3 + + +@skipIfOldSSL +class TLSv1Test(BaseSSLTest, SSLTestMixin): + def get_ssl_version(self): + return ssl.PROTOCOL_TLSv1 + + +class SSLContextTest(BaseSSLTest, SSLTestMixin): + def get_ssl_options(self): + context = ssl_options_to_context( + AsyncHTTPSTestCase.get_ssl_options(self), server_side=True + ) + assert isinstance(context, ssl.SSLContext) + return context + + +class BadSSLOptionsTest(unittest.TestCase): + def test_missing_arguments(self): + application = Application() + self.assertRaises( + KeyError, + HTTPServer, + application, + ssl_options={"keyfile": "/__missing__.crt"}, + ) + + def test_missing_key(self): + """A missing SSL key should cause an immediate exception.""" + + application = Application() + module_dir = os.path.dirname(__file__) + existing_certificate = os.path.join(module_dir, "test.crt") + existing_key = os.path.join(module_dir, "test.key") + + self.assertRaises( + (ValueError, IOError), + HTTPServer, + application, + ssl_options={"certfile": "/__mising__.crt"}, + ) + self.assertRaises( + (ValueError, IOError), + HTTPServer, + application, + ssl_options={ + "certfile": existing_certificate, + "keyfile": "/__missing__.key", + }, + ) + + # This actually works because both files exist + HTTPServer( + application, + ssl_options={"certfile": existing_certificate, "keyfile": existing_key}, + ) + + +class MultipartTestHandler(RequestHandler): + def post(self): + self.finish( + { + "header": self.request.headers["X-Header-Encoding-Test"], + "argument": self.get_argument("argument"), + "filename": self.request.files["files"][0].filename, + "filebody": _unicode(self.request.files["files"][0]["body"]), + } + ) + + +# This test is also called from wsgi_test +class HTTPConnectionTest(AsyncHTTPTestCase): + def get_handlers(self): + return [ + ("/multipart", MultipartTestHandler), + ("/hello", HelloWorldRequestHandler), + ] + + def get_app(self): + return Application(self.get_handlers()) + + def raw_fetch(self, headers, body, newline=b"\r\n"): + with closing(IOStream(socket.socket())) as stream: + self.io_loop.run_sync( + lambda: stream.connect(("127.0.0.1", self.get_http_port())) + ) + stream.write( + newline.join(headers + [utf8("Content-Length: %d" % len(body))]) + + newline + + newline + + body + ) + start_line, headers, body = self.io_loop.run_sync( + lambda: read_stream_body(stream) + ) + return body + + def test_multipart_form(self): + # Encodings here are tricky: Headers are latin1, bodies can be + # anything (we use utf8 by default). + response = self.raw_fetch( + [ + b"POST /multipart HTTP/1.0", + b"Content-Type: multipart/form-data; boundary=1234567890", + b"X-Header-encoding-test: \xe9", + ], + b"\r\n".join( + [ + b"Content-Disposition: form-data; name=argument", + b"", + "\u00e1".encode("utf-8"), + b"--1234567890", + 'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode( + "utf8" + ), + b"", + "\u00fa".encode("utf-8"), + b"--1234567890--", + b"", + ] + ), + ) + data = json_decode(response) + self.assertEqual("\u00e9", data["header"]) + self.assertEqual("\u00e1", data["argument"]) + self.assertEqual("\u00f3", data["filename"]) + self.assertEqual("\u00fa", data["filebody"]) + + def test_newlines(self): + # We support both CRLF and bare LF as line separators. + for newline in (b"\r\n", b"\n"): + response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", newline=newline) + self.assertEqual(response, b"Hello world") + + @gen_test + def test_100_continue(self): + # Run through a 100-continue interaction by hand: + # When given Expect: 100-continue, we get a 100 response after the + # headers, and then the real response after the body. + stream = IOStream(socket.socket()) + yield stream.connect(("127.0.0.1", self.get_http_port())) + yield stream.write( + b"\r\n".join( + [ + b"POST /hello HTTP/1.1", + b"Content-Length: 1024", + b"Expect: 100-continue", + b"Connection: close", + b"\r\n", + ] + ) + ) + data = yield stream.read_until(b"\r\n\r\n") + self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data) + stream.write(b"a" * 1024) + first_line = yield stream.read_until(b"\r\n") + self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) + header_data = yield stream.read_until(b"\r\n\r\n") + headers = HTTPHeaders.parse(native_str(header_data.decode("latin1"))) + body = yield stream.read_bytes(int(headers["Content-Length"])) + self.assertEqual(body, b"Got 1024 bytes in POST") + stream.close() + + +class EchoHandler(RequestHandler): + def get(self): + self.write(recursive_unicode(self.request.arguments)) + + def post(self): + self.write(recursive_unicode(self.request.arguments)) + + +class TypeCheckHandler(RequestHandler): + def prepare(self): + self.errors = {} # type: Dict[str, str] + fields = [ + ("method", str), + ("uri", str), + ("version", str), + ("remote_ip", str), + ("protocol", str), + ("host", str), + ("path", str), + ("query", str), + ] + for field, expected_type in fields: + self.check_type(field, getattr(self.request, field), expected_type) + + self.check_type("header_key", list(self.request.headers.keys())[0], str) + self.check_type("header_value", list(self.request.headers.values())[0], str) + + self.check_type("cookie_key", list(self.request.cookies.keys())[0], str) + self.check_type( + "cookie_value", list(self.request.cookies.values())[0].value, str + ) + # secure cookies + + self.check_type("arg_key", list(self.request.arguments.keys())[0], str) + self.check_type("arg_value", list(self.request.arguments.values())[0][0], bytes) + + def post(self): + self.check_type("body", self.request.body, bytes) + self.write(self.errors) + + def get(self): + self.write(self.errors) + + def check_type(self, name, obj, expected_type): + actual_type = type(obj) + if expected_type != actual_type: + self.errors[name] = "expected %s, got %s" % (expected_type, actual_type) + + +class PostEchoHandler(RequestHandler): + def post(self, *path_args): + self.write(dict(echo=self.get_argument("data"))) + + +class PostEchoGBKHandler(PostEchoHandler): + def decode_argument(self, value, name=None): + try: + return value.decode("gbk") + except Exception: + raise HTTPError(400, "invalid gbk bytes: %r" % value) + + +class HTTPServerTest(AsyncHTTPTestCase): + def get_app(self): + return Application( + [ + ("/echo", EchoHandler), + ("/typecheck", TypeCheckHandler), + ("//doubleslash", EchoHandler), + ("/post_utf8", PostEchoHandler), + ("/post_gbk", PostEchoGBKHandler), + ] + ) + + def test_query_string_encoding(self): + response = self.fetch("/echo?foo=%C3%A9") + data = json_decode(response.body) + self.assertEqual(data, {"foo": ["\u00e9"]}) + + def test_empty_query_string(self): + response = self.fetch("/echo?foo=&foo=") + data = json_decode(response.body) + self.assertEqual(data, {"foo": ["", ""]}) + + def test_empty_post_parameters(self): + response = self.fetch("/echo", method="POST", body="foo=&bar=") + data = json_decode(response.body) + self.assertEqual(data, {"foo": [""], "bar": [""]}) + + def test_types(self): + headers = {"Cookie": "foo=bar"} + response = self.fetch("/typecheck?foo=bar", headers=headers) + data = json_decode(response.body) + self.assertEqual(data, {}) + + response = self.fetch( + "/typecheck", method="POST", body="foo=bar", headers=headers + ) + data = json_decode(response.body) + self.assertEqual(data, {}) + + def test_double_slash(self): + # urlparse.urlsplit (which tornado.httpserver used to use + # incorrectly) would parse paths beginning with "//" as + # protocol-relative urls. + response = self.fetch("//doubleslash") + self.assertEqual(200, response.code) + self.assertEqual(json_decode(response.body), {}) + + def test_post_encodings(self): + headers = {"Content-Type": "application/x-www-form-urlencoded"} + uni_text = "chinese: \u5f20\u4e09" + for enc in ("utf8", "gbk"): + for quote in (True, False): + with self.subTest(enc=enc, quote=quote): + bin_text = uni_text.encode(enc) + if quote: + bin_text = urllib.parse.quote(bin_text).encode("ascii") + response = self.fetch( + "/post_" + enc, + method="POST", + headers=headers, + body=(b"data=" + bin_text), + ) + self.assertEqual(json_decode(response.body), {"echo": uni_text}) + + +class HTTPServerRawTest(AsyncHTTPTestCase): + def get_app(self): + return Application([("/echo", EchoHandler)]) + + def setUp(self): + super().setUp() + self.stream = IOStream(socket.socket()) + self.io_loop.run_sync( + lambda: self.stream.connect(("127.0.0.1", self.get_http_port())) + ) + + def tearDown(self): + self.stream.close() + super().tearDown() + + def test_empty_request(self): + self.stream.close() + self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) + self.wait() + + def test_malformed_first_line_response(self): + with ExpectLog(gen_log, ".*Malformed HTTP request line", level=logging.INFO): + self.stream.write(b"asdf\r\n\r\n") + start_line, headers, response = self.io_loop.run_sync( + lambda: read_stream_body(self.stream) + ) + self.assertEqual("HTTP/1.1", start_line.version) + self.assertEqual(400, start_line.code) + self.assertEqual("Bad Request", start_line.reason) + + def test_malformed_first_line_log(self): + with ExpectLog(gen_log, ".*Malformed HTTP request line", level=logging.INFO): + self.stream.write(b"asdf\r\n\r\n") + # TODO: need an async version of ExpectLog so we don't need + # hard-coded timeouts here. + self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) + self.wait() + + def test_malformed_headers(self): + with ExpectLog( + gen_log, + ".*Malformed HTTP message.*no colon in header line", + level=logging.INFO, + ): + self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n") + self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) + self.wait() + + def test_chunked_request_body(self): + # Chunked requests are not widely supported and we don't have a way + # to generate them in AsyncHTTPClient, but HTTPServer will read them. + self.stream.write( + b"""\ +POST /echo HTTP/1.1 +Transfer-Encoding: chunked +Content-Type: application/x-www-form-urlencoded + +4 +foo= +3 +bar +0 + +""".replace( + b"\n", b"\r\n" + ) + ) + start_line, headers, response = self.io_loop.run_sync( + lambda: read_stream_body(self.stream) + ) + self.assertEqual(json_decode(response), {"foo": ["bar"]}) + + def test_chunked_request_uppercase(self): + # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is + # case-insensitive. + self.stream.write( + b"""\ +POST /echo HTTP/1.1 +Transfer-Encoding: Chunked +Content-Type: application/x-www-form-urlencoded + +4 +foo= +3 +bar +0 + +""".replace( + b"\n", b"\r\n" + ) + ) + start_line, headers, response = self.io_loop.run_sync( + lambda: read_stream_body(self.stream) + ) + self.assertEqual(json_decode(response), {"foo": ["bar"]}) + + @gen_test + def test_invalid_content_length(self): + with ExpectLog( + gen_log, ".*Only integer Content-Length is allowed", level=logging.INFO + ): + self.stream.write( + b"""\ +POST /echo HTTP/1.1 +Content-Length: foo + +bar + +""".replace( + b"\n", b"\r\n" + ) + ) + yield self.stream.read_until_close() + + +class XHeaderTest(HandlerBaseTestCase): + class Handler(RequestHandler): + def get(self): + self.set_header("request-version", self.request.version) + self.write( + dict( + remote_ip=self.request.remote_ip, + remote_protocol=self.request.protocol, + ) + ) + + def get_httpserver_options(self): + return dict(xheaders=True, trusted_downstream=["5.5.5.5"]) + + def test_ip_headers(self): + self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1") + + valid_ipv4 = {"X-Real-IP": "4.4.4.4"} + self.assertEqual( + self.fetch_json("/", headers=valid_ipv4)["remote_ip"], "4.4.4.4" + ) + + valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4"} + self.assertEqual( + self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], "4.4.4.4" + ) + + valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"} + self.assertEqual( + self.fetch_json("/", headers=valid_ipv6)["remote_ip"], + "2620:0:1cfe:face:b00c::3", + ) + + valid_ipv6_list = {"X-Forwarded-For": "::1, 2620:0:1cfe:face:b00c::3"} + self.assertEqual( + self.fetch_json("/", headers=valid_ipv6_list)["remote_ip"], + "2620:0:1cfe:face:b00c::3", + ) + + invalid_chars = {"X-Real-IP": "4.4.4.4<script>"} + self.assertEqual( + self.fetch_json("/", headers=invalid_chars)["remote_ip"], "127.0.0.1" + ) + + invalid_chars_list = {"X-Forwarded-For": "4.4.4.4, 5.5.5.5<script>"} + self.assertEqual( + self.fetch_json("/", headers=invalid_chars_list)["remote_ip"], "127.0.0.1" + ) + + invalid_host = {"X-Real-IP": "www.google.com"} + self.assertEqual( + self.fetch_json("/", headers=invalid_host)["remote_ip"], "127.0.0.1" + ) + + def test_trusted_downstream(self): + valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4, 5.5.5.5"} + resp = self.fetch("/", headers=valid_ipv4_list) + if resp.headers["request-version"].startswith("HTTP/2"): + # This is a hack - there's nothing that fundamentally requires http/1 + # here but tornado_http2 doesn't support it yet. + self.skipTest("requires HTTP/1.x") + result = json_decode(resp.body) + self.assertEqual(result["remote_ip"], "4.4.4.4") + + def test_scheme_headers(self): + self.assertEqual(self.fetch_json("/")["remote_protocol"], "http") + + https_scheme = {"X-Scheme": "https"} + self.assertEqual( + self.fetch_json("/", headers=https_scheme)["remote_protocol"], "https" + ) + + https_forwarded = {"X-Forwarded-Proto": "https"} + self.assertEqual( + self.fetch_json("/", headers=https_forwarded)["remote_protocol"], "https" + ) + + https_multi_forwarded = {"X-Forwarded-Proto": "https , http"} + self.assertEqual( + self.fetch_json("/", headers=https_multi_forwarded)["remote_protocol"], + "http", + ) + + http_multi_forwarded = {"X-Forwarded-Proto": "http,https"} + self.assertEqual( + self.fetch_json("/", headers=http_multi_forwarded)["remote_protocol"], + "https", + ) + + bad_forwarded = {"X-Forwarded-Proto": "unknown"} + self.assertEqual( + self.fetch_json("/", headers=bad_forwarded)["remote_protocol"], "http" + ) + + +class SSLXHeaderTest(AsyncHTTPSTestCase, HandlerBaseTestCase): + def get_app(self): + return Application([("/", XHeaderTest.Handler)]) + + def get_httpserver_options(self): + output = super().get_httpserver_options() + output["xheaders"] = True + return output + + def test_request_without_xprotocol(self): + self.assertEqual(self.fetch_json("/")["remote_protocol"], "https") + + http_scheme = {"X-Scheme": "http"} + self.assertEqual( + self.fetch_json("/", headers=http_scheme)["remote_protocol"], "http" + ) + + bad_scheme = {"X-Scheme": "unknown"} + self.assertEqual( + self.fetch_json("/", headers=bad_scheme)["remote_protocol"], "https" + ) + + +class ManualProtocolTest(HandlerBaseTestCase): + class Handler(RequestHandler): + def get(self): + self.write(dict(protocol=self.request.protocol)) + + def get_httpserver_options(self): + return dict(protocol="https") + + def test_manual_protocol(self): + self.assertEqual(self.fetch_json("/")["protocol"], "https") + + +@unittest.skipIf( + not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin", + "unix sockets not supported on this platform", +) +class UnixSocketTest(AsyncTestCase): + """HTTPServers can listen on Unix sockets too. + + Why would you want to do this? Nginx can proxy to backends listening + on unix sockets, for one thing (and managing a namespace for unix + sockets can be easier than managing a bunch of TCP port numbers). + + Unfortunately, there's no way to specify a unix socket in a url for + an HTTP client, so we have to test this by hand. + """ + + def setUp(self): + super().setUp() + self.tmpdir = tempfile.mkdtemp() + self.sockfile = os.path.join(self.tmpdir, "test.sock") + sock = netutil.bind_unix_socket(self.sockfile) + app = Application([("/hello", HelloWorldRequestHandler)]) + self.server = HTTPServer(app) + self.server.add_socket(sock) + self.stream = IOStream(socket.socket(socket.AF_UNIX)) + self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile)) + + def tearDown(self): + self.stream.close() + self.io_loop.run_sync(self.server.close_all_connections) + self.server.stop() + shutil.rmtree(self.tmpdir) + super().tearDown() + + @gen_test + def test_unix_socket(self): + self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n") + response = yield self.stream.read_until(b"\r\n") + self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") + header_data = yield self.stream.read_until(b"\r\n\r\n") + headers = HTTPHeaders.parse(header_data.decode("latin1")) + body = yield self.stream.read_bytes(int(headers["Content-Length"])) + self.assertEqual(body, b"Hello world") + + @gen_test + def test_unix_socket_bad_request(self): + # Unix sockets don't have remote addresses so they just return an + # empty string. + with ExpectLog(gen_log, "Malformed HTTP message from", level=logging.INFO): + self.stream.write(b"garbage\r\n\r\n") + response = yield self.stream.read_until_close() + self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n") + + +class KeepAliveTest(AsyncHTTPTestCase): + """Tests various scenarios for HTTP 1.1 keep-alive support. + + These tests don't use AsyncHTTPClient because we want to control + connection reuse and closing. + """ + + def get_app(self): + class HelloHandler(RequestHandler): + def get(self): + self.finish("Hello world") + + def post(self): + self.finish("Hello world") + + class LargeHandler(RequestHandler): + def get(self): + # 512KB should be bigger than the socket buffers so it will + # be written out in chunks. + self.write("".join(chr(i % 256) * 1024 for i in range(512))) + + class TransferEncodingChunkedHandler(RequestHandler): + @gen.coroutine + def head(self): + self.write("Hello world") + yield self.flush() + + class FinishOnCloseHandler(RequestHandler): + def initialize(self, cleanup_event): + self.cleanup_event = cleanup_event + + @gen.coroutine + def get(self): + self.flush() + yield self.cleanup_event.wait() + + def on_connection_close(self): + # This is not very realistic, but finishing the request + # from the close callback has the right timing to mimic + # some errors seen in the wild. + self.finish("closed") + + self.cleanup_event = Event() + return Application( + [ + ("/", HelloHandler), + ("/large", LargeHandler), + ("/chunked", TransferEncodingChunkedHandler), + ( + "/finish_on_close", + FinishOnCloseHandler, + dict(cleanup_event=self.cleanup_event), + ), + ] + ) + + def setUp(self): + super().setUp() + self.http_version = b"HTTP/1.1" + + def tearDown(self): + # We just closed the client side of the socket; let the IOLoop run + # once to make sure the server side got the message. + self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) + self.wait() + + if hasattr(self, "stream"): + self.stream.close() + super().tearDown() + + # The next few methods are a crude manual http client + @gen.coroutine + def connect(self): + self.stream = IOStream(socket.socket()) + yield self.stream.connect(("127.0.0.1", self.get_http_port())) + + @gen.coroutine + def read_headers(self): + first_line = yield self.stream.read_until(b"\r\n") + self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) + header_bytes = yield self.stream.read_until(b"\r\n\r\n") + headers = HTTPHeaders.parse(header_bytes.decode("latin1")) + raise gen.Return(headers) + + @gen.coroutine + def read_response(self): + self.headers = yield self.read_headers() + body = yield self.stream.read_bytes(int(self.headers["Content-Length"])) + self.assertEqual(b"Hello world", body) + + def close(self): + self.stream.close() + del self.stream + + @gen_test + def test_two_requests(self): + yield self.connect() + self.stream.write(b"GET / HTTP/1.1\r\n\r\n") + yield self.read_response() + self.stream.write(b"GET / HTTP/1.1\r\n\r\n") + yield self.read_response() + self.close() + + @gen_test + def test_request_close(self): + yield self.connect() + self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") + yield self.read_response() + data = yield self.stream.read_until_close() + self.assertTrue(not data) + self.assertEqual(self.headers["Connection"], "close") + self.close() + + # keepalive is supported for http 1.0 too, but it's opt-in + @gen_test + def test_http10(self): + self.http_version = b"HTTP/1.0" + yield self.connect() + self.stream.write(b"GET / HTTP/1.0\r\n\r\n") + yield self.read_response() + data = yield self.stream.read_until_close() + self.assertTrue(not data) + self.assertTrue("Connection" not in self.headers) + self.close() + + @gen_test + def test_http10_keepalive(self): + self.http_version = b"HTTP/1.0" + yield self.connect() + self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") + yield self.read_response() + self.assertEqual(self.headers["Connection"], "Keep-Alive") + self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") + yield self.read_response() + self.assertEqual(self.headers["Connection"], "Keep-Alive") + self.close() + + @gen_test + def test_http10_keepalive_extra_crlf(self): + self.http_version = b"HTTP/1.0" + yield self.connect() + self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n") + yield self.read_response() + self.assertEqual(self.headers["Connection"], "Keep-Alive") + self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") + yield self.read_response() + self.assertEqual(self.headers["Connection"], "Keep-Alive") + self.close() + + @gen_test + def test_pipelined_requests(self): + yield self.connect() + self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n") + yield self.read_response() + yield self.read_response() + self.close() + + @gen_test + def test_pipelined_cancel(self): + yield self.connect() + self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n") + # only read once + yield self.read_response() + self.close() + + @gen_test + def test_cancel_during_download(self): + yield self.connect() + self.stream.write(b"GET /large HTTP/1.1\r\n\r\n") + yield self.read_headers() + yield self.stream.read_bytes(1024) + self.close() + + @gen_test + def test_finish_while_closed(self): + yield self.connect() + self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n") + yield self.read_headers() + self.close() + # Let the hanging coroutine clean up after itself + self.cleanup_event.set() + + @gen_test + def test_keepalive_chunked(self): + self.http_version = b"HTTP/1.0" + yield self.connect() + self.stream.write( + b"POST / HTTP/1.0\r\n" + b"Connection: keep-alive\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"0\r\n" + b"\r\n" + ) + yield self.read_response() + self.assertEqual(self.headers["Connection"], "Keep-Alive") + self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") + yield self.read_response() + self.assertEqual(self.headers["Connection"], "Keep-Alive") + self.close() + + @gen_test + def test_keepalive_chunked_head_no_body(self): + yield self.connect() + self.stream.write(b"HEAD /chunked HTTP/1.1\r\n\r\n") + yield self.read_headers() + + self.stream.write(b"HEAD /chunked HTTP/1.1\r\n\r\n") + yield self.read_headers() + self.close() + + +class GzipBaseTest(AsyncHTTPTestCase): + def get_app(self): + return Application([("/", EchoHandler)]) + + def post_gzip(self, body): + bytesio = BytesIO() + gzip_file = gzip.GzipFile(mode="w", fileobj=bytesio) + gzip_file.write(utf8(body)) + gzip_file.close() + compressed_body = bytesio.getvalue() + return self.fetch( + "/", + method="POST", + body=compressed_body, + headers={"Content-Encoding": "gzip"}, + ) + + def test_uncompressed(self): + response = self.fetch("/", method="POST", body="foo=bar") + self.assertEqual(json_decode(response.body), {"foo": ["bar"]}) + + +class GzipTest(GzipBaseTest, AsyncHTTPTestCase): + def get_httpserver_options(self): + return dict(decompress_request=True) + + def test_gzip(self): + response = self.post_gzip("foo=bar") + self.assertEqual(json_decode(response.body), {"foo": ["bar"]}) + + def test_gzip_case_insensitive(self): + # https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.2.1 + bytesio = BytesIO() + gzip_file = gzip.GzipFile(mode="w", fileobj=bytesio) + gzip_file.write(utf8("foo=bar")) + gzip_file.close() + compressed_body = bytesio.getvalue() + response = self.fetch( + "/", + method="POST", + body=compressed_body, + headers={"Content-Encoding": "GZIP"}, + ) + self.assertEqual(json_decode(response.body), {"foo": ["bar"]}) + + +class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase): + def test_gzip_unsupported(self): + # Gzip support is opt-in; without it the server fails to parse + # the body (but parsing form bodies is currently just a log message, + # not a fatal error). + with ExpectLog(gen_log, "Unsupported Content-Encoding"): + response = self.post_gzip("foo=bar") + self.assertEqual(json_decode(response.body), {}) + + +class StreamingChunkSizeTest(AsyncHTTPTestCase): + # 50 characters long, and repetitive so it can be compressed. + BODY = b"01234567890123456789012345678901234567890123456789" + CHUNK_SIZE = 16 + + def get_http_client(self): + # body_producer doesn't work on curl_httpclient, so override the + # configured AsyncHTTPClient implementation. + return SimpleAsyncHTTPClient() + + def get_httpserver_options(self): + return dict(chunk_size=self.CHUNK_SIZE, decompress_request=True) + + class MessageDelegate(HTTPMessageDelegate): + def __init__(self, connection): + self.connection = connection + + def headers_received(self, start_line, headers): + self.chunk_lengths = [] # type: List[int] + + def data_received(self, chunk): + self.chunk_lengths.append(len(chunk)) + + def finish(self): + response_body = utf8(json_encode(self.chunk_lengths)) + self.connection.write_headers( + ResponseStartLine("HTTP/1.1", 200, "OK"), + HTTPHeaders({"Content-Length": str(len(response_body))}), + ) + self.connection.write(response_body) + self.connection.finish() + + def get_app(self): + class App(HTTPServerConnectionDelegate): + def start_request(self, server_conn, request_conn): + return StreamingChunkSizeTest.MessageDelegate(request_conn) + + return App() + + def fetch_chunk_sizes(self, **kwargs): + response = self.fetch("/", method="POST", **kwargs) + response.rethrow() + chunks = json_decode(response.body) + self.assertEqual(len(self.BODY), sum(chunks)) + for chunk_size in chunks: + self.assertLessEqual( + chunk_size, self.CHUNK_SIZE, "oversized chunk: " + str(chunks) + ) + self.assertGreater(chunk_size, 0, "empty chunk: " + str(chunks)) + return chunks + + def compress(self, body): + bytesio = BytesIO() + gzfile = gzip.GzipFile(mode="w", fileobj=bytesio) + gzfile.write(body) + gzfile.close() + compressed = bytesio.getvalue() + if len(compressed) >= len(body): + raise Exception("body did not shrink when compressed") + return compressed + + def test_regular_body(self): + chunks = self.fetch_chunk_sizes(body=self.BODY) + # Without compression we know exactly what to expect. + self.assertEqual([16, 16, 16, 2], chunks) + + def test_compressed_body(self): + self.fetch_chunk_sizes( + body=self.compress(self.BODY), headers={"Content-Encoding": "gzip"} + ) + # Compression creates irregular boundaries so the assertions + # in fetch_chunk_sizes are as specific as we can get. + + def test_chunked_body(self): + def body_producer(write): + write(self.BODY[:20]) + write(self.BODY[20:]) + + chunks = self.fetch_chunk_sizes(body_producer=body_producer) + # HTTP chunk boundaries translate to application-visible breaks + self.assertEqual([16, 4, 16, 14], chunks) + + def test_chunked_compressed(self): + compressed = self.compress(self.BODY) + self.assertGreater(len(compressed), 20) + + def body_producer(write): + write(compressed[:20]) + write(compressed[20:]) + + self.fetch_chunk_sizes( + body_producer=body_producer, headers={"Content-Encoding": "gzip"} + ) + + +class MaxHeaderSizeTest(AsyncHTTPTestCase): + def get_app(self): + return Application([("/", HelloWorldRequestHandler)]) + + def get_httpserver_options(self): + return dict(max_header_size=1024) + + def test_small_headers(self): + response = self.fetch("/", headers={"X-Filler": "a" * 100}) + response.rethrow() + self.assertEqual(response.body, b"Hello world") + + def test_large_headers(self): + with ExpectLog(gen_log, "Unsatisfiable read", required=False): + try: + self.fetch("/", headers={"X-Filler": "a" * 1000}, raise_error=True) + self.fail("did not raise expected exception") + except HTTPError as e: + # 431 is "Request Header Fields Too Large", defined in RFC + # 6585. However, many implementations just close the + # connection in this case, resulting in a missing response. + if e.response is not None: + self.assertIn(e.response.code, (431, 599)) + + +@skipOnTravis +class IdleTimeoutTest(AsyncHTTPTestCase): + def get_app(self): + return Application([("/", HelloWorldRequestHandler)]) + + def get_httpserver_options(self): + return dict(idle_connection_timeout=0.1) + + def setUp(self): + super().setUp() + self.streams = [] # type: List[IOStream] + + def tearDown(self): + super().tearDown() + for stream in self.streams: + stream.close() + + @gen.coroutine + def connect(self): + stream = IOStream(socket.socket()) + yield stream.connect(("127.0.0.1", self.get_http_port())) + self.streams.append(stream) + raise gen.Return(stream) + + @gen_test + def test_unused_connection(self): + stream = yield self.connect() + event = Event() + stream.set_close_callback(event.set) + yield event.wait() + + @gen_test + def test_idle_after_use(self): + stream = yield self.connect() + event = Event() + stream.set_close_callback(event.set) + + # Use the connection twice to make sure keep-alives are working + for i in range(2): + stream.write(b"GET / HTTP/1.1\r\n\r\n") + yield stream.read_until(b"\r\n\r\n") + data = yield stream.read_bytes(11) + self.assertEqual(data, b"Hello world") + + # Now let the timeout trigger and close the connection. + yield event.wait() + + +class BodyLimitsTest(AsyncHTTPTestCase): + def get_app(self): + class BufferedHandler(RequestHandler): + def put(self): + self.write(str(len(self.request.body))) + + @stream_request_body + class StreamingHandler(RequestHandler): + def initialize(self): + self.bytes_read = 0 + + def prepare(self): + conn = typing.cast(HTTP1Connection, self.request.connection) + if "expected_size" in self.request.arguments: + conn.set_max_body_size(int(self.get_argument("expected_size"))) + if "body_timeout" in self.request.arguments: + conn.set_body_timeout(float(self.get_argument("body_timeout"))) + + def data_received(self, data): + self.bytes_read += len(data) + + def put(self): + self.write(str(self.bytes_read)) + + return Application( + [("/buffered", BufferedHandler), ("/streaming", StreamingHandler)] + ) + + def get_httpserver_options(self): + return dict(body_timeout=3600, max_body_size=4096) + + def get_http_client(self): + # body_producer doesn't work on curl_httpclient, so override the + # configured AsyncHTTPClient implementation. + return SimpleAsyncHTTPClient() + + def test_small_body(self): + response = self.fetch("/buffered", method="PUT", body=b"a" * 4096) + self.assertEqual(response.body, b"4096") + response = self.fetch("/streaming", method="PUT", body=b"a" * 4096) + self.assertEqual(response.body, b"4096") + + def test_large_body_buffered(self): + with ExpectLog(gen_log, ".*Content-Length too long", level=logging.INFO): + response = self.fetch("/buffered", method="PUT", body=b"a" * 10240) + self.assertEqual(response.code, 400) + + @unittest.skipIf(os.name == "nt", "flaky on windows") + def test_large_body_buffered_chunked(self): + # This test is flaky on windows for unknown reasons. + with ExpectLog(gen_log, ".*chunked body too large", level=logging.INFO): + response = self.fetch( + "/buffered", + method="PUT", + body_producer=lambda write: write(b"a" * 10240), + ) + self.assertEqual(response.code, 400) + + def test_large_body_streaming(self): + with ExpectLog(gen_log, ".*Content-Length too long", level=logging.INFO): + response = self.fetch("/streaming", method="PUT", body=b"a" * 10240) + self.assertEqual(response.code, 400) + + @unittest.skipIf(os.name == "nt", "flaky on windows") + def test_large_body_streaming_chunked(self): + with ExpectLog(gen_log, ".*chunked body too large", level=logging.INFO): + response = self.fetch( + "/streaming", + method="PUT", + body_producer=lambda write: write(b"a" * 10240), + ) + self.assertEqual(response.code, 400) + + def test_large_body_streaming_override(self): + response = self.fetch( + "/streaming?expected_size=10240", method="PUT", body=b"a" * 10240 + ) + self.assertEqual(response.body, b"10240") + + def test_large_body_streaming_chunked_override(self): + response = self.fetch( + "/streaming?expected_size=10240", + method="PUT", + body_producer=lambda write: write(b"a" * 10240), + ) + self.assertEqual(response.body, b"10240") + + @gen_test + def test_timeout(self): + stream = IOStream(socket.socket()) + try: + yield stream.connect(("127.0.0.1", self.get_http_port())) + # Use a raw stream because AsyncHTTPClient won't let us read a + # response without finishing a body. + stream.write( + b"PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n" + b"Content-Length: 42\r\n\r\n" + ) + with ExpectLog(gen_log, "Timeout reading body", level=logging.INFO): + response = yield stream.read_until_close() + self.assertEqual(response, b"") + finally: + stream.close() + + @gen_test + def test_body_size_override_reset(self): + # The max_body_size override is reset between requests. + stream = IOStream(socket.socket()) + try: + yield stream.connect(("127.0.0.1", self.get_http_port())) + # Use a raw stream so we can make sure it's all on one connection. + stream.write( + b"PUT /streaming?expected_size=10240 HTTP/1.1\r\n" + b"Content-Length: 10240\r\n\r\n" + ) + stream.write(b"a" * 10240) + start_line, headers, response = yield read_stream_body(stream) + self.assertEqual(response, b"10240") + # Without the ?expected_size parameter, we get the old default value + stream.write( + b"PUT /streaming HTTP/1.1\r\n" b"Content-Length: 10240\r\n\r\n" + ) + with ExpectLog(gen_log, ".*Content-Length too long", level=logging.INFO): + data = yield stream.read_until_close() + self.assertEqual(data, b"HTTP/1.1 400 Bad Request\r\n\r\n") + finally: + stream.close() + + +class LegacyInterfaceTest(AsyncHTTPTestCase): + def get_app(self): + # The old request_callback interface does not implement the + # delegate interface, and writes its response via request.write + # instead of request.connection.write_headers. + def handle_request(request): + self.http1 = request.version.startswith("HTTP/1.") + if not self.http1: + # This test will be skipped if we're using HTTP/2, + # so just close it out cleanly using the modern interface. + request.connection.write_headers( + ResponseStartLine("", 200, "OK"), HTTPHeaders() + ) + request.connection.finish() + return + message = b"Hello world" + request.connection.write( + utf8("HTTP/1.1 200 OK\r\n" "Content-Length: %d\r\n\r\n" % len(message)) + ) + request.connection.write(message) + request.connection.finish() + + return handle_request + + def test_legacy_interface(self): + response = self.fetch("/") + if not self.http1: + self.skipTest("requires HTTP/1.x") + self.assertEqual(response.body, b"Hello world") diff --git a/venv/lib/python3.9/site-packages/tornado/test/httputil_test.py b/venv/lib/python3.9/site-packages/tornado/test/httputil_test.py new file mode 100644 index 00000000..8424491d --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/httputil_test.py @@ -0,0 +1,521 @@ +from tornado.httputil import ( + url_concat, + parse_multipart_form_data, + HTTPHeaders, + format_timestamp, + HTTPServerRequest, + parse_request_start_line, + parse_cookie, + qs_to_qsl, + HTTPInputError, + HTTPFile, +) +from tornado.escape import utf8, native_str +from tornado.log import gen_log +from tornado.testing import ExpectLog + +import copy +import datetime +import logging +import pickle +import time +import urllib.parse +import unittest + +from typing import Tuple, Dict, List + + +def form_data_args() -> Tuple[Dict[str, List[bytes]], Dict[str, List[HTTPFile]]]: + """Return two empty dicts suitable for use with parse_multipart_form_data. + + mypy insists on type annotations for dict literals, so this lets us avoid + the verbose types throughout this test. + """ + return {}, {} + + +class TestUrlConcat(unittest.TestCase): + def test_url_concat_no_query_params(self): + url = url_concat("https://localhost/path", [("y", "y"), ("z", "z")]) + self.assertEqual(url, "https://localhost/path?y=y&z=z") + + def test_url_concat_encode_args(self): + url = url_concat("https://localhost/path", [("y", "/y"), ("z", "z")]) + self.assertEqual(url, "https://localhost/path?y=%2Fy&z=z") + + def test_url_concat_trailing_q(self): + url = url_concat("https://localhost/path?", [("y", "y"), ("z", "z")]) + self.assertEqual(url, "https://localhost/path?y=y&z=z") + + def test_url_concat_q_with_no_trailing_amp(self): + url = url_concat("https://localhost/path?x", [("y", "y"), ("z", "z")]) + self.assertEqual(url, "https://localhost/path?x=&y=y&z=z") + + def test_url_concat_trailing_amp(self): + url = url_concat("https://localhost/path?x&", [("y", "y"), ("z", "z")]) + self.assertEqual(url, "https://localhost/path?x=&y=y&z=z") + + def test_url_concat_mult_params(self): + url = url_concat("https://localhost/path?a=1&b=2", [("y", "y"), ("z", "z")]) + self.assertEqual(url, "https://localhost/path?a=1&b=2&y=y&z=z") + + def test_url_concat_no_params(self): + url = url_concat("https://localhost/path?r=1&t=2", []) + self.assertEqual(url, "https://localhost/path?r=1&t=2") + + def test_url_concat_none_params(self): + url = url_concat("https://localhost/path?r=1&t=2", None) + self.assertEqual(url, "https://localhost/path?r=1&t=2") + + def test_url_concat_with_frag(self): + url = url_concat("https://localhost/path#tab", [("y", "y")]) + self.assertEqual(url, "https://localhost/path?y=y#tab") + + def test_url_concat_multi_same_params(self): + url = url_concat("https://localhost/path", [("y", "y1"), ("y", "y2")]) + self.assertEqual(url, "https://localhost/path?y=y1&y=y2") + + def test_url_concat_multi_same_query_params(self): + url = url_concat("https://localhost/path?r=1&r=2", [("y", "y")]) + self.assertEqual(url, "https://localhost/path?r=1&r=2&y=y") + + def test_url_concat_dict_params(self): + url = url_concat("https://localhost/path", dict(y="y")) + self.assertEqual(url, "https://localhost/path?y=y") + + +class QsParseTest(unittest.TestCase): + def test_parsing(self): + qsstring = "a=1&b=2&a=3" + qs = urllib.parse.parse_qs(qsstring) + qsl = list(qs_to_qsl(qs)) + self.assertIn(("a", "1"), qsl) + self.assertIn(("a", "3"), qsl) + self.assertIn(("b", "2"), qsl) + + +class MultipartFormDataTest(unittest.TestCase): + def test_file_upload(self): + data = b"""\ +--1234 +Content-Disposition: form-data; name="files"; filename="ab.txt" + +Foo +--1234--""".replace( + b"\n", b"\r\n" + ) + args, files = form_data_args() + parse_multipart_form_data(b"1234", data, args, files) + file = files["files"][0] + self.assertEqual(file["filename"], "ab.txt") + self.assertEqual(file["body"], b"Foo") + + def test_unquoted_names(self): + # quotes are optional unless special characters are present + data = b"""\ +--1234 +Content-Disposition: form-data; name=files; filename=ab.txt + +Foo +--1234--""".replace( + b"\n", b"\r\n" + ) + args, files = form_data_args() + parse_multipart_form_data(b"1234", data, args, files) + file = files["files"][0] + self.assertEqual(file["filename"], "ab.txt") + self.assertEqual(file["body"], b"Foo") + + def test_special_filenames(self): + filenames = [ + "a;b.txt", + 'a"b.txt', + 'a";b.txt', + 'a;"b.txt', + 'a";";.txt', + 'a\\"b.txt', + "a\\b.txt", + ] + for filename in filenames: + logging.debug("trying filename %r", filename) + str_data = """\ +--1234 +Content-Disposition: form-data; name="files"; filename="%s" + +Foo +--1234--""" % filename.replace( + "\\", "\\\\" + ).replace( + '"', '\\"' + ) + data = utf8(str_data.replace("\n", "\r\n")) + args, files = form_data_args() + parse_multipart_form_data(b"1234", data, args, files) + file = files["files"][0] + self.assertEqual(file["filename"], filename) + self.assertEqual(file["body"], b"Foo") + + def test_non_ascii_filename(self): + data = b"""\ +--1234 +Content-Disposition: form-data; name="files"; filename="ab.txt"; filename*=UTF-8''%C3%A1b.txt + +Foo +--1234--""".replace( + b"\n", b"\r\n" + ) + args, files = form_data_args() + parse_multipart_form_data(b"1234", data, args, files) + file = files["files"][0] + self.assertEqual(file["filename"], "áb.txt") + self.assertEqual(file["body"], b"Foo") + + def test_boundary_starts_and_ends_with_quotes(self): + data = b"""\ +--1234 +Content-Disposition: form-data; name="files"; filename="ab.txt" + +Foo +--1234--""".replace( + b"\n", b"\r\n" + ) + args, files = form_data_args() + parse_multipart_form_data(b'"1234"', data, args, files) + file = files["files"][0] + self.assertEqual(file["filename"], "ab.txt") + self.assertEqual(file["body"], b"Foo") + + def test_missing_headers(self): + data = b"""\ +--1234 + +Foo +--1234--""".replace( + b"\n", b"\r\n" + ) + args, files = form_data_args() + with ExpectLog(gen_log, "multipart/form-data missing headers"): + parse_multipart_form_data(b"1234", data, args, files) + self.assertEqual(files, {}) + + def test_invalid_content_disposition(self): + data = b"""\ +--1234 +Content-Disposition: invalid; name="files"; filename="ab.txt" + +Foo +--1234--""".replace( + b"\n", b"\r\n" + ) + args, files = form_data_args() + with ExpectLog(gen_log, "Invalid multipart/form-data"): + parse_multipart_form_data(b"1234", data, args, files) + self.assertEqual(files, {}) + + def test_line_does_not_end_with_correct_line_break(self): + data = b"""\ +--1234 +Content-Disposition: form-data; name="files"; filename="ab.txt" + +Foo--1234--""".replace( + b"\n", b"\r\n" + ) + args, files = form_data_args() + with ExpectLog(gen_log, "Invalid multipart/form-data"): + parse_multipart_form_data(b"1234", data, args, files) + self.assertEqual(files, {}) + + def test_content_disposition_header_without_name_parameter(self): + data = b"""\ +--1234 +Content-Disposition: form-data; filename="ab.txt" + +Foo +--1234--""".replace( + b"\n", b"\r\n" + ) + args, files = form_data_args() + with ExpectLog(gen_log, "multipart/form-data value missing name"): + parse_multipart_form_data(b"1234", data, args, files) + self.assertEqual(files, {}) + + def test_data_after_final_boundary(self): + # The spec requires that data after the final boundary be ignored. + # http://www.w3.org/Protocols/rfc1341/7_2_Multipart.html + # In practice, some libraries include an extra CRLF after the boundary. + data = b"""\ +--1234 +Content-Disposition: form-data; name="files"; filename="ab.txt" + +Foo +--1234-- +""".replace( + b"\n", b"\r\n" + ) + args, files = form_data_args() + parse_multipart_form_data(b"1234", data, args, files) + file = files["files"][0] + self.assertEqual(file["filename"], "ab.txt") + self.assertEqual(file["body"], b"Foo") + + +class HTTPHeadersTest(unittest.TestCase): + def test_multi_line(self): + # Lines beginning with whitespace are appended to the previous line + # with any leading whitespace replaced by a single space. + # Note that while multi-line headers are a part of the HTTP spec, + # their use is strongly discouraged. + data = """\ +Foo: bar + baz +Asdf: qwer +\tzxcv +Foo: even + more + lines +""".replace( + "\n", "\r\n" + ) + headers = HTTPHeaders.parse(data) + self.assertEqual(headers["asdf"], "qwer zxcv") + self.assertEqual(headers.get_list("asdf"), ["qwer zxcv"]) + self.assertEqual(headers["Foo"], "bar baz,even more lines") + self.assertEqual(headers.get_list("foo"), ["bar baz", "even more lines"]) + self.assertEqual( + sorted(list(headers.get_all())), + [("Asdf", "qwer zxcv"), ("Foo", "bar baz"), ("Foo", "even more lines")], + ) + + def test_malformed_continuation(self): + # If the first line starts with whitespace, it's a + # continuation line with nothing to continue, so reject it + # (with a proper error). + data = " Foo: bar" + self.assertRaises(HTTPInputError, HTTPHeaders.parse, data) + + def test_unicode_newlines(self): + # Ensure that only \r\n is recognized as a header separator, and not + # the other newline-like unicode characters. + # Characters that are likely to be problematic can be found in + # http://unicode.org/standard/reports/tr13/tr13-5.html + # and cpython's unicodeobject.c (which defines the implementation + # of unicode_type.splitlines(), and uses a different list than TR13). + newlines = [ + "\u001b", # VERTICAL TAB + "\u001c", # FILE SEPARATOR + "\u001d", # GROUP SEPARATOR + "\u001e", # RECORD SEPARATOR + "\u0085", # NEXT LINE + "\u2028", # LINE SEPARATOR + "\u2029", # PARAGRAPH SEPARATOR + ] + for newline in newlines: + # Try the utf8 and latin1 representations of each newline + for encoding in ["utf8", "latin1"]: + try: + try: + encoded = newline.encode(encoding) + except UnicodeEncodeError: + # Some chars cannot be represented in latin1 + continue + data = b"Cookie: foo=" + encoded + b"bar" + # parse() wants a native_str, so decode through latin1 + # in the same way the real parser does. + headers = HTTPHeaders.parse(native_str(data.decode("latin1"))) + expected = [ + ( + "Cookie", + "foo=" + native_str(encoded.decode("latin1")) + "bar", + ) + ] + self.assertEqual(expected, list(headers.get_all())) + except Exception: + gen_log.warning("failed while trying %r in %s", newline, encoding) + raise + + def test_optional_cr(self): + # Both CRLF and LF should be accepted as separators. CR should not be + # part of the data when followed by LF, but it is a normal char + # otherwise (or should bare CR be an error?) + headers = HTTPHeaders.parse("CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n") + self.assertEqual( + sorted(headers.get_all()), + [("Cr", "cr\rMore: more"), ("Crlf", "crlf"), ("Lf", "lf")], + ) + + def test_copy(self): + all_pairs = [("A", "1"), ("A", "2"), ("B", "c")] + h1 = HTTPHeaders() + for k, v in all_pairs: + h1.add(k, v) + h2 = h1.copy() + h3 = copy.copy(h1) + h4 = copy.deepcopy(h1) + for headers in [h1, h2, h3, h4]: + # All the copies are identical, no matter how they were + # constructed. + self.assertEqual(list(sorted(headers.get_all())), all_pairs) + for headers in [h2, h3, h4]: + # Neither the dict or its member lists are reused. + self.assertIsNot(headers, h1) + self.assertIsNot(headers.get_list("A"), h1.get_list("A")) + + def test_pickle_roundtrip(self): + headers = HTTPHeaders() + headers.add("Set-Cookie", "a=b") + headers.add("Set-Cookie", "c=d") + headers.add("Content-Type", "text/html") + pickled = pickle.dumps(headers) + unpickled = pickle.loads(pickled) + self.assertEqual(sorted(headers.get_all()), sorted(unpickled.get_all())) + self.assertEqual(sorted(headers.items()), sorted(unpickled.items())) + + def test_setdefault(self): + headers = HTTPHeaders() + headers["foo"] = "bar" + # If a value is present, setdefault returns it without changes. + self.assertEqual(headers.setdefault("foo", "baz"), "bar") + self.assertEqual(headers["foo"], "bar") + # If a value is not present, setdefault sets it for future use. + self.assertEqual(headers.setdefault("quux", "xyzzy"), "xyzzy") + self.assertEqual(headers["quux"], "xyzzy") + self.assertEqual(sorted(headers.get_all()), [("Foo", "bar"), ("Quux", "xyzzy")]) + + def test_string(self): + headers = HTTPHeaders() + headers.add("Foo", "1") + headers.add("Foo", "2") + headers.add("Foo", "3") + headers2 = HTTPHeaders.parse(str(headers)) + self.assertEqual(headers, headers2) + + +class FormatTimestampTest(unittest.TestCase): + # Make sure that all the input types are supported. + TIMESTAMP = 1359312200.503611 + EXPECTED = "Sun, 27 Jan 2013 18:43:20 GMT" + + def check(self, value): + self.assertEqual(format_timestamp(value), self.EXPECTED) + + def test_unix_time_float(self): + self.check(self.TIMESTAMP) + + def test_unix_time_int(self): + self.check(int(self.TIMESTAMP)) + + def test_struct_time(self): + self.check(time.gmtime(self.TIMESTAMP)) + + def test_time_tuple(self): + tup = tuple(time.gmtime(self.TIMESTAMP)) + self.assertEqual(9, len(tup)) + self.check(tup) + + def test_datetime(self): + self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP)) + + +# HTTPServerRequest is mainly tested incidentally to the server itself, +# but this tests the parts of the class that can be tested in isolation. +class HTTPServerRequestTest(unittest.TestCase): + def test_default_constructor(self): + # All parameters are formally optional, but uri is required + # (and has been for some time). This test ensures that no + # more required parameters slip in. + HTTPServerRequest(uri="/") + + def test_body_is_a_byte_string(self): + requets = HTTPServerRequest(uri="/") + self.assertIsInstance(requets.body, bytes) + + def test_repr_does_not_contain_headers(self): + request = HTTPServerRequest( + uri="/", headers=HTTPHeaders({"Canary": ["Coal Mine"]}) + ) + self.assertTrue("Canary" not in repr(request)) + + +class ParseRequestStartLineTest(unittest.TestCase): + METHOD = "GET" + PATH = "/foo" + VERSION = "HTTP/1.1" + + def test_parse_request_start_line(self): + start_line = " ".join([self.METHOD, self.PATH, self.VERSION]) + parsed_start_line = parse_request_start_line(start_line) + self.assertEqual(parsed_start_line.method, self.METHOD) + self.assertEqual(parsed_start_line.path, self.PATH) + self.assertEqual(parsed_start_line.version, self.VERSION) + + +class ParseCookieTest(unittest.TestCase): + # These tests copied from Django: + # https://github.com/django/django/pull/6277/commits/da810901ada1cae9fc1f018f879f11a7fb467b28 + def test_python_cookies(self): + """ + Test cases copied from Python's Lib/test/test_http_cookies.py + """ + self.assertEqual( + parse_cookie("chips=ahoy; vienna=finger"), + {"chips": "ahoy", "vienna": "finger"}, + ) + # Here parse_cookie() differs from Python's cookie parsing in that it + # treats all semicolons as delimiters, even within quotes. + self.assertEqual( + parse_cookie('keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"'), + {"keebler": '"E=mc2', "L": '\\"Loves\\"', "fudge": "\\012", "": '"'}, + ) + # Illegal cookies that have an '=' char in an unquoted value. + self.assertEqual(parse_cookie("keebler=E=mc2"), {"keebler": "E=mc2"}) + # Cookies with ':' character in their name. + self.assertEqual( + parse_cookie("key:term=value:term"), {"key:term": "value:term"} + ) + # Cookies with '[' and ']'. + self.assertEqual( + parse_cookie("a=b; c=[; d=r; f=h"), {"a": "b", "c": "[", "d": "r", "f": "h"} + ) + + def test_cookie_edgecases(self): + # Cookies that RFC6265 allows. + self.assertEqual( + parse_cookie("a=b; Domain=example.com"), {"a": "b", "Domain": "example.com"} + ) + # parse_cookie() has historically kept only the last cookie with the + # same name. + self.assertEqual(parse_cookie("a=b; h=i; a=c"), {"a": "c", "h": "i"}) + + def test_invalid_cookies(self): + """ + Cookie strings that go against RFC6265 but browsers will send if set + via document.cookie. + """ + # Chunks without an equals sign appear as unnamed values per + # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 + self.assertIn( + "django_language", + parse_cookie("abc=def; unnamed; django_language=en").keys(), + ) + # Even a double quote may be an unamed value. + self.assertEqual(parse_cookie('a=b; "; c=d'), {"a": "b", "": '"', "c": "d"}) + # Spaces in names and values, and an equals sign in values. + self.assertEqual( + parse_cookie("a b c=d e = f; gh=i"), {"a b c": "d e = f", "gh": "i"} + ) + # More characters the spec forbids. + self.assertEqual( + parse_cookie('a b,c<>@:/[]?{}=d " =e,f g'), + {"a b,c<>@:/[]?{}": 'd " =e,f g'}, + ) + # Unicode characters. The spec only allows ASCII. + self.assertEqual( + parse_cookie("saint=André Bessette"), + {"saint": native_str("André Bessette")}, + ) + # Browsers don't send extra whitespace or semicolons in Cookie headers, + # but parse_cookie() should parse whitespace the same way + # document.cookie parses whitespace. + self.assertEqual( + parse_cookie(" = b ; ; = ; c = ; "), {"": "b", "c": ""} + ) diff --git a/venv/lib/python3.9/site-packages/tornado/test/import_test.py b/venv/lib/python3.9/site-packages/tornado/test/import_test.py new file mode 100644 index 00000000..1ff52206 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/import_test.py @@ -0,0 +1,65 @@ +# flake8: noqa +import subprocess +import sys +import unittest + +_import_everything = b""" +# The event loop is not fork-safe, and it's easy to initialize an asyncio.Future +# at startup, which in turn creates the default event loop and prevents forking. +# Explicitly disallow the default event loop so that an error will be raised +# if something tries to touch it. +import asyncio +asyncio.set_event_loop(None) + +import importlib +import tornado + +for mod in tornado.__all__: + if mod == "curl_httpclient": + # This module has extra dependencies; skip it if they're not installed. + try: + import pycurl + except ImportError: + continue + importlib.import_module(f"tornado.{mod}") +""" + +_import_lazy = b""" +import sys +import tornado + +if "tornado.web" in sys.modules: + raise Exception("unexpected eager import") + +# Trigger a lazy import by referring to something in a submodule. +tornado.web.RequestHandler + +if "tornado.web" not in sys.modules: + raise Exception("lazy import did not update sys.modules") +""" + + +class ImportTest(unittest.TestCase): + def test_import_everything(self): + # Test that all Tornado modules can be imported without side effects, + # specifically without initializing the default asyncio event loop. + # Since we can't tell which modules may have already beein imported + # in our process, do it in a subprocess for a clean slate. + proc = subprocess.Popen([sys.executable], stdin=subprocess.PIPE) + proc.communicate(_import_everything) + self.assertEqual(proc.returncode, 0) + + def test_lazy_import(self): + # Test that submodules can be referenced lazily after "import tornado" + proc = subprocess.Popen([sys.executable], stdin=subprocess.PIPE) + proc.communicate(_import_lazy) + self.assertEqual(proc.returncode, 0) + + def test_import_aliases(self): + # Ensure we don't delete formerly-documented aliases accidentally. + import tornado + import asyncio + + self.assertIs(tornado.ioloop.TimeoutError, tornado.util.TimeoutError) + self.assertIs(tornado.gen.TimeoutError, tornado.util.TimeoutError) + self.assertIs(tornado.util.TimeoutError, asyncio.TimeoutError) diff --git a/venv/lib/python3.9/site-packages/tornado/test/ioloop_test.py b/venv/lib/python3.9/site-packages/tornado/test/ioloop_test.py new file mode 100644 index 00000000..7de392f8 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/ioloop_test.py @@ -0,0 +1,795 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from concurrent import futures +from collections.abc import Generator +import contextlib +import datetime +import functools +import socket +import subprocess +import sys +import threading +import time +import types +from unittest import mock +import unittest + +from tornado.escape import native_str +from tornado import gen +from tornado.ioloop import IOLoop, TimeoutError, PeriodicCallback +from tornado.log import app_log +from tornado.testing import ( + AsyncTestCase, + bind_unused_port, + ExpectLog, + gen_test, + setup_with_context_manager, +) +from tornado.test.util import ( + ignore_deprecation, + skipIfNonUnix, + skipOnTravis, +) +from tornado.concurrent import Future + +import typing + +if typing.TYPE_CHECKING: + from typing import List # noqa: F401 + + +class TestIOLoop(AsyncTestCase): + def test_add_callback_return_sequence(self): + # A callback returning {} or [] shouldn't spin the CPU, see Issue #1803. + self.calls = 0 + + loop = self.io_loop + test = self + old_add_callback = loop.add_callback + + def add_callback(self, callback, *args, **kwargs): + test.calls += 1 + old_add_callback(callback, *args, **kwargs) + + loop.add_callback = types.MethodType(add_callback, loop) # type: ignore + loop.add_callback(lambda: {}) # type: ignore + loop.add_callback(lambda: []) # type: ignore + loop.add_timeout(datetime.timedelta(milliseconds=50), loop.stop) + loop.start() + self.assertLess(self.calls, 10) + + @skipOnTravis + def test_add_callback_wakeup(self): + # Make sure that add_callback from inside a running IOLoop + # wakes up the IOLoop immediately instead of waiting for a timeout. + def callback(): + self.called = True + self.stop() + + def schedule_callback(): + self.called = False + self.io_loop.add_callback(callback) + # Store away the time so we can check if we woke up immediately + self.start_time = time.time() + + self.io_loop.add_timeout(self.io_loop.time(), schedule_callback) + self.wait() + self.assertAlmostEqual(time.time(), self.start_time, places=2) + self.assertTrue(self.called) + + @skipOnTravis + def test_add_callback_wakeup_other_thread(self): + def target(): + # sleep a bit to let the ioloop go into its poll loop + time.sleep(0.01) + self.stop_time = time.time() + self.io_loop.add_callback(self.stop) + + thread = threading.Thread(target=target) + self.io_loop.add_callback(thread.start) + self.wait() + delta = time.time() - self.stop_time + self.assertLess(delta, 0.1) + thread.join() + + def test_add_timeout_timedelta(self): + self.io_loop.add_timeout(datetime.timedelta(microseconds=1), self.stop) + self.wait() + + def test_multiple_add(self): + sock, port = bind_unused_port() + try: + self.io_loop.add_handler( + sock.fileno(), lambda fd, events: None, IOLoop.READ + ) + # Attempting to add the same handler twice fails + # (with a platform-dependent exception) + self.assertRaises( + Exception, + self.io_loop.add_handler, + sock.fileno(), + lambda fd, events: None, + IOLoop.READ, + ) + finally: + self.io_loop.remove_handler(sock.fileno()) + sock.close() + + def test_remove_without_add(self): + # remove_handler should not throw an exception if called on an fd + # was never added. + sock, port = bind_unused_port() + try: + self.io_loop.remove_handler(sock.fileno()) + finally: + sock.close() + + def test_add_callback_from_signal(self): + # cheat a little bit and just run this normally, since we can't + # easily simulate the races that happen with real signal handlers + self.io_loop.add_callback_from_signal(self.stop) + self.wait() + + def test_add_callback_from_signal_other_thread(self): + # Very crude test, just to make sure that we cover this case. + # This also happens to be the first test where we run an IOLoop in + # a non-main thread. + other_ioloop = IOLoop() + thread = threading.Thread(target=other_ioloop.start) + thread.start() + other_ioloop.add_callback_from_signal(other_ioloop.stop) + thread.join() + other_ioloop.close() + + def test_add_callback_while_closing(self): + # add_callback should not fail if it races with another thread + # closing the IOLoop. The callbacks are dropped silently + # without executing. + closing = threading.Event() + + def target(): + other_ioloop.add_callback(other_ioloop.stop) + other_ioloop.start() + closing.set() + other_ioloop.close(all_fds=True) + + other_ioloop = IOLoop() + thread = threading.Thread(target=target) + thread.start() + closing.wait() + for i in range(1000): + other_ioloop.add_callback(lambda: None) + + @skipIfNonUnix # just because socketpair is so convenient + def test_read_while_writeable(self): + # Ensure that write events don't come in while we're waiting for + # a read and haven't asked for writeability. (the reverse is + # difficult to test for) + client, server = socket.socketpair() + try: + + def handler(fd, events): + self.assertEqual(events, IOLoop.READ) + self.stop() + + self.io_loop.add_handler(client.fileno(), handler, IOLoop.READ) + self.io_loop.add_timeout( + self.io_loop.time() + 0.01, functools.partial(server.send, b"asdf") + ) + self.wait() + self.io_loop.remove_handler(client.fileno()) + finally: + client.close() + server.close() + + def test_remove_timeout_after_fire(self): + # It is not an error to call remove_timeout after it has run. + handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop) + self.wait() + self.io_loop.remove_timeout(handle) + + def test_remove_timeout_cleanup(self): + # Add and remove enough callbacks to trigger cleanup. + # Not a very thorough test, but it ensures that the cleanup code + # gets executed and doesn't blow up. This test is only really useful + # on PollIOLoop subclasses, but it should run silently on any + # implementation. + for i in range(2000): + timeout = self.io_loop.add_timeout(self.io_loop.time() + 3600, lambda: None) + self.io_loop.remove_timeout(timeout) + # HACK: wait two IOLoop iterations for the GC to happen. + self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop)) + self.wait() + + def test_remove_timeout_from_timeout(self): + calls = [False, False] + + # Schedule several callbacks and wait for them all to come due at once. + # t2 should be cancelled by t1, even though it is already scheduled to + # be run before the ioloop even looks at it. + now = self.io_loop.time() + + def t1(): + calls[0] = True + self.io_loop.remove_timeout(t2_handle) + + self.io_loop.add_timeout(now + 0.01, t1) + + def t2(): + calls[1] = True + + t2_handle = self.io_loop.add_timeout(now + 0.02, t2) + self.io_loop.add_timeout(now + 0.03, self.stop) + time.sleep(0.03) + self.wait() + self.assertEqual(calls, [True, False]) + + def test_timeout_with_arguments(self): + # This tests that all the timeout methods pass through *args correctly. + results = [] # type: List[int] + self.io_loop.add_timeout(self.io_loop.time(), results.append, 1) + self.io_loop.add_timeout(datetime.timedelta(seconds=0), results.append, 2) + self.io_loop.call_at(self.io_loop.time(), results.append, 3) + self.io_loop.call_later(0, results.append, 4) + self.io_loop.call_later(0, self.stop) + self.wait() + # The asyncio event loop does not guarantee the order of these + # callbacks. + self.assertEqual(sorted(results), [1, 2, 3, 4]) + + def test_add_timeout_return(self): + # All the timeout methods return non-None handles that can be + # passed to remove_timeout. + handle = self.io_loop.add_timeout(self.io_loop.time(), lambda: None) + self.assertFalse(handle is None) + self.io_loop.remove_timeout(handle) + + def test_call_at_return(self): + handle = self.io_loop.call_at(self.io_loop.time(), lambda: None) + self.assertFalse(handle is None) + self.io_loop.remove_timeout(handle) + + def test_call_later_return(self): + handle = self.io_loop.call_later(0, lambda: None) + self.assertFalse(handle is None) + self.io_loop.remove_timeout(handle) + + def test_close_file_object(self): + """When a file object is used instead of a numeric file descriptor, + the object should be closed (by IOLoop.close(all_fds=True), + not just the fd. + """ + # Use a socket since they are supported by IOLoop on all platforms. + # Unfortunately, sockets don't support the .closed attribute for + # inspecting their close status, so we must use a wrapper. + class SocketWrapper(object): + def __init__(self, sockobj): + self.sockobj = sockobj + self.closed = False + + def fileno(self): + return self.sockobj.fileno() + + def close(self): + self.closed = True + self.sockobj.close() + + sockobj, port = bind_unused_port() + socket_wrapper = SocketWrapper(sockobj) + io_loop = IOLoop() + io_loop.add_handler(socket_wrapper, lambda fd, events: None, IOLoop.READ) + io_loop.close(all_fds=True) + self.assertTrue(socket_wrapper.closed) + + def test_handler_callback_file_object(self): + """The handler callback receives the same fd object it passed in.""" + server_sock, port = bind_unused_port() + fds = [] + + def handle_connection(fd, events): + fds.append(fd) + conn, addr = server_sock.accept() + conn.close() + self.stop() + + self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ) + with contextlib.closing(socket.socket()) as client_sock: + client_sock.connect(("127.0.0.1", port)) + self.wait() + self.io_loop.remove_handler(server_sock) + self.io_loop.add_handler(server_sock.fileno(), handle_connection, IOLoop.READ) + with contextlib.closing(socket.socket()) as client_sock: + client_sock.connect(("127.0.0.1", port)) + self.wait() + self.assertIs(fds[0], server_sock) + self.assertEqual(fds[1], server_sock.fileno()) + self.io_loop.remove_handler(server_sock.fileno()) + server_sock.close() + + def test_mixed_fd_fileobj(self): + server_sock, port = bind_unused_port() + + def f(fd, events): + pass + + self.io_loop.add_handler(server_sock, f, IOLoop.READ) + with self.assertRaises(Exception): + # The exact error is unspecified - some implementations use + # IOError, others use ValueError. + self.io_loop.add_handler(server_sock.fileno(), f, IOLoop.READ) + self.io_loop.remove_handler(server_sock.fileno()) + server_sock.close() + + def test_reentrant(self): + """Calling start() twice should raise an error, not deadlock.""" + returned_from_start = [False] + got_exception = [False] + + def callback(): + try: + self.io_loop.start() + returned_from_start[0] = True + except Exception: + got_exception[0] = True + self.stop() + + self.io_loop.add_callback(callback) + self.wait() + self.assertTrue(got_exception[0]) + self.assertFalse(returned_from_start[0]) + + def test_exception_logging(self): + """Uncaught exceptions get logged by the IOLoop.""" + self.io_loop.add_callback(lambda: 1 / 0) + self.io_loop.add_callback(self.stop) + with ExpectLog(app_log, "Exception in callback"): + self.wait() + + def test_exception_logging_future(self): + """The IOLoop examines exceptions from Futures and logs them.""" + + @gen.coroutine + def callback(): + self.io_loop.add_callback(self.stop) + 1 / 0 + + self.io_loop.add_callback(callback) + with ExpectLog(app_log, "Exception in callback"): + self.wait() + + def test_exception_logging_native_coro(self): + """The IOLoop examines exceptions from awaitables and logs them.""" + + async def callback(): + # Stop the IOLoop two iterations after raising an exception + # to give the exception time to be logged. + self.io_loop.add_callback(self.io_loop.add_callback, self.stop) + 1 / 0 + + self.io_loop.add_callback(callback) + with ExpectLog(app_log, "Exception in callback"): + self.wait() + + def test_spawn_callback(self): + # Both add_callback and spawn_callback run directly on the IOLoop, + # so their errors are logged without stopping the test. + self.io_loop.add_callback(lambda: 1 / 0) + self.io_loop.add_callback(self.stop) + with ExpectLog(app_log, "Exception in callback"): + self.wait() + # A spawned callback is run directly on the IOLoop, so it will be + # logged without stopping the test. + self.io_loop.spawn_callback(lambda: 1 / 0) + self.io_loop.add_callback(self.stop) + with ExpectLog(app_log, "Exception in callback"): + self.wait() + + @skipIfNonUnix + def test_remove_handler_from_handler(self): + # Create two sockets with simultaneous read events. + client, server = socket.socketpair() + try: + client.send(b"abc") + server.send(b"abc") + + # After reading from one fd, remove the other from the IOLoop. + chunks = [] + + def handle_read(fd, events): + chunks.append(fd.recv(1024)) + if fd is client: + self.io_loop.remove_handler(server) + else: + self.io_loop.remove_handler(client) + + self.io_loop.add_handler(client, handle_read, self.io_loop.READ) + self.io_loop.add_handler(server, handle_read, self.io_loop.READ) + self.io_loop.call_later(0.1, self.stop) + self.wait() + + # Only one fd was read; the other was cleanly removed. + self.assertEqual(chunks, [b"abc"]) + finally: + client.close() + server.close() + + @skipIfNonUnix + @gen_test + def test_init_close_race(self): + # Regression test for #2367 + # + # Skipped on windows because of what looks like a bug in the + # proactor event loop when started and stopped on non-main + # threads. + def f(): + for i in range(10): + loop = IOLoop(make_current=False) + loop.close() + + yield gen.multi([self.io_loop.run_in_executor(None, f) for i in range(2)]) + + def test_explicit_asyncio_loop(self): + asyncio_loop = asyncio.new_event_loop() + loop = IOLoop(asyncio_loop=asyncio_loop, make_current=False) + assert loop.asyncio_loop is asyncio_loop # type: ignore + with self.assertRaises(RuntimeError): + # Can't register two IOLoops with the same asyncio_loop + IOLoop(asyncio_loop=asyncio_loop, make_current=False) + loop.close() + + +# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't +# automatically set as current. +class TestIOLoopCurrent(unittest.TestCase): + def setUp(self): + setup_with_context_manager(self, ignore_deprecation()) + self.io_loop = None # type: typing.Optional[IOLoop] + IOLoop.clear_current() + + def tearDown(self): + if self.io_loop is not None: + self.io_loop.close() + + def test_non_current(self): + self.io_loop = IOLoop(make_current=False) + # The new IOLoop is not initially made current. + self.assertIsNone(IOLoop.current(instance=False)) + # Starting the IOLoop makes it current, and stopping the loop + # makes it non-current. This process is repeatable. + for i in range(3): + + def f(): + self.current_io_loop = IOLoop.current() + assert self.io_loop is not None + self.io_loop.stop() + + self.io_loop.add_callback(f) + self.io_loop.start() + self.assertIs(self.current_io_loop, self.io_loop) + # Now that the loop is stopped, it is no longer current. + self.assertIsNone(IOLoop.current(instance=False)) + + def test_force_current(self): + self.io_loop = IOLoop(make_current=True) + self.assertIs(self.io_loop, IOLoop.current()) + + +class TestIOLoopCurrentAsync(AsyncTestCase): + def setUp(self): + super().setUp() + setup_with_context_manager(self, ignore_deprecation()) + + @gen_test + def test_clear_without_current(self): + # If there is no current IOLoop, clear_current is a no-op (but + # should not fail). Use a thread so we see the threading.Local + # in a pristine state. + with ThreadPoolExecutor(1) as e: + yield e.submit(IOLoop.clear_current) + + +class TestIOLoopFutures(AsyncTestCase): + def test_add_future_threads(self): + with futures.ThreadPoolExecutor(1) as pool: + + def dummy(): + pass + + self.io_loop.add_future( + pool.submit(dummy), lambda future: self.stop(future) + ) + future = self.wait() + self.assertTrue(future.done()) + self.assertTrue(future.result() is None) + + @gen_test + def test_run_in_executor_gen(self): + event1 = threading.Event() + event2 = threading.Event() + + def sync_func(self_event, other_event): + self_event.set() + other_event.wait() + # Note that return value doesn't actually do anything, + # it is just passed through to our final assertion to + # make sure it is passed through properly. + return self_event + + # Run two synchronous functions, which would deadlock if not + # run in parallel. + res = yield [ + IOLoop.current().run_in_executor(None, sync_func, event1, event2), + IOLoop.current().run_in_executor(None, sync_func, event2, event1), + ] + + self.assertEqual([event1, event2], res) + + @gen_test + def test_run_in_executor_native(self): + event1 = threading.Event() + event2 = threading.Event() + + def sync_func(self_event, other_event): + self_event.set() + other_event.wait() + return self_event + + # Go through an async wrapper to ensure that the result of + # run_in_executor works with await and not just gen.coroutine + # (simply passing the underlying concurrent future would do that). + async def async_wrapper(self_event, other_event): + return await IOLoop.current().run_in_executor( + None, sync_func, self_event, other_event + ) + + res = yield [async_wrapper(event1, event2), async_wrapper(event2, event1)] + + self.assertEqual([event1, event2], res) + + @gen_test + def test_set_default_executor(self): + count = [0] + + class MyExecutor(futures.ThreadPoolExecutor): + def submit(self, func, *args): + count[0] += 1 + return super().submit(func, *args) + + event = threading.Event() + + def sync_func(): + event.set() + + executor = MyExecutor(1) + loop = IOLoop.current() + loop.set_default_executor(executor) + yield loop.run_in_executor(None, sync_func) + self.assertEqual(1, count[0]) + self.assertTrue(event.is_set()) + + +class TestIOLoopRunSync(unittest.TestCase): + def setUp(self): + self.io_loop = IOLoop(make_current=False) + + def tearDown(self): + self.io_loop.close() + + def test_sync_result(self): + with self.assertRaises(gen.BadYieldError): + self.io_loop.run_sync(lambda: 42) + + def test_sync_exception(self): + with self.assertRaises(ZeroDivisionError): + self.io_loop.run_sync(lambda: 1 / 0) + + def test_async_result(self): + @gen.coroutine + def f(): + yield gen.moment + raise gen.Return(42) + + self.assertEqual(self.io_loop.run_sync(f), 42) + + def test_async_exception(self): + @gen.coroutine + def f(): + yield gen.moment + 1 / 0 + + with self.assertRaises(ZeroDivisionError): + self.io_loop.run_sync(f) + + def test_current(self): + def f(): + self.assertIs(IOLoop.current(), self.io_loop) + + self.io_loop.run_sync(f) + + def test_timeout(self): + @gen.coroutine + def f(): + yield gen.sleep(1) + + self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01) + + def test_native_coroutine(self): + @gen.coroutine + def f1(): + yield gen.moment + + async def f2(): + await f1() + + self.io_loop.run_sync(f2) + + +class TestPeriodicCallbackMath(unittest.TestCase): + def simulate_calls(self, pc, durations): + """Simulate a series of calls to the PeriodicCallback. + + Pass a list of call durations in seconds (negative values + work to simulate clock adjustments during the call, or more or + less equivalently, between calls). This method returns the + times at which each call would be made. + """ + calls = [] + now = 1000 + pc._next_timeout = now + for d in durations: + pc._update_next(now) + calls.append(pc._next_timeout) + now = pc._next_timeout + d + return calls + + def dummy(self): + pass + + def test_basic(self): + pc = PeriodicCallback(self.dummy, 10000) + self.assertEqual( + self.simulate_calls(pc, [0] * 5), [1010, 1020, 1030, 1040, 1050] + ) + + def test_overrun(self): + # If a call runs for too long, we skip entire cycles to get + # back on schedule. + call_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0, 0] + expected = [ + 1010, + 1020, + 1030, # first 3 calls on schedule + 1050, + 1070, # next 2 delayed one cycle + 1100, + 1130, # next 2 delayed 2 cycles + 1170, + 1210, # next 2 delayed 3 cycles + 1220, + 1230, # then back on schedule. + ] + + pc = PeriodicCallback(self.dummy, 10000) + self.assertEqual(self.simulate_calls(pc, call_durations), expected) + + def test_clock_backwards(self): + pc = PeriodicCallback(self.dummy, 10000) + # Backwards jumps are ignored, potentially resulting in a + # slightly slow schedule (although we assume that when + # time.time() and time.monotonic() are different, time.time() + # is getting adjusted by NTP and is therefore more accurate) + self.assertEqual( + self.simulate_calls(pc, [-2, -1, -3, -2, 0]), [1010, 1020, 1030, 1040, 1050] + ) + + # For big jumps, we should perhaps alter the schedule, but we + # don't currently. This trace shows that we run callbacks + # every 10s of time.time(), but the first and second calls are + # 110s of real time apart because the backwards jump is + # ignored. + self.assertEqual(self.simulate_calls(pc, [-100, 0, 0]), [1010, 1020, 1030]) + + def test_jitter(self): + random_times = [0.5, 1, 0, 0.75] + expected = [1010, 1022.5, 1030, 1041.25] + call_durations = [0] * len(random_times) + pc = PeriodicCallback(self.dummy, 10000, jitter=0.5) + + def mock_random(): + return random_times.pop(0) + + with mock.patch("random.random", mock_random): + self.assertEqual(self.simulate_calls(pc, call_durations), expected) + + def test_timedelta(self): + pc = PeriodicCallback(lambda: None, datetime.timedelta(minutes=1, seconds=23)) + expected_callback_time = 83000 + self.assertEqual(pc.callback_time, expected_callback_time) + + +class TestPeriodicCallbackAsync(AsyncTestCase): + def test_periodic_plain(self): + count = 0 + + def callback() -> None: + nonlocal count + count += 1 + if count == 3: + self.stop() + + pc = PeriodicCallback(callback, 10) + pc.start() + self.wait() + pc.stop() + self.assertEqual(count, 3) + + def test_periodic_coro(self) -> None: + counts = [0, 0] + + @gen.coroutine + def callback() -> "Generator[Future[None], object, None]": + counts[0] += 1 + yield gen.sleep(0.025) + counts[1] += 1 + if counts[1] == 3: + pc.stop() + self.io_loop.add_callback(self.stop) + + pc = PeriodicCallback(callback, 10) + pc.start() + self.wait() + self.assertEqual(counts[0], 3) + self.assertEqual(counts[1], 3) + + def test_periodic_async(self) -> None: + counts = [0, 0] + + async def callback() -> None: + counts[0] += 1 + await gen.sleep(0.025) + counts[1] += 1 + if counts[1] == 3: + pc.stop() + self.io_loop.add_callback(self.stop) + + pc = PeriodicCallback(callback, 10) + pc.start() + self.wait() + self.assertEqual(counts[0], 3) + self.assertEqual(counts[1], 3) + + +class TestIOLoopConfiguration(unittest.TestCase): + def run_python(self, *statements): + stmt_list = [ + "from tornado.ioloop import IOLoop", + "classname = lambda x: x.__class__.__name__", + ] + list(statements) + args = [sys.executable, "-c", "; ".join(stmt_list)] + return native_str(subprocess.check_output(args)).strip() + + def test_default(self): + # When asyncio is available, it is used by default. + cls = self.run_python("print(classname(IOLoop.current()))") + self.assertEqual(cls, "AsyncIOMainLoop") + cls = self.run_python("print(classname(IOLoop()))") + self.assertEqual(cls, "AsyncIOLoop") + + def test_asyncio(self): + cls = self.run_python( + 'IOLoop.configure("tornado.platform.asyncio.AsyncIOLoop")', + "print(classname(IOLoop.current()))", + ) + self.assertEqual(cls, "AsyncIOMainLoop") + + def test_asyncio_main(self): + cls = self.run_python( + "from tornado.platform.asyncio import AsyncIOMainLoop", + "AsyncIOMainLoop().install()", + "print(classname(IOLoop.current()))", + ) + self.assertEqual(cls, "AsyncIOMainLoop") + + +if __name__ == "__main__": + unittest.main() diff --git a/venv/lib/python3.9/site-packages/tornado/test/iostream_test.py b/venv/lib/python3.9/site-packages/tornado/test/iostream_test.py new file mode 100644 index 00000000..e22e83e6 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/iostream_test.py @@ -0,0 +1,1306 @@ +from tornado.concurrent import Future +from tornado import gen +from tornado import netutil +from tornado.ioloop import IOLoop +from tornado.iostream import ( + IOStream, + SSLIOStream, + PipeIOStream, + StreamClosedError, + _StreamBuffer, +) +from tornado.httputil import HTTPHeaders +from tornado.locks import Condition, Event +from tornado.log import gen_log +from tornado.netutil import ssl_options_to_context, ssl_wrap_socket +from tornado.platform.asyncio import AddThreadSelectorEventLoop +from tornado.tcpserver import TCPServer +from tornado.testing import ( + AsyncHTTPTestCase, + AsyncHTTPSTestCase, + AsyncTestCase, + bind_unused_port, + ExpectLog, + gen_test, +) +from tornado.test.util import ( + skipIfNonUnix, + refusing_port, + skipPypy3V58, + ignore_deprecation, +) +from tornado.web import RequestHandler, Application +import asyncio +import errno +import hashlib +import logging +import os +import platform +import random +import socket +import ssl +import typing +from unittest import mock +import unittest + + +def _server_ssl_options(): + return dict( + certfile=os.path.join(os.path.dirname(__file__), "test.crt"), + keyfile=os.path.join(os.path.dirname(__file__), "test.key"), + ) + + +class HelloHandler(RequestHandler): + def get(self): + self.write("Hello") + + +class TestIOStreamWebMixin(object): + def _make_client_iostream(self): + raise NotImplementedError() + + def get_app(self): + return Application([("/", HelloHandler)]) + + def test_connection_closed(self: typing.Any): + # When a server sends a response and then closes the connection, + # the client must be allowed to read the data before the IOStream + # closes itself. Epoll reports closed connections with a separate + # EPOLLRDHUP event delivered at the same time as the read event, + # while kqueue reports them as a second read/write event with an EOF + # flag. + response = self.fetch("/", headers={"Connection": "close"}) + response.rethrow() + + @gen_test + def test_read_until_close(self: typing.Any): + stream = self._make_client_iostream() + yield stream.connect(("127.0.0.1", self.get_http_port())) + stream.write(b"GET / HTTP/1.0\r\n\r\n") + + data = yield stream.read_until_close() + self.assertTrue(data.startswith(b"HTTP/1.1 200")) + self.assertTrue(data.endswith(b"Hello")) + + @gen_test + def test_read_zero_bytes(self: typing.Any): + self.stream = self._make_client_iostream() + yield self.stream.connect(("127.0.0.1", self.get_http_port())) + self.stream.write(b"GET / HTTP/1.0\r\n\r\n") + + # normal read + data = yield self.stream.read_bytes(9) + self.assertEqual(data, b"HTTP/1.1 ") + + # zero bytes + data = yield self.stream.read_bytes(0) + self.assertEqual(data, b"") + + # another normal read + data = yield self.stream.read_bytes(3) + self.assertEqual(data, b"200") + + self.stream.close() + + @gen_test + def test_write_while_connecting(self: typing.Any): + stream = self._make_client_iostream() + connect_fut = stream.connect(("127.0.0.1", self.get_http_port())) + # unlike the previous tests, try to write before the connection + # is complete. + write_fut = stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n") + self.assertFalse(connect_fut.done()) + + # connect will always complete before write. + it = gen.WaitIterator(connect_fut, write_fut) + resolved_order = [] + while not it.done(): + yield it.next() + resolved_order.append(it.current_future) + self.assertEqual(resolved_order, [connect_fut, write_fut]) + + data = yield stream.read_until_close() + self.assertTrue(data.endswith(b"Hello")) + + stream.close() + + @gen_test + def test_future_interface(self: typing.Any): + """Basic test of IOStream's ability to return Futures.""" + stream = self._make_client_iostream() + connect_result = yield stream.connect(("127.0.0.1", self.get_http_port())) + self.assertIs(connect_result, stream) + yield stream.write(b"GET / HTTP/1.0\r\n\r\n") + first_line = yield stream.read_until(b"\r\n") + self.assertEqual(first_line, b"HTTP/1.1 200 OK\r\n") + # callback=None is equivalent to no callback. + header_data = yield stream.read_until(b"\r\n\r\n") + headers = HTTPHeaders.parse(header_data.decode("latin1")) + content_length = int(headers["Content-Length"]) + body = yield stream.read_bytes(content_length) + self.assertEqual(body, b"Hello") + stream.close() + + @gen_test + def test_future_close_while_reading(self: typing.Any): + stream = self._make_client_iostream() + yield stream.connect(("127.0.0.1", self.get_http_port())) + yield stream.write(b"GET / HTTP/1.0\r\n\r\n") + with self.assertRaises(StreamClosedError): + yield stream.read_bytes(1024 * 1024) + stream.close() + + @gen_test + def test_future_read_until_close(self: typing.Any): + # Ensure that the data comes through before the StreamClosedError. + stream = self._make_client_iostream() + yield stream.connect(("127.0.0.1", self.get_http_port())) + yield stream.write(b"GET / HTTP/1.0\r\nConnection: close\r\n\r\n") + yield stream.read_until(b"\r\n\r\n") + body = yield stream.read_until_close() + self.assertEqual(body, b"Hello") + + # Nothing else to read; the error comes immediately without waiting + # for yield. + with self.assertRaises(StreamClosedError): + stream.read_bytes(1) + + +class TestReadWriteMixin(object): + # Tests where one stream reads and the other writes. + # These should work for BaseIOStream implementations. + + def make_iostream_pair(self, **kwargs): + raise NotImplementedError + + def iostream_pair(self, **kwargs): + """Like make_iostream_pair, but called by ``async with``. + + In py37 this becomes simpler with contextlib.asynccontextmanager. + """ + + class IOStreamPairContext: + def __init__(self, test, kwargs): + self.test = test + self.kwargs = kwargs + + async def __aenter__(self): + self.pair = await self.test.make_iostream_pair(**self.kwargs) + return self.pair + + async def __aexit__(self, typ, value, tb): + for s in self.pair: + s.close() + + return IOStreamPairContext(self, kwargs) + + @gen_test + def test_write_zero_bytes(self): + # Attempting to write zero bytes should run the callback without + # going into an infinite loop. + rs, ws = yield self.make_iostream_pair() + yield ws.write(b"") + ws.close() + rs.close() + + @gen_test + def test_future_delayed_close_callback(self: typing.Any): + # Same as test_delayed_close_callback, but with the future interface. + rs, ws = yield self.make_iostream_pair() + + try: + ws.write(b"12") + chunks = [] + chunks.append((yield rs.read_bytes(1))) + ws.close() + chunks.append((yield rs.read_bytes(1))) + self.assertEqual(chunks, [b"1", b"2"]) + finally: + ws.close() + rs.close() + + @gen_test + def test_close_buffered_data(self: typing.Any): + # Similar to the previous test, but with data stored in the OS's + # socket buffers instead of the IOStream's read buffer. Out-of-band + # close notifications must be delayed until all data has been + # drained into the IOStream buffer. (epoll used to use out-of-band + # close events with EPOLLRDHUP, but no longer) + # + # This depends on the read_chunk_size being smaller than the + # OS socket buffer, so make it small. + rs, ws = yield self.make_iostream_pair(read_chunk_size=256) + try: + ws.write(b"A" * 512) + data = yield rs.read_bytes(256) + self.assertEqual(b"A" * 256, data) + ws.close() + # Allow the close to propagate to the `rs` side of the + # connection. Using add_callback instead of add_timeout + # doesn't seem to work, even with multiple iterations + yield gen.sleep(0.01) + data = yield rs.read_bytes(256) + self.assertEqual(b"A" * 256, data) + finally: + ws.close() + rs.close() + + @gen_test + def test_read_until_close_after_close(self: typing.Any): + # Similar to test_delayed_close_callback, but read_until_close takes + # a separate code path so test it separately. + rs, ws = yield self.make_iostream_pair() + try: + ws.write(b"1234") + # Read one byte to make sure the client has received the data. + # It won't run the close callback as long as there is more buffered + # data that could satisfy a later read. + data = yield rs.read_bytes(1) + ws.close() + self.assertEqual(data, b"1") + data = yield rs.read_until_close() + self.assertEqual(data, b"234") + finally: + ws.close() + rs.close() + + @gen_test + def test_large_read_until(self: typing.Any): + # Performance test: read_until used to have a quadratic component + # so a read_until of 4MB would take 8 seconds; now it takes 0.25 + # seconds. + rs, ws = yield self.make_iostream_pair() + try: + # This test fails on pypy with ssl. I think it's because + # pypy's gc defeats moves objects, breaking the + # "frozen write buffer" assumption. + if ( + isinstance(rs, SSLIOStream) + and platform.python_implementation() == "PyPy" + ): + raise unittest.SkipTest("pypy gc causes problems with openssl") + NUM_KB = 4096 + for i in range(NUM_KB): + ws.write(b"A" * 1024) + ws.write(b"\r\n") + data = yield rs.read_until(b"\r\n") + self.assertEqual(len(data), NUM_KB * 1024 + 2) + finally: + ws.close() + rs.close() + + @gen_test + async def test_read_until_with_close_after_second_packet(self): + # This is a regression test for a regression in Tornado 6.0 + # (maybe 6.0.3?) reported in + # https://github.com/tornadoweb/tornado/issues/2717 + # + # The data arrives in two chunks; the stream is closed at the + # same time that the second chunk is received. If the second + # chunk is larger than the first, it works, but when this bug + # existed it would fail if the second chunk were smaller than + # the first. This is due to the optimization that the + # read_until condition is only checked when the buffer doubles + # in size + async with self.iostream_pair() as (rs, ws): + rf = asyncio.ensure_future(rs.read_until(b"done")) + # We need to wait for the read_until to actually start. On + # windows that's tricky because the selector runs in + # another thread; sleeping is the simplest way. + await asyncio.sleep(0.1) + await ws.write(b"x" * 2048) + ws.write(b"done") + ws.close() + await rf + + @gen_test + async def test_read_until_unsatisfied_after_close(self: typing.Any): + # If a stream is closed while reading, it raises + # StreamClosedError instead of UnsatisfiableReadError (the + # latter should only be raised when byte limits are reached). + # The particular scenario tested here comes from #2717. + async with self.iostream_pair() as (rs, ws): + rf = asyncio.ensure_future(rs.read_until(b"done")) + await ws.write(b"x" * 2048) + ws.write(b"foo") + ws.close() + with self.assertRaises(StreamClosedError): + await rf + + @gen_test + def test_close_callback_with_pending_read(self: typing.Any): + # Regression test for a bug that was introduced in 2.3 + # where the IOStream._close_callback would never be called + # if there were pending reads. + OK = b"OK\r\n" + rs, ws = yield self.make_iostream_pair() + event = Event() + rs.set_close_callback(event.set) + try: + ws.write(OK) + res = yield rs.read_until(b"\r\n") + self.assertEqual(res, OK) + + ws.close() + rs.read_until(b"\r\n") + # If _close_callback (self.stop) is not called, + # an AssertionError: Async operation timed out after 5 seconds + # will be raised. + yield event.wait() + finally: + ws.close() + rs.close() + + @gen_test + def test_future_close_callback(self: typing.Any): + # Regression test for interaction between the Future read interfaces + # and IOStream._maybe_add_error_listener. + rs, ws = yield self.make_iostream_pair() + closed = [False] + cond = Condition() + + def close_callback(): + closed[0] = True + cond.notify() + + rs.set_close_callback(close_callback) + try: + ws.write(b"a") + res = yield rs.read_bytes(1) + self.assertEqual(res, b"a") + self.assertFalse(closed[0]) + ws.close() + yield cond.wait() + self.assertTrue(closed[0]) + finally: + rs.close() + ws.close() + + @gen_test + def test_write_memoryview(self: typing.Any): + rs, ws = yield self.make_iostream_pair() + try: + fut = rs.read_bytes(4) + ws.write(memoryview(b"hello")) + data = yield fut + self.assertEqual(data, b"hell") + finally: + ws.close() + rs.close() + + @gen_test + def test_read_bytes_partial(self: typing.Any): + rs, ws = yield self.make_iostream_pair() + try: + # Ask for more than is available with partial=True + fut = rs.read_bytes(50, partial=True) + ws.write(b"hello") + data = yield fut + self.assertEqual(data, b"hello") + + # Ask for less than what is available; num_bytes is still + # respected. + fut = rs.read_bytes(3, partial=True) + ws.write(b"world") + data = yield fut + self.assertEqual(data, b"wor") + + # Partial reads won't return an empty string, but read_bytes(0) + # will. + data = yield rs.read_bytes(0, partial=True) + self.assertEqual(data, b"") + finally: + ws.close() + rs.close() + + @gen_test + def test_read_until_max_bytes(self: typing.Any): + rs, ws = yield self.make_iostream_pair() + closed = Event() + rs.set_close_callback(closed.set) + try: + # Extra room under the limit + fut = rs.read_until(b"def", max_bytes=50) + ws.write(b"abcdef") + data = yield fut + self.assertEqual(data, b"abcdef") + + # Just enough space + fut = rs.read_until(b"def", max_bytes=6) + ws.write(b"abcdef") + data = yield fut + self.assertEqual(data, b"abcdef") + + # Not enough space, but we don't know it until all we can do is + # log a warning and close the connection. + with ExpectLog(gen_log, "Unsatisfiable read", level=logging.INFO): + fut = rs.read_until(b"def", max_bytes=5) + ws.write(b"123456") + yield closed.wait() + finally: + ws.close() + rs.close() + + @gen_test + def test_read_until_max_bytes_inline(self: typing.Any): + rs, ws = yield self.make_iostream_pair() + closed = Event() + rs.set_close_callback(closed.set) + try: + # Similar to the error case in the previous test, but the + # ws writes first so rs reads are satisfied + # inline. For consistency with the out-of-line case, we + # do not raise the error synchronously. + ws.write(b"123456") + with ExpectLog(gen_log, "Unsatisfiable read", level=logging.INFO): + with self.assertRaises(StreamClosedError): + yield rs.read_until(b"def", max_bytes=5) + yield closed.wait() + finally: + ws.close() + rs.close() + + @gen_test + def test_read_until_max_bytes_ignores_extra(self: typing.Any): + rs, ws = yield self.make_iostream_pair() + closed = Event() + rs.set_close_callback(closed.set) + try: + # Even though data that matches arrives the same packet that + # puts us over the limit, we fail the request because it was not + # found within the limit. + ws.write(b"abcdef") + with ExpectLog(gen_log, "Unsatisfiable read", level=logging.INFO): + rs.read_until(b"def", max_bytes=5) + yield closed.wait() + finally: + ws.close() + rs.close() + + @gen_test + def test_read_until_regex_max_bytes(self: typing.Any): + rs, ws = yield self.make_iostream_pair() + closed = Event() + rs.set_close_callback(closed.set) + try: + # Extra room under the limit + fut = rs.read_until_regex(b"def", max_bytes=50) + ws.write(b"abcdef") + data = yield fut + self.assertEqual(data, b"abcdef") + + # Just enough space + fut = rs.read_until_regex(b"def", max_bytes=6) + ws.write(b"abcdef") + data = yield fut + self.assertEqual(data, b"abcdef") + + # Not enough space, but we don't know it until all we can do is + # log a warning and close the connection. + with ExpectLog(gen_log, "Unsatisfiable read", level=logging.INFO): + rs.read_until_regex(b"def", max_bytes=5) + ws.write(b"123456") + yield closed.wait() + finally: + ws.close() + rs.close() + + @gen_test + def test_read_until_regex_max_bytes_inline(self: typing.Any): + rs, ws = yield self.make_iostream_pair() + closed = Event() + rs.set_close_callback(closed.set) + try: + # Similar to the error case in the previous test, but the + # ws writes first so rs reads are satisfied + # inline. For consistency with the out-of-line case, we + # do not raise the error synchronously. + ws.write(b"123456") + with ExpectLog(gen_log, "Unsatisfiable read", level=logging.INFO): + rs.read_until_regex(b"def", max_bytes=5) + yield closed.wait() + finally: + ws.close() + rs.close() + + @gen_test + def test_read_until_regex_max_bytes_ignores_extra(self): + rs, ws = yield self.make_iostream_pair() + closed = Event() + rs.set_close_callback(closed.set) + try: + # Even though data that matches arrives the same packet that + # puts us over the limit, we fail the request because it was not + # found within the limit. + ws.write(b"abcdef") + with ExpectLog(gen_log, "Unsatisfiable read", level=logging.INFO): + rs.read_until_regex(b"def", max_bytes=5) + yield closed.wait() + finally: + ws.close() + rs.close() + + @gen_test + def test_small_reads_from_large_buffer(self: typing.Any): + # 10KB buffer size, 100KB available to read. + # Read 1KB at a time and make sure that the buffer is not eagerly + # filled. + rs, ws = yield self.make_iostream_pair(max_buffer_size=10 * 1024) + try: + ws.write(b"a" * 1024 * 100) + for i in range(100): + data = yield rs.read_bytes(1024) + self.assertEqual(data, b"a" * 1024) + finally: + ws.close() + rs.close() + + @gen_test + def test_small_read_untils_from_large_buffer(self: typing.Any): + # 10KB buffer size, 100KB available to read. + # Read 1KB at a time and make sure that the buffer is not eagerly + # filled. + rs, ws = yield self.make_iostream_pair(max_buffer_size=10 * 1024) + try: + ws.write((b"a" * 1023 + b"\n") * 100) + for i in range(100): + data = yield rs.read_until(b"\n", max_bytes=4096) + self.assertEqual(data, b"a" * 1023 + b"\n") + finally: + ws.close() + rs.close() + + @gen_test + def test_flow_control(self): + MB = 1024 * 1024 + rs, ws = yield self.make_iostream_pair(max_buffer_size=5 * MB) + try: + # Client writes more than the rs will accept. + ws.write(b"a" * 10 * MB) + # The rs pauses while reading. + yield rs.read_bytes(MB) + yield gen.sleep(0.1) + # The ws's writes have been blocked; the rs can + # continue to read gradually. + for i in range(9): + yield rs.read_bytes(MB) + finally: + rs.close() + ws.close() + + @gen_test + def test_read_into(self: typing.Any): + rs, ws = yield self.make_iostream_pair() + + def sleep_some(): + self.io_loop.run_sync(lambda: gen.sleep(0.05)) + + try: + buf = bytearray(10) + fut = rs.read_into(buf) + ws.write(b"hello") + yield gen.sleep(0.05) + self.assertTrue(rs.reading()) + ws.write(b"world!!") + data = yield fut + self.assertFalse(rs.reading()) + self.assertEqual(data, 10) + self.assertEqual(bytes(buf), b"helloworld") + + # Existing buffer is fed into user buffer + fut = rs.read_into(buf) + yield gen.sleep(0.05) + self.assertTrue(rs.reading()) + ws.write(b"1234567890") + data = yield fut + self.assertFalse(rs.reading()) + self.assertEqual(data, 10) + self.assertEqual(bytes(buf), b"!!12345678") + + # Existing buffer can satisfy read immediately + buf = bytearray(4) + ws.write(b"abcdefghi") + data = yield rs.read_into(buf) + self.assertEqual(data, 4) + self.assertEqual(bytes(buf), b"90ab") + + data = yield rs.read_bytes(7) + self.assertEqual(data, b"cdefghi") + finally: + ws.close() + rs.close() + + @gen_test + def test_read_into_partial(self: typing.Any): + rs, ws = yield self.make_iostream_pair() + + try: + # Partial read + buf = bytearray(10) + fut = rs.read_into(buf, partial=True) + ws.write(b"hello") + data = yield fut + self.assertFalse(rs.reading()) + self.assertEqual(data, 5) + self.assertEqual(bytes(buf), b"hello\0\0\0\0\0") + + # Full read despite partial=True + ws.write(b"world!1234567890") + data = yield rs.read_into(buf, partial=True) + self.assertEqual(data, 10) + self.assertEqual(bytes(buf), b"world!1234") + + # Existing buffer can satisfy read immediately + data = yield rs.read_into(buf, partial=True) + self.assertEqual(data, 6) + self.assertEqual(bytes(buf), b"5678901234") + + finally: + ws.close() + rs.close() + + @gen_test + def test_read_into_zero_bytes(self: typing.Any): + rs, ws = yield self.make_iostream_pair() + try: + buf = bytearray() + fut = rs.read_into(buf) + self.assertEqual(fut.result(), 0) + finally: + ws.close() + rs.close() + + @gen_test + def test_many_mixed_reads(self): + # Stress buffer handling when going back and forth between + # read_bytes() (using an internal buffer) and read_into() + # (using a user-allocated buffer). + r = random.Random(42) + nbytes = 1000000 + rs, ws = yield self.make_iostream_pair() + + produce_hash = hashlib.sha1() + consume_hash = hashlib.sha1() + + @gen.coroutine + def produce(): + remaining = nbytes + while remaining > 0: + size = r.randint(1, min(1000, remaining)) + data = os.urandom(size) + produce_hash.update(data) + yield ws.write(data) + remaining -= size + assert remaining == 0 + + @gen.coroutine + def consume(): + remaining = nbytes + while remaining > 0: + if r.random() > 0.5: + # read_bytes() + size = r.randint(1, min(1000, remaining)) + data = yield rs.read_bytes(size) + consume_hash.update(data) + remaining -= size + else: + # read_into() + size = r.randint(1, min(1000, remaining)) + buf = bytearray(size) + n = yield rs.read_into(buf) + assert n == size + consume_hash.update(buf) + remaining -= size + assert remaining == 0 + + try: + yield [produce(), consume()] + assert produce_hash.hexdigest() == consume_hash.hexdigest() + finally: + ws.close() + rs.close() + + +class TestIOStreamMixin(TestReadWriteMixin): + def _make_server_iostream(self, connection, **kwargs): + raise NotImplementedError() + + def _make_client_iostream(self, connection, **kwargs): + raise NotImplementedError() + + @gen.coroutine + def make_iostream_pair(self: typing.Any, **kwargs): + listener, port = bind_unused_port() + server_stream_fut = Future() # type: Future[IOStream] + + def accept_callback(connection, address): + server_stream_fut.set_result( + self._make_server_iostream(connection, **kwargs) + ) + + netutil.add_accept_handler(listener, accept_callback) + client_stream = self._make_client_iostream(socket.socket(), **kwargs) + connect_fut = client_stream.connect(("127.0.0.1", port)) + server_stream, client_stream = yield [server_stream_fut, connect_fut] + self.io_loop.remove_handler(listener.fileno()) + listener.close() + raise gen.Return((server_stream, client_stream)) + + @gen_test + def test_connection_refused(self: typing.Any): + # When a connection is refused, the connect callback should not + # be run. (The kqueue IOLoop used to behave differently from the + # epoll IOLoop in this respect) + cleanup_func, port = refusing_port() + self.addCleanup(cleanup_func) + stream = IOStream(socket.socket()) + + stream.set_close_callback(self.stop) + # log messages vary by platform and ioloop implementation + with ExpectLog(gen_log, ".*", required=False): + with self.assertRaises(StreamClosedError): + yield stream.connect(("127.0.0.1", port)) + + self.assertTrue(isinstance(stream.error, ConnectionRefusedError), stream.error) + + @gen_test + def test_gaierror(self: typing.Any): + # Test that IOStream sets its exc_info on getaddrinfo error. + # It's difficult to reliably trigger a getaddrinfo error; + # some resolvers own't even return errors for malformed names, + # so we mock it instead. If IOStream changes to call a Resolver + # before sock.connect, the mock target will need to change too. + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + stream = IOStream(s) + stream.set_close_callback(self.stop) + with mock.patch( + "socket.socket.connect", side_effect=socket.gaierror(errno.EIO, "boom") + ): + with self.assertRaises(StreamClosedError): + yield stream.connect(("localhost", 80)) + self.assertTrue(isinstance(stream.error, socket.gaierror)) + + @gen_test + def test_read_until_close_with_error(self: typing.Any): + server, client = yield self.make_iostream_pair() + try: + with mock.patch( + "tornado.iostream.BaseIOStream._try_inline_read", + side_effect=IOError("boom"), + ): + with self.assertRaisesRegex(IOError, "boom"): + client.read_until_close() + finally: + server.close() + client.close() + + @skipIfNonUnix + @skipPypy3V58 + @gen_test + def test_inline_read_error(self: typing.Any): + # An error on an inline read is raised without logging (on the + # assumption that it will eventually be noticed or logged further + # up the stack). + # + # This test is posix-only because windows os.close() doesn't work + # on socket FDs, but we can't close the socket object normally + # because we won't get the error we want if the socket knows + # it's closed. + # + # This test is also disabled when the + # AddThreadSelectorEventLoop is used, because a race between + # this thread closing the socket and the selector thread + # calling the select system call can make this test flaky. + # This event loop implementation is normally only used on + # windows, making this check redundant with skipIfNonUnix, but + # we sometimes enable it on other platforms for testing. + io_loop = IOLoop.current() + if isinstance( + io_loop.selector_loop, # type: ignore[attr-defined] + AddThreadSelectorEventLoop, + ): + self.skipTest("AddThreadSelectorEventLoop not supported") + server, client = yield self.make_iostream_pair() + try: + os.close(server.socket.fileno()) + with self.assertRaises(socket.error): + server.read_bytes(1) + finally: + server.close() + client.close() + + @skipPypy3V58 + @gen_test + def test_async_read_error_logging(self): + # Socket errors on asynchronous reads should be logged (but only + # once). + server, client = yield self.make_iostream_pair() + closed = Event() + server.set_close_callback(closed.set) + try: + # Start a read that will be fulfilled asynchronously. + server.read_bytes(1) + client.write(b"a") + # Stub out read_from_fd to make it fail. + + def fake_read_from_fd(): + os.close(server.socket.fileno()) + server.__class__.read_from_fd(server) + + server.read_from_fd = fake_read_from_fd + # This log message is from _handle_read (not read_from_fd). + with ExpectLog(gen_log, "error on read"): + yield closed.wait() + finally: + server.close() + client.close() + + @gen_test + def test_future_write(self): + """ + Test that write() Futures are never orphaned. + """ + # Run concurrent writers that will write enough bytes so as to + # clog the socket buffer and accumulate bytes in our write buffer. + m, n = 5000, 1000 + nproducers = 10 + total_bytes = m * n * nproducers + server, client = yield self.make_iostream_pair(max_buffer_size=total_bytes) + + @gen.coroutine + def produce(): + data = b"x" * m + for i in range(n): + yield server.write(data) + + @gen.coroutine + def consume(): + nread = 0 + while nread < total_bytes: + res = yield client.read_bytes(m) + nread += len(res) + + try: + yield [produce() for i in range(nproducers)] + [consume()] + finally: + server.close() + client.close() + + +class TestIOStreamWebHTTP(TestIOStreamWebMixin, AsyncHTTPTestCase): + def _make_client_iostream(self): + return IOStream(socket.socket()) + + +class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase): + def _make_client_iostream(self): + return SSLIOStream(socket.socket(), ssl_options=dict(cert_reqs=ssl.CERT_NONE)) + + +class TestIOStream(TestIOStreamMixin, AsyncTestCase): + def _make_server_iostream(self, connection, **kwargs): + return IOStream(connection, **kwargs) + + def _make_client_iostream(self, connection, **kwargs): + return IOStream(connection, **kwargs) + + +class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase): + def _make_server_iostream(self, connection, **kwargs): + ssl_ctx = ssl_options_to_context(_server_ssl_options(), server_side=True) + connection = ssl_ctx.wrap_socket( + connection, + server_side=True, + do_handshake_on_connect=False, + ) + return SSLIOStream(connection, **kwargs) + + def _make_client_iostream(self, connection, **kwargs): + return SSLIOStream( + connection, ssl_options=dict(cert_reqs=ssl.CERT_NONE), **kwargs + ) + + +# This will run some tests that are basically redundant but it's the +# simplest way to make sure that it works to pass an SSLContext +# instead of an ssl_options dict to the SSLIOStream constructor. +class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase): + def _make_server_iostream(self, connection, **kwargs): + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.load_cert_chain( + os.path.join(os.path.dirname(__file__), "test.crt"), + os.path.join(os.path.dirname(__file__), "test.key"), + ) + connection = ssl_wrap_socket( + connection, context, server_side=True, do_handshake_on_connect=False + ) + return SSLIOStream(connection, **kwargs) + + def _make_client_iostream(self, connection, **kwargs): + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + return SSLIOStream(connection, ssl_options=context, **kwargs) + + +class TestIOStreamStartTLS(AsyncTestCase): + def setUp(self): + try: + super().setUp() + self.listener, self.port = bind_unused_port() + self.server_stream = None + self.server_accepted = Future() # type: Future[None] + netutil.add_accept_handler(self.listener, self.accept) + self.client_stream = IOStream( + socket.socket() + ) # type: typing.Optional[IOStream] + self.io_loop.add_future( + self.client_stream.connect(("127.0.0.1", self.port)), self.stop + ) + self.wait() + self.io_loop.add_future(self.server_accepted, self.stop) + self.wait() + except Exception as e: + print(e) + raise + + def tearDown(self): + if self.server_stream is not None: + self.server_stream.close() + if self.client_stream is not None: + self.client_stream.close() + self.io_loop.remove_handler(self.listener.fileno()) + self.listener.close() + super().tearDown() + + def accept(self, connection, address): + if self.server_stream is not None: + self.fail("should only get one connection") + self.server_stream = IOStream(connection) + self.server_accepted.set_result(None) + + @gen.coroutine + def client_send_line(self, line): + assert self.client_stream is not None + self.client_stream.write(line) + assert self.server_stream is not None + recv_line = yield self.server_stream.read_until(b"\r\n") + self.assertEqual(line, recv_line) + + @gen.coroutine + def server_send_line(self, line): + assert self.server_stream is not None + self.server_stream.write(line) + assert self.client_stream is not None + recv_line = yield self.client_stream.read_until(b"\r\n") + self.assertEqual(line, recv_line) + + def client_start_tls(self, ssl_options=None, server_hostname=None): + assert self.client_stream is not None + client_stream = self.client_stream + self.client_stream = None + return client_stream.start_tls(False, ssl_options, server_hostname) + + def server_start_tls(self, ssl_options=None): + assert self.server_stream is not None + server_stream = self.server_stream + self.server_stream = None + return server_stream.start_tls(True, ssl_options) + + @gen_test + def test_start_tls_smtp(self): + # This flow is simplified from RFC 3207 section 5. + # We don't really need all of this, but it helps to make sure + # that after realistic back-and-forth traffic the buffers end up + # in a sane state. + yield self.server_send_line(b"220 mail.example.com ready\r\n") + yield self.client_send_line(b"EHLO mail.example.com\r\n") + yield self.server_send_line(b"250-mail.example.com welcome\r\n") + yield self.server_send_line(b"250 STARTTLS\r\n") + yield self.client_send_line(b"STARTTLS\r\n") + yield self.server_send_line(b"220 Go ahead\r\n") + client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE)) + server_future = self.server_start_tls(_server_ssl_options()) + self.client_stream = yield client_future + self.server_stream = yield server_future + self.assertTrue(isinstance(self.client_stream, SSLIOStream)) + self.assertTrue(isinstance(self.server_stream, SSLIOStream)) + yield self.client_send_line(b"EHLO mail.example.com\r\n") + yield self.server_send_line(b"250 mail.example.com welcome\r\n") + + @gen_test + def test_handshake_fail(self): + server_future = self.server_start_tls(_server_ssl_options()) + # Certificates are verified with the default configuration. + with ExpectLog(gen_log, "SSL Error"): + client_future = self.client_start_tls(server_hostname="localhost") + with self.assertRaises(ssl.SSLError): + yield client_future + with self.assertRaises((ssl.SSLError, socket.error)): + yield server_future + + @gen_test + def test_check_hostname(self): + # Test that server_hostname parameter to start_tls is being used. + # The check_hostname functionality is only available in python 2.7 and + # up and in python 3.4 and up. + server_future = self.server_start_tls(_server_ssl_options()) + with ExpectLog(gen_log, "SSL Error"): + client_future = self.client_start_tls( + ssl.create_default_context(), server_hostname="127.0.0.1" + ) + with self.assertRaises(ssl.SSLError): + # The client fails to connect with an SSL error. + yield client_future + with self.assertRaises(Exception): + # The server fails to connect, but the exact error is unspecified. + yield server_future + + @gen_test + def test_typed_memoryview(self): + # Test support of memoryviews with an item size greater than 1 byte. + buf = memoryview(bytes(80)).cast("L") + assert self.server_stream is not None + yield self.server_stream.write(buf) + assert self.client_stream is not None + # This will timeout if the calculation of the buffer size is incorrect + recv = yield self.client_stream.read_bytes(buf.nbytes) + self.assertEqual(bytes(recv), bytes(buf)) + + +class WaitForHandshakeTest(AsyncTestCase): + @gen.coroutine + def connect_to_server(self, server_cls): + server = client = None + try: + sock, port = bind_unused_port() + server = server_cls(ssl_options=_server_ssl_options()) + server.add_socket(sock) + + ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE + # These tests fail with ConnectionAbortedErrors with TLS + # 1.3 on windows python 3.7.4 (which includes an upgrade + # to openssl 1.1.c. Other platforms might be affected with + # newer openssl too). Disable it until we figure out + # what's up. + # Update 2021-12-28: Still happening with Python 3.10 on + # Windows. OP_NO_TLSv1_3 now raises a DeprecationWarning. + with ignore_deprecation(): + ssl_ctx.options |= getattr(ssl, "OP_NO_TLSv1_3", 0) + client = SSLIOStream(socket.socket(), ssl_options=ssl_ctx) + yield client.connect(("127.0.0.1", port)) + self.assertIsNotNone(client.socket.cipher()) + finally: + if server is not None: + server.stop() + if client is not None: + client.close() + + @gen_test + def test_wait_for_handshake_future(self): + test = self + handshake_future = Future() # type: Future[None] + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + test.assertIsNone(stream.socket.cipher()) + test.io_loop.spawn_callback(self.handle_connection, stream) + + @gen.coroutine + def handle_connection(self, stream): + yield stream.wait_for_handshake() + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_already_waiting_error(self): + test = self + handshake_future = Future() # type: Future[None] + + class TestServer(TCPServer): + @gen.coroutine + def handle_stream(self, stream, address): + fut = stream.wait_for_handshake() + test.assertRaises(RuntimeError, stream.wait_for_handshake) + yield fut + + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_already_connected(self): + handshake_future = Future() # type: Future[None] + + class TestServer(TCPServer): + @gen.coroutine + def handle_stream(self, stream, address): + yield stream.wait_for_handshake() + yield stream.wait_for_handshake() + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + +@skipIfNonUnix +class TestPipeIOStream(TestReadWriteMixin, AsyncTestCase): + @gen.coroutine + def make_iostream_pair(self, **kwargs): + r, w = os.pipe() + + return PipeIOStream(r, **kwargs), PipeIOStream(w, **kwargs) + + @gen_test + def test_pipe_iostream(self): + rs, ws = yield self.make_iostream_pair() + + ws.write(b"hel") + ws.write(b"lo world") + + data = yield rs.read_until(b" ") + self.assertEqual(data, b"hello ") + + data = yield rs.read_bytes(3) + self.assertEqual(data, b"wor") + + ws.close() + + data = yield rs.read_until_close() + self.assertEqual(data, b"ld") + + rs.close() + + @gen_test + def test_pipe_iostream_big_write(self): + rs, ws = yield self.make_iostream_pair() + + NUM_BYTES = 1048576 + + # Write 1MB of data, which should fill the buffer + ws.write(b"1" * NUM_BYTES) + + data = yield rs.read_bytes(NUM_BYTES) + self.assertEqual(data, b"1" * NUM_BYTES) + + ws.close() + rs.close() + + +class TestStreamBuffer(unittest.TestCase): + """ + Unit tests for the private _StreamBuffer class. + """ + + def setUp(self): + self.random = random.Random(42) + + def to_bytes(self, b): + if isinstance(b, (bytes, bytearray)): + return bytes(b) + elif isinstance(b, memoryview): + return b.tobytes() # For py2 + else: + raise TypeError(b) + + def make_streambuffer(self, large_buf_threshold=10): + buf = _StreamBuffer() + assert buf._large_buf_threshold + buf._large_buf_threshold = large_buf_threshold + return buf + + def check_peek(self, buf, expected): + size = 1 + while size < 2 * len(expected): + got = self.to_bytes(buf.peek(size)) + self.assertTrue(got) # Not empty + self.assertLessEqual(len(got), size) + self.assertTrue(expected.startswith(got), (expected, got)) + size = (size * 3 + 1) // 2 + + def check_append_all_then_skip_all(self, buf, objs, input_type): + self.assertEqual(len(buf), 0) + + expected = b"" + + for o in objs: + expected += o + buf.append(input_type(o)) + self.assertEqual(len(buf), len(expected)) + self.check_peek(buf, expected) + + while expected: + n = self.random.randrange(1, len(expected) + 1) + expected = expected[n:] + buf.advance(n) + self.assertEqual(len(buf), len(expected)) + self.check_peek(buf, expected) + + self.assertEqual(len(buf), 0) + + def test_small(self): + objs = [b"12", b"345", b"67", b"89a", b"bcde", b"fgh", b"ijklmn"] + + buf = self.make_streambuffer() + self.check_append_all_then_skip_all(buf, objs, bytes) + + buf = self.make_streambuffer() + self.check_append_all_then_skip_all(buf, objs, bytearray) + + buf = self.make_streambuffer() + self.check_append_all_then_skip_all(buf, objs, memoryview) + + # Test internal algorithm + buf = self.make_streambuffer(10) + for i in range(9): + buf.append(b"x") + self.assertEqual(len(buf._buffers), 1) + for i in range(9): + buf.append(b"x") + self.assertEqual(len(buf._buffers), 2) + buf.advance(10) + self.assertEqual(len(buf._buffers), 1) + buf.advance(8) + self.assertEqual(len(buf._buffers), 0) + self.assertEqual(len(buf), 0) + + def test_large(self): + objs = [ + b"12" * 5, + b"345" * 2, + b"67" * 20, + b"89a" * 12, + b"bcde" * 1, + b"fgh" * 7, + b"ijklmn" * 2, + ] + + buf = self.make_streambuffer() + self.check_append_all_then_skip_all(buf, objs, bytes) + + buf = self.make_streambuffer() + self.check_append_all_then_skip_all(buf, objs, bytearray) + + buf = self.make_streambuffer() + self.check_append_all_then_skip_all(buf, objs, memoryview) + + # Test internal algorithm + buf = self.make_streambuffer(10) + for i in range(3): + buf.append(b"x" * 11) + self.assertEqual(len(buf._buffers), 3) + buf.append(b"y") + self.assertEqual(len(buf._buffers), 4) + buf.append(b"z") + self.assertEqual(len(buf._buffers), 4) + buf.advance(33) + self.assertEqual(len(buf._buffers), 1) + buf.advance(2) + self.assertEqual(len(buf._buffers), 0) + self.assertEqual(len(buf), 0) diff --git a/venv/lib/python3.9/site-packages/tornado/test/locale_test.py b/venv/lib/python3.9/site-packages/tornado/test/locale_test.py new file mode 100644 index 00000000..ee74cb05 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/locale_test.py @@ -0,0 +1,149 @@ +import datetime +import os +import shutil +import tempfile +import unittest + +import tornado.locale +from tornado.escape import utf8, to_unicode +from tornado.util import unicode_type + + +class TranslationLoaderTest(unittest.TestCase): + # TODO: less hacky way to get isolated tests + SAVE_VARS = ["_translations", "_supported_locales", "_use_gettext"] + + def clear_locale_cache(self): + tornado.locale.Locale._cache = {} + + def setUp(self): + self.saved = {} # type: dict + for var in TranslationLoaderTest.SAVE_VARS: + self.saved[var] = getattr(tornado.locale, var) + self.clear_locale_cache() + + def tearDown(self): + for k, v in self.saved.items(): + setattr(tornado.locale, k, v) + self.clear_locale_cache() + + def test_csv(self): + tornado.locale.load_translations( + os.path.join(os.path.dirname(__file__), "csv_translations") + ) + locale = tornado.locale.get("fr_FR") + self.assertTrue(isinstance(locale, tornado.locale.CSVLocale)) + self.assertEqual(locale.translate("school"), "\u00e9cole") + + def test_csv_bom(self): + with open( + os.path.join(os.path.dirname(__file__), "csv_translations", "fr_FR.csv"), + "rb", + ) as f: + char_data = to_unicode(f.read()) + # Re-encode our input data (which is utf-8 without BOM) in + # encodings that use the BOM and ensure that we can still load + # it. Note that utf-16-le and utf-16-be do not write a BOM, + # so we only test whichver variant is native to our platform. + for encoding in ["utf-8-sig", "utf-16"]: + tmpdir = tempfile.mkdtemp() + try: + with open(os.path.join(tmpdir, "fr_FR.csv"), "wb") as f: + f.write(char_data.encode(encoding)) + tornado.locale.load_translations(tmpdir) + locale = tornado.locale.get("fr_FR") + self.assertIsInstance(locale, tornado.locale.CSVLocale) + self.assertEqual(locale.translate("school"), "\u00e9cole") + finally: + shutil.rmtree(tmpdir) + + def test_gettext(self): + tornado.locale.load_gettext_translations( + os.path.join(os.path.dirname(__file__), "gettext_translations"), + "tornado_test", + ) + locale = tornado.locale.get("fr_FR") + self.assertTrue(isinstance(locale, tornado.locale.GettextLocale)) + self.assertEqual(locale.translate("school"), "\u00e9cole") + self.assertEqual(locale.pgettext("law", "right"), "le droit") + self.assertEqual(locale.pgettext("good", "right"), "le bien") + self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), "le club") + self.assertEqual( + locale.pgettext("organization", "club", "clubs", 2), "les clubs" + ) + self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), "le b\xe2ton") + self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), "les b\xe2tons") + + +class LocaleDataTest(unittest.TestCase): + def test_non_ascii_name(self): + name = tornado.locale.LOCALE_NAMES["es_LA"]["name"] + self.assertTrue(isinstance(name, unicode_type)) + self.assertEqual(name, "Espa\u00f1ol") + self.assertEqual(utf8(name), b"Espa\xc3\xb1ol") + + +class EnglishTest(unittest.TestCase): + def test_format_date(self): + locale = tornado.locale.get("en_US") + date = datetime.datetime(2013, 4, 28, 18, 35) + self.assertEqual( + locale.format_date(date, full_format=True), "April 28, 2013 at 6:35 pm" + ) + + now = datetime.datetime.utcnow() + + self.assertEqual( + locale.format_date(now - datetime.timedelta(seconds=2), full_format=False), + "2 seconds ago", + ) + self.assertEqual( + locale.format_date(now - datetime.timedelta(minutes=2), full_format=False), + "2 minutes ago", + ) + self.assertEqual( + locale.format_date(now - datetime.timedelta(hours=2), full_format=False), + "2 hours ago", + ) + + self.assertEqual( + locale.format_date( + now - datetime.timedelta(days=1), full_format=False, shorter=True + ), + "yesterday", + ) + + date = now - datetime.timedelta(days=2) + self.assertEqual( + locale.format_date(date, full_format=False, shorter=True), + locale._weekdays[date.weekday()], + ) + + date = now - datetime.timedelta(days=300) + self.assertEqual( + locale.format_date(date, full_format=False, shorter=True), + "%s %d" % (locale._months[date.month - 1], date.day), + ) + + date = now - datetime.timedelta(days=500) + self.assertEqual( + locale.format_date(date, full_format=False, shorter=True), + "%s %d, %d" % (locale._months[date.month - 1], date.day, date.year), + ) + + def test_friendly_number(self): + locale = tornado.locale.get("en_US") + self.assertEqual(locale.friendly_number(1000000), "1,000,000") + + def test_list(self): + locale = tornado.locale.get("en_US") + self.assertEqual(locale.list([]), "") + self.assertEqual(locale.list(["A"]), "A") + self.assertEqual(locale.list(["A", "B"]), "A and B") + self.assertEqual(locale.list(["A", "B", "C"]), "A, B and C") + + def test_format_day(self): + locale = tornado.locale.get("en_US") + date = datetime.datetime(2013, 4, 28, 18, 35) + self.assertEqual(locale.format_day(date=date, dow=True), "Sunday, April 28") + self.assertEqual(locale.format_day(date=date, dow=False), "April 28") diff --git a/venv/lib/python3.9/site-packages/tornado/test/locks_test.py b/venv/lib/python3.9/site-packages/tornado/test/locks_test.py new file mode 100644 index 00000000..23e1c520 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/locks_test.py @@ -0,0 +1,535 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import asyncio +from datetime import timedelta +import typing # noqa: F401 +import unittest + +from tornado import gen, locks +from tornado.gen import TimeoutError +from tornado.testing import gen_test, AsyncTestCase + + +class ConditionTest(AsyncTestCase): + def setUp(self): + super().setUp() + self.history = [] # type: typing.List[typing.Union[int, str]] + + def record_done(self, future, key): + """Record the resolution of a Future returned by Condition.wait.""" + + def callback(_): + if not future.result(): + # wait() resolved to False, meaning it timed out. + self.history.append("timeout") + else: + self.history.append(key) + + future.add_done_callback(callback) + + def loop_briefly(self): + """Run all queued callbacks on the IOLoop. + + In these tests, this method is used after calling notify() to + preserve the pre-5.0 behavior in which callbacks ran + synchronously. + """ + self.io_loop.add_callback(self.stop) + self.wait() + + def test_repr(self): + c = locks.Condition() + self.assertIn("Condition", repr(c)) + self.assertNotIn("waiters", repr(c)) + c.wait() + self.assertIn("waiters", repr(c)) + + @gen_test + def test_notify(self): + c = locks.Condition() + self.io_loop.call_later(0.01, c.notify) + yield c.wait() + + def test_notify_1(self): + c = locks.Condition() + self.record_done(c.wait(), "wait1") + self.record_done(c.wait(), "wait2") + c.notify(1) + self.loop_briefly() + self.history.append("notify1") + c.notify(1) + self.loop_briefly() + self.history.append("notify2") + self.assertEqual(["wait1", "notify1", "wait2", "notify2"], self.history) + + def test_notify_n(self): + c = locks.Condition() + for i in range(6): + self.record_done(c.wait(), i) + + c.notify(3) + self.loop_briefly() + + # Callbacks execute in the order they were registered. + self.assertEqual(list(range(3)), self.history) + c.notify(1) + self.loop_briefly() + self.assertEqual(list(range(4)), self.history) + c.notify(2) + self.loop_briefly() + self.assertEqual(list(range(6)), self.history) + + def test_notify_all(self): + c = locks.Condition() + for i in range(4): + self.record_done(c.wait(), i) + + c.notify_all() + self.loop_briefly() + self.history.append("notify_all") + + # Callbacks execute in the order they were registered. + self.assertEqual(list(range(4)) + ["notify_all"], self.history) # type: ignore + + @gen_test + def test_wait_timeout(self): + c = locks.Condition() + wait = c.wait(timedelta(seconds=0.01)) + self.io_loop.call_later(0.02, c.notify) # Too late. + yield gen.sleep(0.03) + self.assertFalse((yield wait)) + + @gen_test + def test_wait_timeout_preempted(self): + c = locks.Condition() + + # This fires before the wait times out. + self.io_loop.call_later(0.01, c.notify) + wait = c.wait(timedelta(seconds=0.02)) + yield gen.sleep(0.03) + yield wait # No TimeoutError. + + @gen_test + def test_notify_n_with_timeout(self): + # Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout. + # Wait for that timeout to expire, then do notify(2) and make + # sure everyone runs. Verifies that a timed-out callback does + # not count against the 'n' argument to notify(). + c = locks.Condition() + self.record_done(c.wait(), 0) + self.record_done(c.wait(timedelta(seconds=0.01)), 1) + self.record_done(c.wait(), 2) + self.record_done(c.wait(), 3) + + # Wait for callback 1 to time out. + yield gen.sleep(0.02) + self.assertEqual(["timeout"], self.history) + + c.notify(2) + yield gen.sleep(0.01) + self.assertEqual(["timeout", 0, 2], self.history) + self.assertEqual(["timeout", 0, 2], self.history) + c.notify() + yield + self.assertEqual(["timeout", 0, 2, 3], self.history) + + @gen_test + def test_notify_all_with_timeout(self): + c = locks.Condition() + self.record_done(c.wait(), 0) + self.record_done(c.wait(timedelta(seconds=0.01)), 1) + self.record_done(c.wait(), 2) + + # Wait for callback 1 to time out. + yield gen.sleep(0.02) + self.assertEqual(["timeout"], self.history) + + c.notify_all() + yield + self.assertEqual(["timeout", 0, 2], self.history) + + @gen_test + def test_nested_notify(self): + # Ensure no notifications lost, even if notify() is reentered by a + # waiter calling notify(). + c = locks.Condition() + + # Three waiters. + futures = [asyncio.ensure_future(c.wait()) for _ in range(3)] + + # First and second futures resolved. Second future reenters notify(), + # resolving third future. + futures[1].add_done_callback(lambda _: c.notify()) + c.notify(2) + yield + self.assertTrue(all(f.done() for f in futures)) + + @gen_test + def test_garbage_collection(self): + # Test that timed-out waiters are occasionally cleaned from the queue. + c = locks.Condition() + for _ in range(101): + c.wait(timedelta(seconds=0.01)) + + future = asyncio.ensure_future(c.wait()) + self.assertEqual(102, len(c._waiters)) + + # Let first 101 waiters time out, triggering a collection. + yield gen.sleep(0.02) + self.assertEqual(1, len(c._waiters)) + + # Final waiter is still active. + self.assertFalse(future.done()) + c.notify() + self.assertTrue(future.done()) + + +class EventTest(AsyncTestCase): + def test_repr(self): + event = locks.Event() + self.assertTrue("clear" in str(event)) + self.assertFalse("set" in str(event)) + event.set() + self.assertFalse("clear" in str(event)) + self.assertTrue("set" in str(event)) + + def test_event(self): + e = locks.Event() + future_0 = asyncio.ensure_future(e.wait()) + e.set() + future_1 = asyncio.ensure_future(e.wait()) + e.clear() + future_2 = asyncio.ensure_future(e.wait()) + + self.assertTrue(future_0.done()) + self.assertTrue(future_1.done()) + self.assertFalse(future_2.done()) + + @gen_test + def test_event_timeout(self): + e = locks.Event() + with self.assertRaises(TimeoutError): + yield e.wait(timedelta(seconds=0.01)) + + # After a timed-out waiter, normal operation works. + self.io_loop.add_timeout(timedelta(seconds=0.01), e.set) + yield e.wait(timedelta(seconds=1)) + + def test_event_set_multiple(self): + e = locks.Event() + e.set() + e.set() + self.assertTrue(e.is_set()) + + def test_event_wait_clear(self): + e = locks.Event() + f0 = asyncio.ensure_future(e.wait()) + e.clear() + f1 = asyncio.ensure_future(e.wait()) + e.set() + self.assertTrue(f0.done()) + self.assertTrue(f1.done()) + + +class SemaphoreTest(AsyncTestCase): + def test_negative_value(self): + self.assertRaises(ValueError, locks.Semaphore, value=-1) + + def test_repr(self): + sem = locks.Semaphore() + self.assertIn("Semaphore", repr(sem)) + self.assertIn("unlocked,value:1", repr(sem)) + sem.acquire() + self.assertIn("locked", repr(sem)) + self.assertNotIn("waiters", repr(sem)) + sem.acquire() + self.assertIn("waiters", repr(sem)) + + def test_acquire(self): + sem = locks.Semaphore() + f0 = asyncio.ensure_future(sem.acquire()) + self.assertTrue(f0.done()) + + # Wait for release(). + f1 = asyncio.ensure_future(sem.acquire()) + self.assertFalse(f1.done()) + f2 = asyncio.ensure_future(sem.acquire()) + sem.release() + self.assertTrue(f1.done()) + self.assertFalse(f2.done()) + sem.release() + self.assertTrue(f2.done()) + + sem.release() + # Now acquire() is instant. + self.assertTrue(asyncio.ensure_future(sem.acquire()).done()) + self.assertEqual(0, len(sem._waiters)) + + @gen_test + def test_acquire_timeout(self): + sem = locks.Semaphore(2) + yield sem.acquire() + yield sem.acquire() + acquire = sem.acquire(timedelta(seconds=0.01)) + self.io_loop.call_later(0.02, sem.release) # Too late. + yield gen.sleep(0.3) + with self.assertRaises(gen.TimeoutError): + yield acquire + + sem.acquire() + f = asyncio.ensure_future(sem.acquire()) + self.assertFalse(f.done()) + sem.release() + self.assertTrue(f.done()) + + @gen_test + def test_acquire_timeout_preempted(self): + sem = locks.Semaphore(1) + yield sem.acquire() + + # This fires before the wait times out. + self.io_loop.call_later(0.01, sem.release) + acquire = sem.acquire(timedelta(seconds=0.02)) + yield gen.sleep(0.03) + yield acquire # No TimeoutError. + + def test_release_unacquired(self): + # Unbounded releases are allowed, and increment the semaphore's value. + sem = locks.Semaphore() + sem.release() + sem.release() + + # Now the counter is 3. We can acquire three times before blocking. + self.assertTrue(asyncio.ensure_future(sem.acquire()).done()) + self.assertTrue(asyncio.ensure_future(sem.acquire()).done()) + self.assertTrue(asyncio.ensure_future(sem.acquire()).done()) + self.assertFalse(asyncio.ensure_future(sem.acquire()).done()) + + @gen_test + def test_garbage_collection(self): + # Test that timed-out waiters are occasionally cleaned from the queue. + sem = locks.Semaphore(value=0) + futures = [ + asyncio.ensure_future(sem.acquire(timedelta(seconds=0.01))) + for _ in range(101) + ] + + future = asyncio.ensure_future(sem.acquire()) + self.assertEqual(102, len(sem._waiters)) + + # Let first 101 waiters time out, triggering a collection. + yield gen.sleep(0.02) + self.assertEqual(1, len(sem._waiters)) + + # Final waiter is still active. + self.assertFalse(future.done()) + sem.release() + self.assertTrue(future.done()) + + # Prevent "Future exception was never retrieved" messages. + for future in futures: + self.assertRaises(TimeoutError, future.result) + + +class SemaphoreContextManagerTest(AsyncTestCase): + @gen_test + def test_context_manager(self): + sem = locks.Semaphore() + with (yield sem.acquire()) as yielded: + self.assertTrue(yielded is None) + + # Semaphore was released and can be acquired again. + self.assertTrue(asyncio.ensure_future(sem.acquire()).done()) + + @gen_test + def test_context_manager_async_await(self): + # Repeat the above test using 'async with'. + sem = locks.Semaphore() + + async def f(): + async with sem as yielded: + self.assertTrue(yielded is None) + + yield f() + + # Semaphore was released and can be acquired again. + self.assertTrue(asyncio.ensure_future(sem.acquire()).done()) + + @gen_test + def test_context_manager_exception(self): + sem = locks.Semaphore() + with self.assertRaises(ZeroDivisionError): + with (yield sem.acquire()): + 1 / 0 + + # Semaphore was released and can be acquired again. + self.assertTrue(asyncio.ensure_future(sem.acquire()).done()) + + @gen_test + def test_context_manager_timeout(self): + sem = locks.Semaphore() + with (yield sem.acquire(timedelta(seconds=0.01))): + pass + + # Semaphore was released and can be acquired again. + self.assertTrue(asyncio.ensure_future(sem.acquire()).done()) + + @gen_test + def test_context_manager_timeout_error(self): + sem = locks.Semaphore(value=0) + with self.assertRaises(gen.TimeoutError): + with (yield sem.acquire(timedelta(seconds=0.01))): + pass + + # Counter is still 0. + self.assertFalse(asyncio.ensure_future(sem.acquire()).done()) + + @gen_test + def test_context_manager_contended(self): + sem = locks.Semaphore() + history = [] + + @gen.coroutine + def f(index): + with (yield sem.acquire()): + history.append("acquired %d" % index) + yield gen.sleep(0.01) + history.append("release %d" % index) + + yield [f(i) for i in range(2)] + + expected_history = [] + for i in range(2): + expected_history.extend(["acquired %d" % i, "release %d" % i]) + + self.assertEqual(expected_history, history) + + @gen_test + def test_yield_sem(self): + # Ensure we catch a "with (yield sem)", which should be + # "with (yield sem.acquire())". + with self.assertRaises(gen.BadYieldError): + with (yield locks.Semaphore()): + pass + + def test_context_manager_misuse(self): + # Ensure we catch a "with sem", which should be + # "with (yield sem.acquire())". + with self.assertRaises(RuntimeError): + with locks.Semaphore(): + pass + + +class BoundedSemaphoreTest(AsyncTestCase): + def test_release_unacquired(self): + sem = locks.BoundedSemaphore() + self.assertRaises(ValueError, sem.release) + # Value is 0. + sem.acquire() + # Block on acquire(). + future = asyncio.ensure_future(sem.acquire()) + self.assertFalse(future.done()) + sem.release() + self.assertTrue(future.done()) + # Value is 1. + sem.release() + self.assertRaises(ValueError, sem.release) + + +class LockTests(AsyncTestCase): + def test_repr(self): + lock = locks.Lock() + # No errors. + repr(lock) + lock.acquire() + repr(lock) + + def test_acquire_release(self): + lock = locks.Lock() + self.assertTrue(asyncio.ensure_future(lock.acquire()).done()) + future = asyncio.ensure_future(lock.acquire()) + self.assertFalse(future.done()) + lock.release() + self.assertTrue(future.done()) + + @gen_test + def test_acquire_fifo(self): + lock = locks.Lock() + self.assertTrue(asyncio.ensure_future(lock.acquire()).done()) + N = 5 + history = [] + + @gen.coroutine + def f(idx): + with (yield lock.acquire()): + history.append(idx) + + futures = [f(i) for i in range(N)] + self.assertFalse(any(future.done() for future in futures)) + lock.release() + yield futures + self.assertEqual(list(range(N)), history) + + @gen_test + def test_acquire_fifo_async_with(self): + # Repeat the above test using `async with lock:` + # instead of `with (yield lock.acquire()):`. + lock = locks.Lock() + self.assertTrue(asyncio.ensure_future(lock.acquire()).done()) + N = 5 + history = [] + + async def f(idx): + async with lock: + history.append(idx) + + futures = [f(i) for i in range(N)] + lock.release() + yield futures + self.assertEqual(list(range(N)), history) + + @gen_test + def test_acquire_timeout(self): + lock = locks.Lock() + lock.acquire() + with self.assertRaises(gen.TimeoutError): + yield lock.acquire(timeout=timedelta(seconds=0.01)) + + # Still locked. + self.assertFalse(asyncio.ensure_future(lock.acquire()).done()) + + def test_multi_release(self): + lock = locks.Lock() + self.assertRaises(RuntimeError, lock.release) + lock.acquire() + lock.release() + self.assertRaises(RuntimeError, lock.release) + + @gen_test + def test_yield_lock(self): + # Ensure we catch a "with (yield lock)", which should be + # "with (yield lock.acquire())". + with self.assertRaises(gen.BadYieldError): + with (yield locks.Lock()): + pass + + def test_context_manager_misuse(self): + # Ensure we catch a "with lock", which should be + # "with (yield lock.acquire())". + with self.assertRaises(RuntimeError): + with locks.Lock(): + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/venv/lib/python3.9/site-packages/tornado/test/log_test.py b/venv/lib/python3.9/site-packages/tornado/test/log_test.py new file mode 100644 index 00000000..9130ae7e --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/log_test.py @@ -0,0 +1,245 @@ +# +# Copyright 2012 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import contextlib +import glob +import logging +import os +import re +import subprocess +import sys +import tempfile +import unittest +import warnings + +from tornado.escape import utf8 +from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging +from tornado.options import OptionParser +from tornado.util import basestring_type + + +@contextlib.contextmanager +def ignore_bytes_warning(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=BytesWarning) + yield + + +class LogFormatterTest(unittest.TestCase): + # Matches the output of a single logging call (which may be multiple lines + # if a traceback was included, so we use the DOTALL option) + LINE_RE = re.compile( + b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)" + ) + + def setUp(self): + self.formatter = LogFormatter(color=False) + # Fake color support. We can't guarantee anything about the $TERM + # variable when the tests are run, so just patch in some values + # for testing. (testing with color off fails to expose some potential + # encoding issues from the control characters) + self.formatter._colors = {logging.ERROR: "\u0001"} + self.formatter._normal = "\u0002" + # construct a Logger directly to bypass getLogger's caching + self.logger = logging.Logger("LogFormatterTest") + self.logger.propagate = False + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, "log.out") + self.handler = self.make_handler(self.filename) + self.handler.setFormatter(self.formatter) + self.logger.addHandler(self.handler) + + def tearDown(self): + self.handler.close() + os.unlink(self.filename) + os.rmdir(self.tempdir) + + def make_handler(self, filename): + # Base case: default setup without explicit encoding. + # In python 2, supports arbitrary byte strings and unicode objects + # that contain only ascii. In python 3, supports ascii-only unicode + # strings (but byte strings will be repr'd automatically). + return logging.FileHandler(filename) + + def get_output(self): + with open(self.filename, "rb") as f: + line = f.read().strip() + m = LogFormatterTest.LINE_RE.match(line) + if m: + return m.group(1) + else: + raise Exception("output didn't match regex: %r" % line) + + def test_basic_logging(self): + self.logger.error("foo") + self.assertEqual(self.get_output(), b"foo") + + def test_bytes_logging(self): + with ignore_bytes_warning(): + # This will be "\xe9" on python 2 or "b'\xe9'" on python 3 + self.logger.error(b"\xe9") + self.assertEqual(self.get_output(), utf8(repr(b"\xe9"))) + + def test_utf8_logging(self): + with ignore_bytes_warning(): + self.logger.error("\u00e9".encode("utf8")) + if issubclass(bytes, basestring_type): + # on python 2, utf8 byte strings (and by extension ascii byte + # strings) are passed through as-is. + self.assertEqual(self.get_output(), utf8("\u00e9")) + else: + # on python 3, byte strings always get repr'd even if + # they're ascii-only, so this degenerates into another + # copy of test_bytes_logging. + self.assertEqual(self.get_output(), utf8(repr(utf8("\u00e9")))) + + def test_bytes_exception_logging(self): + try: + raise Exception(b"\xe9") + except Exception: + self.logger.exception("caught exception") + # This will be "Exception: \xe9" on python 2 or + # "Exception: b'\xe9'" on python 3. + output = self.get_output() + self.assertRegex(output, rb"Exception.*\\xe9") + # The traceback contains newlines, which should not have been escaped. + self.assertNotIn(rb"\n", output) + + +class UnicodeLogFormatterTest(LogFormatterTest): + def make_handler(self, filename): + # Adding an explicit encoding configuration allows non-ascii unicode + # strings in both python 2 and 3, without changing the behavior + # for byte strings. + return logging.FileHandler(filename, encoding="utf8") + + def test_unicode_logging(self): + self.logger.error("\u00e9") + self.assertEqual(self.get_output(), utf8("\u00e9")) + + +class EnablePrettyLoggingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.options = OptionParser() + define_logging_options(self.options) + self.logger = logging.Logger("tornado.test.log_test.EnablePrettyLoggingTest") + self.logger.propagate = False + + def test_log_file(self): + tmpdir = tempfile.mkdtemp() + try: + self.options.log_file_prefix = tmpdir + "/test_log" + enable_pretty_logging(options=self.options, logger=self.logger) + self.assertEqual(1, len(self.logger.handlers)) + self.logger.error("hello") + self.logger.handlers[0].flush() + filenames = glob.glob(tmpdir + "/test_log*") + self.assertEqual(1, len(filenames)) + with open(filenames[0]) as f: + self.assertRegex(f.read(), r"^\[E [^]]*\] hello$") + finally: + for handler in self.logger.handlers: + handler.flush() + handler.close() + for filename in glob.glob(tmpdir + "/test_log*"): + os.unlink(filename) + os.rmdir(tmpdir) + + def test_log_file_with_timed_rotating(self): + tmpdir = tempfile.mkdtemp() + try: + self.options.log_file_prefix = tmpdir + "/test_log" + self.options.log_rotate_mode = "time" + enable_pretty_logging(options=self.options, logger=self.logger) + self.logger.error("hello") + self.logger.handlers[0].flush() + filenames = glob.glob(tmpdir + "/test_log*") + self.assertEqual(1, len(filenames)) + with open(filenames[0]) as f: + self.assertRegex(f.read(), r"^\[E [^]]*\] hello$") + finally: + for handler in self.logger.handlers: + handler.flush() + handler.close() + for filename in glob.glob(tmpdir + "/test_log*"): + os.unlink(filename) + os.rmdir(tmpdir) + + def test_wrong_rotate_mode_value(self): + try: + self.options.log_file_prefix = "some_path" + self.options.log_rotate_mode = "wrong_mode" + self.assertRaises( + ValueError, + enable_pretty_logging, + options=self.options, + logger=self.logger, + ) + finally: + for handler in self.logger.handlers: + handler.flush() + handler.close() + + +class LoggingOptionTest(unittest.TestCase): + """Test the ability to enable and disable Tornado's logging hooks.""" + + def logs_present(self, statement, args=None): + # Each test may manipulate and/or parse the options and then logs + # a line at the 'info' level. This level is ignored in the + # logging module by default, but Tornado turns it on by default + # so it is the easiest way to tell whether tornado's logging hooks + # ran. + IMPORT = "from tornado.options import options, parse_command_line" + LOG_INFO = 'import logging; logging.info("hello")' + program = ";".join([IMPORT, statement, LOG_INFO]) + proc = subprocess.Popen( + [sys.executable, "-c", program] + (args or []), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + stdout, stderr = proc.communicate() + self.assertEqual(proc.returncode, 0, "process failed: %r" % stdout) + return b"hello" in stdout + + def test_default(self): + self.assertFalse(self.logs_present("pass")) + + def test_tornado_default(self): + self.assertTrue(self.logs_present("parse_command_line()")) + + def test_disable_command_line(self): + self.assertFalse(self.logs_present("parse_command_line()", ["--logging=none"])) + + def test_disable_command_line_case_insensitive(self): + self.assertFalse(self.logs_present("parse_command_line()", ["--logging=None"])) + + def test_disable_code_string(self): + self.assertFalse( + self.logs_present('options.logging = "none"; parse_command_line()') + ) + + def test_disable_code_none(self): + self.assertFalse( + self.logs_present("options.logging = None; parse_command_line()") + ) + + def test_disable_override(self): + # command line trumps code defaults + self.assertTrue( + self.logs_present( + "options.logging = None; parse_command_line()", ["--logging=info"] + ) + ) diff --git a/venv/lib/python3.9/site-packages/tornado/test/netutil_test.py b/venv/lib/python3.9/site-packages/tornado/test/netutil_test.py new file mode 100644 index 00000000..b35b7947 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/netutil_test.py @@ -0,0 +1,241 @@ +import errno +import os +import signal +import socket +from subprocess import Popen +import sys +import time +import unittest + +from tornado.netutil import ( + BlockingResolver, + OverrideResolver, + ThreadedResolver, + is_valid_ip, + bind_sockets, +) +from tornado.testing import AsyncTestCase, gen_test, bind_unused_port +from tornado.test.util import skipIfNoNetwork + +import typing + +if typing.TYPE_CHECKING: + from typing import List # noqa: F401 + +try: + import pycares # type: ignore +except ImportError: + pycares = None +else: + from tornado.platform.caresresolver import CaresResolver + +try: + import twisted # type: ignore + import twisted.names # type: ignore +except ImportError: + twisted = None +else: + from tornado.platform.twisted import TwistedResolver + + +class _ResolverTestMixin(object): + resolver = None # type: typing.Any + + @gen_test + def test_localhost(self: typing.Any): + addrinfo = yield self.resolver.resolve("localhost", 80, socket.AF_UNSPEC) + # Most of the time localhost resolves to either the ipv4 loopback + # address alone, or ipv4+ipv6. But some versions of pycares will only + # return the ipv6 version, so we have to check for either one alone. + self.assertTrue( + ((socket.AF_INET, ("127.0.0.1", 80)) in addrinfo) + or ((socket.AF_INET6, ("::1", 80)) in addrinfo), + f"loopback address not found in {addrinfo}", + ) + + +# It is impossible to quickly and consistently generate an error in name +# resolution, so test this case separately, using mocks as needed. +class _ResolverErrorTestMixin(object): + resolver = None # type: typing.Any + + @gen_test + def test_bad_host(self: typing.Any): + with self.assertRaises(IOError): + yield self.resolver.resolve("an invalid domain", 80, socket.AF_UNSPEC) + + +def _failing_getaddrinfo(*args): + """Dummy implementation of getaddrinfo for use in mocks""" + raise socket.gaierror(errno.EIO, "mock: lookup failed") + + +@skipIfNoNetwork +class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin): + def setUp(self): + super().setUp() + self.resolver = BlockingResolver() + + +# getaddrinfo-based tests need mocking to reliably generate errors; +# some configurations are slow to produce errors and take longer than +# our default timeout. +class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin): + def setUp(self): + super().setUp() + self.resolver = BlockingResolver() + self.real_getaddrinfo = socket.getaddrinfo + socket.getaddrinfo = _failing_getaddrinfo + + def tearDown(self): + socket.getaddrinfo = self.real_getaddrinfo + super().tearDown() + + +class OverrideResolverTest(AsyncTestCase, _ResolverTestMixin): + def setUp(self): + super().setUp() + mapping = { + ("google.com", 80): ("1.2.3.4", 80), + ("google.com", 80, socket.AF_INET): ("1.2.3.4", 80), + ("google.com", 80, socket.AF_INET6): ( + "2a02:6b8:7c:40c:c51e:495f:e23a:3", + 80, + ), + } + self.resolver = OverrideResolver(BlockingResolver(), mapping) + + @gen_test + def test_resolve_multiaddr(self): + result = yield self.resolver.resolve("google.com", 80, socket.AF_INET) + self.assertIn((socket.AF_INET, ("1.2.3.4", 80)), result) + + result = yield self.resolver.resolve("google.com", 80, socket.AF_INET6) + self.assertIn( + (socket.AF_INET6, ("2a02:6b8:7c:40c:c51e:495f:e23a:3", 80, 0, 0)), result + ) + + +@skipIfNoNetwork +class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin): + def setUp(self): + super().setUp() + self.resolver = ThreadedResolver() + + def tearDown(self): + self.resolver.close() + super().tearDown() + + +class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin): + def setUp(self): + super().setUp() + self.resolver = BlockingResolver() + self.real_getaddrinfo = socket.getaddrinfo + socket.getaddrinfo = _failing_getaddrinfo + + def tearDown(self): + socket.getaddrinfo = self.real_getaddrinfo + super().tearDown() + + +@skipIfNoNetwork +@unittest.skipIf(sys.platform == "win32", "preexec_fn not available on win32") +class ThreadedResolverImportTest(unittest.TestCase): + def test_import(self): + TIMEOUT = 5 + + # Test for a deadlock when importing a module that runs the + # ThreadedResolver at import-time. See resolve_test.py for + # full explanation. + command = [sys.executable, "-c", "import tornado.test.resolve_test_helper"] + + start = time.time() + popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT)) + while time.time() - start < TIMEOUT: + return_code = popen.poll() + if return_code is not None: + self.assertEqual(0, return_code) + return # Success. + time.sleep(0.05) + + self.fail("import timed out") + + +# We do not test errors with CaresResolver: +# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results +# with an NXDOMAIN status code. Most resolvers treat this as an error; +# C-ares returns the results, making the "bad_host" tests unreliable. +# C-ares will try to resolve even malformed names, such as the +# name with spaces used in this test. +@skipIfNoNetwork +@unittest.skipIf(pycares is None, "pycares module not present") +@unittest.skipIf(sys.platform == "win32", "pycares doesn't return loopback on windows") +@unittest.skipIf(sys.platform == "darwin", "pycares doesn't return 127.0.0.1 on darwin") +class CaresResolverTest(AsyncTestCase, _ResolverTestMixin): + def setUp(self): + super().setUp() + self.resolver = CaresResolver() + + +# TwistedResolver produces consistent errors in our test cases so we +# could test the regular and error cases in the same class. However, +# in the error cases it appears that cleanup of socket objects is +# handled asynchronously and occasionally results in "unclosed socket" +# warnings if not given time to shut down (and there is no way to +# explicitly shut it down). This makes the test flaky, so we do not +# test error cases here. +@skipIfNoNetwork +@unittest.skipIf(twisted is None, "twisted module not present") +@unittest.skipIf( + getattr(twisted, "__version__", "0.0") < "12.1", "old version of twisted" +) +@unittest.skipIf(sys.platform == "win32", "twisted resolver hangs on windows") +class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin): + def setUp(self): + super().setUp() + self.resolver = TwistedResolver() + + +class IsValidIPTest(unittest.TestCase): + def test_is_valid_ip(self): + self.assertTrue(is_valid_ip("127.0.0.1")) + self.assertTrue(is_valid_ip("4.4.4.4")) + self.assertTrue(is_valid_ip("::1")) + self.assertTrue(is_valid_ip("2620:0:1cfe:face:b00c::3")) + self.assertTrue(not is_valid_ip("www.google.com")) + self.assertTrue(not is_valid_ip("localhost")) + self.assertTrue(not is_valid_ip("4.4.4.4<")) + self.assertTrue(not is_valid_ip(" 127.0.0.1")) + self.assertTrue(not is_valid_ip("")) + self.assertTrue(not is_valid_ip(" ")) + self.assertTrue(not is_valid_ip("\n")) + self.assertTrue(not is_valid_ip("\x00")) + self.assertTrue(not is_valid_ip("a" * 100)) + + +class TestPortAllocation(unittest.TestCase): + def test_same_port_allocation(self): + if "TRAVIS" in os.environ: + self.skipTest("dual-stack servers often have port conflicts on travis") + sockets = bind_sockets(0, "localhost") + try: + port = sockets[0].getsockname()[1] + self.assertTrue(all(s.getsockname()[1] == port for s in sockets[1:])) + finally: + for sock in sockets: + sock.close() + + @unittest.skipIf( + not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported" + ) + def test_reuse_port(self): + sockets = [] # type: List[socket.socket] + socket, port = bind_unused_port(reuse_port=True) + try: + sockets = bind_sockets(port, "127.0.0.1", reuse_port=True) + self.assertTrue(all(s.getsockname()[1] == port for s in sockets)) + finally: + socket.close() + for sock in sockets: + sock.close() diff --git a/venv/lib/python3.9/site-packages/tornado/test/options_test.cfg b/venv/lib/python3.9/site-packages/tornado/test/options_test.cfg new file mode 100644 index 00000000..4ead46a4 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/options_test.cfg @@ -0,0 +1,7 @@ +port=443 +port=443 +username='李康' + +foo_bar='a' + +my_path = __file__ diff --git a/venv/lib/python3.9/site-packages/tornado/test/options_test.py b/venv/lib/python3.9/site-packages/tornado/test/options_test.py new file mode 100644 index 00000000..6f4021c6 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/options_test.py @@ -0,0 +1,329 @@ +import datetime +from io import StringIO +import os +import sys +from unittest import mock +import unittest + +from tornado.options import OptionParser, Error +from tornado.util import basestring_type +from tornado.test.util import subTest + +import typing + +if typing.TYPE_CHECKING: + from typing import List # noqa: F401 + + +class Email(object): + def __init__(self, value): + if isinstance(value, str) and "@" in value: + self._value = value + else: + raise ValueError() + + @property + def value(self): + return self._value + + +class OptionsTest(unittest.TestCase): + def test_parse_command_line(self): + options = OptionParser() + options.define("port", default=80) + options.parse_command_line(["main.py", "--port=443"]) + self.assertEqual(options.port, 443) + + def test_parse_config_file(self): + options = OptionParser() + options.define("port", default=80) + options.define("username", default="foo") + options.define("my_path") + config_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "options_test.cfg" + ) + options.parse_config_file(config_path) + self.assertEqual(options.port, 443) + self.assertEqual(options.username, "李康") + self.assertEqual(options.my_path, config_path) + + def test_parse_callbacks(self): + options = OptionParser() + self.called = False + + def callback(): + self.called = True + + options.add_parse_callback(callback) + + # non-final parse doesn't run callbacks + options.parse_command_line(["main.py"], final=False) + self.assertFalse(self.called) + + # final parse does + options.parse_command_line(["main.py"]) + self.assertTrue(self.called) + + # callbacks can be run more than once on the same options + # object if there are multiple final parses + self.called = False + options.parse_command_line(["main.py"]) + self.assertTrue(self.called) + + def test_help(self): + options = OptionParser() + try: + orig_stderr = sys.stderr + sys.stderr = StringIO() + with self.assertRaises(SystemExit): + options.parse_command_line(["main.py", "--help"]) + usage = sys.stderr.getvalue() + finally: + sys.stderr = orig_stderr + self.assertIn("Usage:", usage) + + def test_subcommand(self): + base_options = OptionParser() + base_options.define("verbose", default=False) + sub_options = OptionParser() + sub_options.define("foo", type=str) + rest = base_options.parse_command_line( + ["main.py", "--verbose", "subcommand", "--foo=bar"] + ) + self.assertEqual(rest, ["subcommand", "--foo=bar"]) + self.assertTrue(base_options.verbose) + rest2 = sub_options.parse_command_line(rest) + self.assertEqual(rest2, []) + self.assertEqual(sub_options.foo, "bar") + + # the two option sets are distinct + try: + orig_stderr = sys.stderr + sys.stderr = StringIO() + with self.assertRaises(Error): + sub_options.parse_command_line(["subcommand", "--verbose"]) + finally: + sys.stderr = orig_stderr + + def test_setattr(self): + options = OptionParser() + options.define("foo", default=1, type=int) + options.foo = 2 + self.assertEqual(options.foo, 2) + + def test_setattr_type_check(self): + # setattr requires that options be the right type and doesn't + # parse from string formats. + options = OptionParser() + options.define("foo", default=1, type=int) + with self.assertRaises(Error): + options.foo = "2" + + def test_setattr_with_callback(self): + values = [] # type: List[int] + options = OptionParser() + options.define("foo", default=1, type=int, callback=values.append) + options.foo = 2 + self.assertEqual(values, [2]) + + def _sample_options(self): + options = OptionParser() + options.define("a", default=1) + options.define("b", default=2) + return options + + def test_iter(self): + options = self._sample_options() + # OptionParsers always define 'help'. + self.assertEqual(set(["a", "b", "help"]), set(iter(options))) + + def test_getitem(self): + options = self._sample_options() + self.assertEqual(1, options["a"]) + + def test_setitem(self): + options = OptionParser() + options.define("foo", default=1, type=int) + options["foo"] = 2 + self.assertEqual(options["foo"], 2) + + def test_items(self): + options = self._sample_options() + # OptionParsers always define 'help'. + expected = [("a", 1), ("b", 2), ("help", options.help)] + actual = sorted(options.items()) + self.assertEqual(expected, actual) + + def test_as_dict(self): + options = self._sample_options() + expected = {"a": 1, "b": 2, "help": options.help} + self.assertEqual(expected, options.as_dict()) + + def test_group_dict(self): + options = OptionParser() + options.define("a", default=1) + options.define("b", group="b_group", default=2) + + frame = sys._getframe(0) + this_file = frame.f_code.co_filename + self.assertEqual(set(["b_group", "", this_file]), options.groups()) + + b_group_dict = options.group_dict("b_group") + self.assertEqual({"b": 2}, b_group_dict) + + self.assertEqual({}, options.group_dict("nonexistent")) + + def test_mock_patch(self): + # ensure that our setattr hooks don't interfere with mock.patch + options = OptionParser() + options.define("foo", default=1) + options.parse_command_line(["main.py", "--foo=2"]) + self.assertEqual(options.foo, 2) + + with mock.patch.object(options.mockable(), "foo", 3): + self.assertEqual(options.foo, 3) + self.assertEqual(options.foo, 2) + + # Try nested patches mixed with explicit sets + with mock.patch.object(options.mockable(), "foo", 4): + self.assertEqual(options.foo, 4) + options.foo = 5 + self.assertEqual(options.foo, 5) + with mock.patch.object(options.mockable(), "foo", 6): + self.assertEqual(options.foo, 6) + self.assertEqual(options.foo, 5) + self.assertEqual(options.foo, 2) + + def _define_options(self): + options = OptionParser() + options.define("str", type=str) + options.define("basestring", type=basestring_type) + options.define("int", type=int) + options.define("float", type=float) + options.define("datetime", type=datetime.datetime) + options.define("timedelta", type=datetime.timedelta) + options.define("email", type=Email) + options.define("list-of-int", type=int, multiple=True) + options.define("list-of-str", type=str, multiple=True) + return options + + def _check_options_values(self, options): + self.assertEqual(options.str, "asdf") + self.assertEqual(options.basestring, "qwer") + self.assertEqual(options.int, 42) + self.assertEqual(options.float, 1.5) + self.assertEqual(options.datetime, datetime.datetime(2013, 4, 28, 5, 16)) + self.assertEqual(options.timedelta, datetime.timedelta(seconds=45)) + self.assertEqual(options.email.value, "tornado@web.com") + self.assertTrue(isinstance(options.email, Email)) + self.assertEqual(options.list_of_int, [1, 2, 3]) + self.assertEqual(options.list_of_str, ["a", "b", "c"]) + + def test_types(self): + options = self._define_options() + options.parse_command_line( + [ + "main.py", + "--str=asdf", + "--basestring=qwer", + "--int=42", + "--float=1.5", + "--datetime=2013-04-28 05:16", + "--timedelta=45s", + "--email=tornado@web.com", + "--list-of-int=1,2,3", + "--list-of-str=a,b,c", + ] + ) + self._check_options_values(options) + + def test_types_with_conf_file(self): + for config_file_name in ( + "options_test_types.cfg", + "options_test_types_str.cfg", + ): + options = self._define_options() + options.parse_config_file( + os.path.join(os.path.dirname(__file__), config_file_name) + ) + self._check_options_values(options) + + def test_multiple_string(self): + options = OptionParser() + options.define("foo", type=str, multiple=True) + options.parse_command_line(["main.py", "--foo=a,b,c"]) + self.assertEqual(options.foo, ["a", "b", "c"]) + + def test_multiple_int(self): + options = OptionParser() + options.define("foo", type=int, multiple=True) + options.parse_command_line(["main.py", "--foo=1,3,5:7"]) + self.assertEqual(options.foo, [1, 3, 5, 6, 7]) + + def test_error_redefine(self): + options = OptionParser() + options.define("foo") + with self.assertRaises(Error) as cm: + options.define("foo") + self.assertRegex(str(cm.exception), "Option.*foo.*already defined") + + def test_error_redefine_underscore(self): + # Ensure that the dash/underscore normalization doesn't + # interfere with the redefinition error. + tests = [ + ("foo-bar", "foo-bar"), + ("foo_bar", "foo_bar"), + ("foo-bar", "foo_bar"), + ("foo_bar", "foo-bar"), + ] + for a, b in tests: + with subTest(self, a=a, b=b): + options = OptionParser() + options.define(a) + with self.assertRaises(Error) as cm: + options.define(b) + self.assertRegex(str(cm.exception), "Option.*foo.bar.*already defined") + + def test_dash_underscore_cli(self): + # Dashes and underscores should be interchangeable. + for defined_name in ["foo-bar", "foo_bar"]: + for flag in ["--foo-bar=a", "--foo_bar=a"]: + options = OptionParser() + options.define(defined_name) + options.parse_command_line(["main.py", flag]) + # Attr-style access always uses underscores. + self.assertEqual(options.foo_bar, "a") + # Dict-style access allows both. + self.assertEqual(options["foo-bar"], "a") + self.assertEqual(options["foo_bar"], "a") + + def test_dash_underscore_file(self): + # No matter how an option was defined, it can be set with underscores + # in a config file. + for defined_name in ["foo-bar", "foo_bar"]: + options = OptionParser() + options.define(defined_name) + options.parse_config_file( + os.path.join(os.path.dirname(__file__), "options_test.cfg") + ) + self.assertEqual(options.foo_bar, "a") + + def test_dash_underscore_introspection(self): + # Original names are preserved in introspection APIs. + options = OptionParser() + options.define("with-dash", group="g") + options.define("with_underscore", group="g") + all_options = ["help", "with-dash", "with_underscore"] + self.assertEqual(sorted(options), all_options) + self.assertEqual(sorted(k for (k, v) in options.items()), all_options) + self.assertEqual(sorted(options.as_dict().keys()), all_options) + + self.assertEqual( + sorted(options.group_dict("g")), ["with-dash", "with_underscore"] + ) + + # --help shows CLI-style names with dashes. + buf = StringIO() + options.print_help(buf) + self.assertIn("--with-dash", buf.getvalue()) + self.assertIn("--with-underscore", buf.getvalue()) diff --git a/venv/lib/python3.9/site-packages/tornado/test/options_test_types.cfg b/venv/lib/python3.9/site-packages/tornado/test/options_test_types.cfg new file mode 100644 index 00000000..9dfd9220 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/options_test_types.cfg @@ -0,0 +1,12 @@ +from datetime import datetime, timedelta +from tornado.test.options_test import Email + +str = 'asdf' +basestring = 'qwer' +int = 42 +float = 1.5 +datetime = datetime(2013, 4, 28, 5, 16) +timedelta = timedelta(0, 45) +email = Email('tornado@web.com') +list_of_int = [1, 2, 3] +list_of_str = ["a", "b", "c"] diff --git a/venv/lib/python3.9/site-packages/tornado/test/options_test_types_str.cfg b/venv/lib/python3.9/site-packages/tornado/test/options_test_types_str.cfg new file mode 100644 index 00000000..b07d6428 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/options_test_types_str.cfg @@ -0,0 +1,9 @@ +str = 'asdf' +basestring = 'qwer' +int = 42 +float = 1.5 +datetime = '2013-04-28 05:16' +timedelta = '45s' +email = 'tornado@web.com' +list_of_int = '1,2,3' +list_of_str = 'a,b,c' diff --git a/venv/lib/python3.9/site-packages/tornado/test/process_test.py b/venv/lib/python3.9/site-packages/tornado/test/process_test.py new file mode 100644 index 00000000..ab290085 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/process_test.py @@ -0,0 +1,274 @@ +import asyncio +import logging +import os +import signal +import subprocess +import sys +import time +import unittest + +from tornado.httpclient import HTTPClient, HTTPError +from tornado.httpserver import HTTPServer +from tornado.ioloop import IOLoop +from tornado.log import gen_log +from tornado.process import fork_processes, task_id, Subprocess +from tornado.simple_httpclient import SimpleAsyncHTTPClient +from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test +from tornado.test.util import skipIfNonUnix +from tornado.web import RequestHandler, Application + + +# Not using AsyncHTTPTestCase because we need control over the IOLoop. +@skipIfNonUnix +class ProcessTest(unittest.TestCase): + def get_app(self): + class ProcessHandler(RequestHandler): + def get(self): + if self.get_argument("exit", None): + # must use os._exit instead of sys.exit so unittest's + # exception handler doesn't catch it + os._exit(int(self.get_argument("exit"))) + if self.get_argument("signal", None): + os.kill(os.getpid(), int(self.get_argument("signal"))) + self.write(str(os.getpid())) + + return Application([("/", ProcessHandler)]) + + def tearDown(self): + if task_id() is not None: + # We're in a child process, and probably got to this point + # via an uncaught exception. If we return now, both + # processes will continue with the rest of the test suite. + # Exit now so the parent process will restart the child + # (since we don't have a clean way to signal failure to + # the parent that won't restart) + logging.error("aborting child process from tearDown") + logging.shutdown() + os._exit(1) + # In the surviving process, clear the alarm we set earlier + signal.alarm(0) + super().tearDown() + + def test_multi_process(self): + # This test doesn't work on twisted because we use the global + # reactor and don't restore it to a sane state after the fork + # (asyncio has the same issue, but we have a special case in + # place for it). + with ExpectLog( + gen_log, "(Starting .* processes|child .* exited|uncaught exception)" + ): + sock, port = bind_unused_port() + + def get_url(path): + return "http://127.0.0.1:%d%s" % (port, path) + + # ensure that none of these processes live too long + signal.alarm(5) # master process + try: + id = fork_processes(3, max_restarts=3) + self.assertTrue(id is not None) + signal.alarm(5) # child processes + except SystemExit as e: + # if we exit cleanly from fork_processes, all the child processes + # finished with status 0 + self.assertEqual(e.code, 0) + self.assertTrue(task_id() is None) + sock.close() + return + try: + if id in (0, 1): + self.assertEqual(id, task_id()) + + async def f(): + server = HTTPServer(self.get_app()) + server.add_sockets([sock]) + await asyncio.Event().wait() + + asyncio.run(f()) + elif id == 2: + self.assertEqual(id, task_id()) + sock.close() + # Always use SimpleAsyncHTTPClient here; the curl + # version appears to get confused sometimes if the + # connection gets closed before it's had a chance to + # switch from writing mode to reading mode. + client = HTTPClient(SimpleAsyncHTTPClient) + + def fetch(url, fail_ok=False): + try: + return client.fetch(get_url(url)) + except HTTPError as e: + if not (fail_ok and e.code == 599): + raise + + # Make two processes exit abnormally + fetch("/?exit=2", fail_ok=True) + fetch("/?exit=3", fail_ok=True) + + # They've been restarted, so a new fetch will work + int(fetch("/").body) + + # Now the same with signals + # Disabled because on the mac a process dying with a signal + # can trigger an "Application exited abnormally; send error + # report to Apple?" prompt. + # fetch("/?signal=%d" % signal.SIGTERM, fail_ok=True) + # fetch("/?signal=%d" % signal.SIGABRT, fail_ok=True) + # int(fetch("/").body) + + # Now kill them normally so they won't be restarted + fetch("/?exit=0", fail_ok=True) + # One process left; watch it's pid change + pid = int(fetch("/").body) + fetch("/?exit=4", fail_ok=True) + pid2 = int(fetch("/").body) + self.assertNotEqual(pid, pid2) + + # Kill the last one so we shut down cleanly + fetch("/?exit=0", fail_ok=True) + + os._exit(0) + except Exception: + logging.error("exception in child process %d", id, exc_info=True) + raise + + +@skipIfNonUnix +class SubprocessTest(AsyncTestCase): + def term_and_wait(self, subproc): + subproc.proc.terminate() + subproc.proc.wait() + + @gen_test + def test_subprocess(self): + if IOLoop.configured_class().__name__.endswith("LayeredTwistedIOLoop"): + # This test fails non-deterministically with LayeredTwistedIOLoop. + # (the read_until('\n') returns '\n' instead of 'hello\n') + # This probably indicates a problem with either TornadoReactor + # or TwistedIOLoop, but I haven't been able to track it down + # and for now this is just causing spurious travis-ci failures. + raise unittest.SkipTest( + "Subprocess tests not compatible with " "LayeredTwistedIOLoop" + ) + subproc = Subprocess( + [sys.executable, "-u", "-i"], + stdin=Subprocess.STREAM, + stdout=Subprocess.STREAM, + stderr=subprocess.STDOUT, + ) + self.addCleanup(lambda: self.term_and_wait(subproc)) + self.addCleanup(subproc.stdout.close) + self.addCleanup(subproc.stdin.close) + yield subproc.stdout.read_until(b">>> ") + subproc.stdin.write(b"print('hello')\n") + data = yield subproc.stdout.read_until(b"\n") + self.assertEqual(data, b"hello\n") + + yield subproc.stdout.read_until(b">>> ") + subproc.stdin.write(b"raise SystemExit\n") + data = yield subproc.stdout.read_until_close() + self.assertEqual(data, b"") + + @gen_test + def test_close_stdin(self): + # Close the parent's stdin handle and see that the child recognizes it. + subproc = Subprocess( + [sys.executable, "-u", "-i"], + stdin=Subprocess.STREAM, + stdout=Subprocess.STREAM, + stderr=subprocess.STDOUT, + ) + self.addCleanup(lambda: self.term_and_wait(subproc)) + yield subproc.stdout.read_until(b">>> ") + subproc.stdin.close() + data = yield subproc.stdout.read_until_close() + self.assertEqual(data, b"\n") + + @gen_test + def test_stderr(self): + # This test is mysteriously flaky on twisted: it succeeds, but logs + # an error of EBADF on closing a file descriptor. + subproc = Subprocess( + [sys.executable, "-u", "-c", r"import sys; sys.stderr.write('hello\n')"], + stderr=Subprocess.STREAM, + ) + self.addCleanup(lambda: self.term_and_wait(subproc)) + data = yield subproc.stderr.read_until(b"\n") + self.assertEqual(data, b"hello\n") + # More mysterious EBADF: This fails if done with self.addCleanup instead of here. + subproc.stderr.close() + + def test_sigchild(self): + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess([sys.executable, "-c", "pass"]) + subproc.set_exit_callback(self.stop) + ret = self.wait() + self.assertEqual(ret, 0) + self.assertEqual(subproc.returncode, ret) + + @gen_test + def test_sigchild_future(self): + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess([sys.executable, "-c", "pass"]) + ret = yield subproc.wait_for_exit() + self.assertEqual(ret, 0) + self.assertEqual(subproc.returncode, ret) + + def test_sigchild_signal(self): + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess( + [sys.executable, "-c", "import time; time.sleep(30)"], + stdout=Subprocess.STREAM, + ) + self.addCleanup(subproc.stdout.close) + subproc.set_exit_callback(self.stop) + + # For unclear reasons, killing a process too soon after + # creating it can result in an exit status corresponding to + # SIGKILL instead of the actual signal involved. This has been + # observed on macOS 10.15 with Python 3.8 installed via brew, + # but not with the system-installed Python 3.7. + time.sleep(0.1) + + os.kill(subproc.pid, signal.SIGTERM) + try: + ret = self.wait() + except AssertionError: + # We failed to get the termination signal. This test is + # occasionally flaky on pypy, so try to get a little more + # information: did the process close its stdout + # (indicating that the problem is in the parent process's + # signal handling) or did the child process somehow fail + # to terminate? + fut = subproc.stdout.read_until_close() + fut.add_done_callback(lambda f: self.stop()) # type: ignore + try: + self.wait() + except AssertionError: + raise AssertionError("subprocess failed to terminate") + else: + raise AssertionError( + "subprocess closed stdout but failed to " "get termination signal" + ) + self.assertEqual(subproc.returncode, ret) + self.assertEqual(ret, -signal.SIGTERM) + + @gen_test + def test_wait_for_exit_raise(self): + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess([sys.executable, "-c", "import sys; sys.exit(1)"]) + with self.assertRaises(subprocess.CalledProcessError) as cm: + yield subproc.wait_for_exit() + self.assertEqual(cm.exception.returncode, 1) + + @gen_test + def test_wait_for_exit_raise_disabled(self): + Subprocess.initialize() + self.addCleanup(Subprocess.uninitialize) + subproc = Subprocess([sys.executable, "-c", "import sys; sys.exit(1)"]) + ret = yield subproc.wait_for_exit(raise_error=False) + self.assertEqual(ret, 1) diff --git a/venv/lib/python3.9/site-packages/tornado/test/queues_test.py b/venv/lib/python3.9/site-packages/tornado/test/queues_test.py new file mode 100644 index 00000000..98a29a8d --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/queues_test.py @@ -0,0 +1,431 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import asyncio +from datetime import timedelta +from random import random +import unittest + +from tornado import gen, queues +from tornado.gen import TimeoutError +from tornado.testing import gen_test, AsyncTestCase + + +class QueueBasicTest(AsyncTestCase): + def test_repr_and_str(self): + q = queues.Queue(maxsize=1) # type: queues.Queue[None] + self.assertIn(hex(id(q)), repr(q)) + self.assertNotIn(hex(id(q)), str(q)) + q.get() + + for q_str in repr(q), str(q): + self.assertTrue(q_str.startswith("<Queue")) + self.assertIn("maxsize=1", q_str) + self.assertIn("getters[1]", q_str) + self.assertNotIn("putters", q_str) + self.assertNotIn("tasks", q_str) + + q.put(None) + q.put(None) + # Now the queue is full, this putter blocks. + q.put(None) + + for q_str in repr(q), str(q): + self.assertNotIn("getters", q_str) + self.assertIn("putters[1]", q_str) + self.assertIn("tasks=2", q_str) + + def test_order(self): + q = queues.Queue() # type: queues.Queue[int] + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([1, 3, 2], items) + + @gen_test + def test_maxsize(self): + self.assertRaises(TypeError, queues.Queue, maxsize=None) + self.assertRaises(ValueError, queues.Queue, maxsize=-1) + + q = queues.Queue(maxsize=2) # type: queues.Queue[int] + self.assertTrue(q.empty()) + self.assertFalse(q.full()) + self.assertEqual(2, q.maxsize) + self.assertTrue(q.put(0).done()) + self.assertTrue(q.put(1).done()) + self.assertFalse(q.empty()) + self.assertTrue(q.full()) + put2 = q.put(2) + self.assertFalse(put2.done()) + self.assertEqual(0, (yield q.get())) # Make room. + self.assertTrue(put2.done()) + self.assertFalse(q.empty()) + self.assertTrue(q.full()) + + +class QueueGetTest(AsyncTestCase): + @gen_test + def test_blocking_get(self): + q = queues.Queue() # type: queues.Queue[int] + q.put_nowait(0) + self.assertEqual(0, (yield q.get())) + + def test_nonblocking_get(self): + q = queues.Queue() # type: queues.Queue[int] + q.put_nowait(0) + self.assertEqual(0, q.get_nowait()) + + def test_nonblocking_get_exception(self): + q = queues.Queue() # type: queues.Queue[int] + self.assertRaises(queues.QueueEmpty, q.get_nowait) + + @gen_test + def test_get_with_putters(self): + q = queues.Queue(1) # type: queues.Queue[int] + q.put_nowait(0) + put = q.put(1) + self.assertEqual(0, (yield q.get())) + self.assertIsNone((yield put)) + + @gen_test + def test_blocking_get_wait(self): + q = queues.Queue() # type: queues.Queue[int] + q.put(0) + self.io_loop.call_later(0.01, q.put_nowait, 1) + self.io_loop.call_later(0.02, q.put_nowait, 2) + self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1)))) + self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1)))) + + @gen_test + def test_get_timeout(self): + q = queues.Queue() # type: queues.Queue[int] + get_timeout = q.get(timeout=timedelta(seconds=0.01)) + get = q.get() + with self.assertRaises(TimeoutError): + yield get_timeout + + q.put_nowait(0) + self.assertEqual(0, (yield get)) + + @gen_test + def test_get_timeout_preempted(self): + q = queues.Queue() # type: queues.Queue[int] + get = q.get(timeout=timedelta(seconds=0.01)) + q.put(0) + yield gen.sleep(0.02) + self.assertEqual(0, (yield get)) + + @gen_test + def test_get_clears_timed_out_putters(self): + q = queues.Queue(1) # type: queues.Queue[int] + # First putter succeeds, remainder block. + putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)] + put = q.put(10) + self.assertEqual(10, len(q._putters)) + yield gen.sleep(0.02) + self.assertEqual(10, len(q._putters)) + self.assertFalse(put.done()) # Final waiter is still active. + q.put(11) + self.assertEqual(0, (yield q.get())) # get() clears the waiters. + self.assertEqual(1, len(q._putters)) + for putter in putters[1:]: + self.assertRaises(TimeoutError, putter.result) + + @gen_test + def test_get_clears_timed_out_getters(self): + q = queues.Queue() # type: queues.Queue[int] + getters = [ + asyncio.ensure_future(q.get(timedelta(seconds=0.01))) for _ in range(10) + ] + get = asyncio.ensure_future(q.get()) + self.assertEqual(11, len(q._getters)) + yield gen.sleep(0.02) + self.assertEqual(11, len(q._getters)) + self.assertFalse(get.done()) # Final waiter is still active. + q.get() # get() clears the waiters. + self.assertEqual(2, len(q._getters)) + for getter in getters: + self.assertRaises(TimeoutError, getter.result) + + @gen_test + def test_async_for(self): + q = queues.Queue() # type: queues.Queue[int] + for i in range(5): + q.put(i) + + async def f(): + results = [] + async for i in q: + results.append(i) + if i == 4: + return results + + results = yield f() + self.assertEqual(results, list(range(5))) + + +class QueuePutTest(AsyncTestCase): + @gen_test + def test_blocking_put(self): + q = queues.Queue() # type: queues.Queue[int] + q.put(0) + self.assertEqual(0, q.get_nowait()) + + def test_nonblocking_put_exception(self): + q = queues.Queue(1) # type: queues.Queue[int] + q.put(0) + self.assertRaises(queues.QueueFull, q.put_nowait, 1) + + @gen_test + def test_put_with_getters(self): + q = queues.Queue() # type: queues.Queue[int] + get0 = q.get() + get1 = q.get() + yield q.put(0) + self.assertEqual(0, (yield get0)) + yield q.put(1) + self.assertEqual(1, (yield get1)) + + @gen_test + def test_nonblocking_put_with_getters(self): + q = queues.Queue() # type: queues.Queue[int] + get0 = q.get() + get1 = q.get() + q.put_nowait(0) + # put_nowait does *not* immediately unblock getters. + yield gen.moment + self.assertEqual(0, (yield get0)) + q.put_nowait(1) + yield gen.moment + self.assertEqual(1, (yield get1)) + + @gen_test + def test_blocking_put_wait(self): + q = queues.Queue(1) # type: queues.Queue[int] + q.put_nowait(0) + + def get_and_discard(): + q.get() + + self.io_loop.call_later(0.01, get_and_discard) + self.io_loop.call_later(0.02, get_and_discard) + futures = [q.put(0), q.put(1)] + self.assertFalse(any(f.done() for f in futures)) + yield futures + + @gen_test + def test_put_timeout(self): + q = queues.Queue(1) # type: queues.Queue[int] + q.put_nowait(0) # Now it's full. + put_timeout = q.put(1, timeout=timedelta(seconds=0.01)) + put = q.put(2) + with self.assertRaises(TimeoutError): + yield put_timeout + + self.assertEqual(0, q.get_nowait()) + # 1 was never put in the queue. + self.assertEqual(2, (yield q.get())) + + # Final get() unblocked this putter. + yield put + + @gen_test + def test_put_timeout_preempted(self): + q = queues.Queue(1) # type: queues.Queue[int] + q.put_nowait(0) + put = q.put(1, timeout=timedelta(seconds=0.01)) + q.get() + yield gen.sleep(0.02) + yield put # No TimeoutError. + + @gen_test + def test_put_clears_timed_out_putters(self): + q = queues.Queue(1) # type: queues.Queue[int] + # First putter succeeds, remainder block. + putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)] + put = q.put(10) + self.assertEqual(10, len(q._putters)) + yield gen.sleep(0.02) + self.assertEqual(10, len(q._putters)) + self.assertFalse(put.done()) # Final waiter is still active. + q.put(11) # put() clears the waiters. + self.assertEqual(2, len(q._putters)) + for putter in putters[1:]: + self.assertRaises(TimeoutError, putter.result) + + @gen_test + def test_put_clears_timed_out_getters(self): + q = queues.Queue() # type: queues.Queue[int] + getters = [ + asyncio.ensure_future(q.get(timedelta(seconds=0.01))) for _ in range(10) + ] + get = asyncio.ensure_future(q.get()) + q.get() + self.assertEqual(12, len(q._getters)) + yield gen.sleep(0.02) + self.assertEqual(12, len(q._getters)) + self.assertFalse(get.done()) # Final waiters still active. + q.put(0) # put() clears the waiters. + self.assertEqual(1, len(q._getters)) + self.assertEqual(0, (yield get)) + for getter in getters: + self.assertRaises(TimeoutError, getter.result) + + @gen_test + def test_float_maxsize(self): + # If a float is passed for maxsize, a reasonable limit should + # be enforced, instead of being treated as unlimited. + # It happens to be rounded up. + # http://bugs.python.org/issue21723 + q = queues.Queue(maxsize=1.3) # type: ignore + self.assertTrue(q.empty()) + self.assertFalse(q.full()) + q.put_nowait(0) + q.put_nowait(1) + self.assertFalse(q.empty()) + self.assertTrue(q.full()) + self.assertRaises(queues.QueueFull, q.put_nowait, 2) + self.assertEqual(0, q.get_nowait()) + self.assertFalse(q.empty()) + self.assertFalse(q.full()) + + yield q.put(2) + put = q.put(3) + self.assertFalse(put.done()) + self.assertEqual(1, (yield q.get())) + yield put + self.assertTrue(q.full()) + + +class QueueJoinTest(AsyncTestCase): + queue_class = queues.Queue + + def test_task_done_underflow(self): + q = self.queue_class() # type: queues.Queue + self.assertRaises(ValueError, q.task_done) + + @gen_test + def test_task_done(self): + q = self.queue_class() # type: queues.Queue + for i in range(100): + q.put_nowait(i) + + self.accumulator = 0 + + @gen.coroutine + def worker(): + while True: + item = yield q.get() + self.accumulator += item + q.task_done() + yield gen.sleep(random() * 0.01) + + # Two coroutines share work. + worker() + worker() + yield q.join() + self.assertEqual(sum(range(100)), self.accumulator) + + @gen_test + def test_task_done_delay(self): + # Verify it is task_done(), not get(), that unblocks join(). + q = self.queue_class() # type: queues.Queue + q.put_nowait(0) + join = asyncio.ensure_future(q.join()) + self.assertFalse(join.done()) + yield q.get() + self.assertFalse(join.done()) + yield gen.moment + self.assertFalse(join.done()) + q.task_done() + self.assertTrue(join.done()) + + @gen_test + def test_join_empty_queue(self): + q = self.queue_class() # type: queues.Queue + yield q.join() + yield q.join() + + @gen_test + def test_join_timeout(self): + q = self.queue_class() # type: queues.Queue + q.put(0) + with self.assertRaises(TimeoutError): + yield q.join(timeout=timedelta(seconds=0.01)) + + +class PriorityQueueJoinTest(QueueJoinTest): + queue_class = queues.PriorityQueue + + @gen_test + def test_order(self): + q = self.queue_class(maxsize=2) + q.put_nowait((1, "a")) + q.put_nowait((0, "b")) + self.assertTrue(q.full()) + q.put((3, "c")) + q.put((2, "d")) + self.assertEqual((0, "b"), q.get_nowait()) + self.assertEqual((1, "a"), (yield q.get())) + self.assertEqual((2, "d"), q.get_nowait()) + self.assertEqual((3, "c"), (yield q.get())) + self.assertTrue(q.empty()) + + +class LifoQueueJoinTest(QueueJoinTest): + queue_class = queues.LifoQueue + + @gen_test + def test_order(self): + q = self.queue_class(maxsize=2) + q.put_nowait(1) + q.put_nowait(0) + self.assertTrue(q.full()) + q.put(3) + q.put(2) + self.assertEqual(3, q.get_nowait()) + self.assertEqual(2, (yield q.get())) + self.assertEqual(0, q.get_nowait()) + self.assertEqual(1, (yield q.get())) + self.assertTrue(q.empty()) + + +class ProducerConsumerTest(AsyncTestCase): + @gen_test + def test_producer_consumer(self): + q = queues.Queue(maxsize=3) # type: queues.Queue[int] + history = [] + + # We don't yield between get() and task_done(), so get() must wait for + # the next tick. Otherwise we'd immediately call task_done and unblock + # join() before q.put() resumes, and we'd only process the first four + # items. + @gen.coroutine + def consumer(): + while True: + history.append((yield q.get())) + q.task_done() + + @gen.coroutine + def producer(): + for item in range(10): + yield q.put(item) + + consumer() + yield producer() + yield q.join() + self.assertEqual(list(range(10)), history) + + +if __name__ == "__main__": + unittest.main() diff --git a/venv/lib/python3.9/site-packages/tornado/test/resolve_test_helper.py b/venv/lib/python3.9/site-packages/tornado/test/resolve_test_helper.py new file mode 100644 index 00000000..b720a411 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/resolve_test_helper.py @@ -0,0 +1,10 @@ +from tornado.ioloop import IOLoop +from tornado.netutil import ThreadedResolver + +# When this module is imported, it runs getaddrinfo on a thread. Since +# the hostname is unicode, getaddrinfo attempts to import encodings.idna +# but blocks on the import lock. Verify that ThreadedResolver avoids +# this deadlock. + +resolver = ThreadedResolver() +IOLoop.current().run_sync(lambda: resolver.resolve("localhost", 80)) diff --git a/venv/lib/python3.9/site-packages/tornado/test/routing_test.py b/venv/lib/python3.9/site-packages/tornado/test/routing_test.py new file mode 100644 index 00000000..6e02697e --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/routing_test.py @@ -0,0 +1,276 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from tornado.httputil import ( + HTTPHeaders, + HTTPMessageDelegate, + HTTPServerConnectionDelegate, + ResponseStartLine, +) +from tornado.routing import ( + HostMatches, + PathMatches, + ReversibleRouter, + Router, + Rule, + RuleRouter, +) +from tornado.testing import AsyncHTTPTestCase +from tornado.web import Application, HTTPError, RequestHandler +from tornado.wsgi import WSGIContainer + +import typing # noqa: F401 + + +class BasicRouter(Router): + def find_handler(self, request, **kwargs): + class MessageDelegate(HTTPMessageDelegate): + def __init__(self, connection): + self.connection = connection + + def finish(self): + self.connection.write_headers( + ResponseStartLine("HTTP/1.1", 200, "OK"), + HTTPHeaders({"Content-Length": "2"}), + b"OK", + ) + self.connection.finish() + + return MessageDelegate(request.connection) + + +class BasicRouterTestCase(AsyncHTTPTestCase): + def get_app(self): + return BasicRouter() + + def test_basic_router(self): + response = self.fetch("/any_request") + self.assertEqual(response.body, b"OK") + + +resources = {} # type: typing.Dict[str, bytes] + + +class GetResource(RequestHandler): + def get(self, path): + if path not in resources: + raise HTTPError(404) + + self.finish(resources[path]) + + +class PostResource(RequestHandler): + def post(self, path): + resources[path] = self.request.body + + +class HTTPMethodRouter(Router): + def __init__(self, app): + self.app = app + + def find_handler(self, request, **kwargs): + handler = GetResource if request.method == "GET" else PostResource + return self.app.get_handler_delegate(request, handler, path_args=[request.path]) + + +class HTTPMethodRouterTestCase(AsyncHTTPTestCase): + def get_app(self): + return HTTPMethodRouter(Application()) + + def test_http_method_router(self): + response = self.fetch("/post_resource", method="POST", body="data") + self.assertEqual(response.code, 200) + + response = self.fetch("/get_resource") + self.assertEqual(response.code, 404) + + response = self.fetch("/post_resource") + self.assertEqual(response.code, 200) + self.assertEqual(response.body, b"data") + + +def _get_named_handler(handler_name): + class Handler(RequestHandler): + def get(self, *args, **kwargs): + if self.application.settings.get("app_name") is not None: + self.write(self.application.settings["app_name"] + ": ") + + self.finish(handler_name + ": " + self.reverse_url(handler_name)) + + return Handler + + +FirstHandler = _get_named_handler("first_handler") +SecondHandler = _get_named_handler("second_handler") + + +class CustomRouter(ReversibleRouter): + def __init__(self): + super().__init__() + self.routes = {} # type: typing.Dict[str, typing.Any] + + def add_routes(self, routes): + self.routes.update(routes) + + def find_handler(self, request, **kwargs): + if request.path in self.routes: + app, handler = self.routes[request.path] + return app.get_handler_delegate(request, handler) + + def reverse_url(self, name, *args): + handler_path = "/" + name + return handler_path if handler_path in self.routes else None + + +class CustomRouterTestCase(AsyncHTTPTestCase): + def get_app(self): + router = CustomRouter() + + class CustomApplication(Application): + def reverse_url(self, name, *args): + return router.reverse_url(name, *args) + + app1 = CustomApplication(app_name="app1") + app2 = CustomApplication(app_name="app2") + + router.add_routes( + { + "/first_handler": (app1, FirstHandler), + "/second_handler": (app2, SecondHandler), + "/first_handler_second_app": (app2, FirstHandler), + } + ) + + return router + + def test_custom_router(self): + response = self.fetch("/first_handler") + self.assertEqual(response.body, b"app1: first_handler: /first_handler") + response = self.fetch("/second_handler") + self.assertEqual(response.body, b"app2: second_handler: /second_handler") + response = self.fetch("/first_handler_second_app") + self.assertEqual(response.body, b"app2: first_handler: /first_handler") + + +class ConnectionDelegate(HTTPServerConnectionDelegate): + def start_request(self, server_conn, request_conn): + class MessageDelegate(HTTPMessageDelegate): + def __init__(self, connection): + self.connection = connection + + def finish(self): + response_body = b"OK" + self.connection.write_headers( + ResponseStartLine("HTTP/1.1", 200, "OK"), + HTTPHeaders({"Content-Length": str(len(response_body))}), + ) + self.connection.write(response_body) + self.connection.finish() + + return MessageDelegate(request_conn) + + +class RuleRouterTest(AsyncHTTPTestCase): + def get_app(self): + app = Application() + + def request_callable(request): + request.connection.write_headers( + ResponseStartLine("HTTP/1.1", 200, "OK"), + HTTPHeaders({"Content-Length": "2"}), + ) + request.connection.write(b"OK") + request.connection.finish() + + router = CustomRouter() + router.add_routes( + {"/nested_handler": (app, _get_named_handler("nested_handler"))} + ) + + app.add_handlers( + ".*", + [ + ( + HostMatches("www.example.com"), + [ + ( + PathMatches("/first_handler"), + "tornado.test.routing_test.SecondHandler", + {}, + "second_handler", + ) + ], + ), + Rule(PathMatches("/.*handler"), router), + Rule(PathMatches("/first_handler"), FirstHandler, name="first_handler"), + Rule(PathMatches("/request_callable"), request_callable), + ("/connection_delegate", ConnectionDelegate()), + ], + ) + + return app + + def test_rule_based_router(self): + response = self.fetch("/first_handler") + self.assertEqual(response.body, b"first_handler: /first_handler") + + response = self.fetch("/first_handler", headers={"Host": "www.example.com"}) + self.assertEqual(response.body, b"second_handler: /first_handler") + + response = self.fetch("/nested_handler") + self.assertEqual(response.body, b"nested_handler: /nested_handler") + + response = self.fetch("/nested_not_found_handler") + self.assertEqual(response.code, 404) + + response = self.fetch("/connection_delegate") + self.assertEqual(response.body, b"OK") + + response = self.fetch("/request_callable") + self.assertEqual(response.body, b"OK") + + response = self.fetch("/404") + self.assertEqual(response.code, 404) + + +class WSGIContainerTestCase(AsyncHTTPTestCase): + def get_app(self): + wsgi_app = WSGIContainer(self.wsgi_app) + + class Handler(RequestHandler): + def get(self, *args, **kwargs): + self.finish(self.reverse_url("tornado")) + + return RuleRouter( + [ + ( + PathMatches("/tornado.*"), + Application([(r"/tornado/test", Handler, {}, "tornado")]), + ), + (PathMatches("/wsgi"), wsgi_app), + ] + ) + + def wsgi_app(self, environ, start_response): + start_response("200 OK", []) + return [b"WSGI"] + + def test_wsgi_container(self): + response = self.fetch("/tornado/test") + self.assertEqual(response.body, b"/tornado/test") + + response = self.fetch("/wsgi") + self.assertEqual(response.body, b"WSGI") + + def test_delegate_not_found(self): + response = self.fetch("/404") + self.assertEqual(response.code, 404) diff --git a/venv/lib/python3.9/site-packages/tornado/test/runtests.py b/venv/lib/python3.9/site-packages/tornado/test/runtests.py new file mode 100644 index 00000000..6075b1e2 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/runtests.py @@ -0,0 +1,241 @@ +from functools import reduce +import gc +import io +import locale # system locale module, not tornado.locale +import logging +import operator +import textwrap +import sys +import unittest +import warnings + +from tornado.httpclient import AsyncHTTPClient +from tornado.httpserver import HTTPServer +from tornado.netutil import Resolver +from tornado.options import define, add_parse_callback, options + + +TEST_MODULES = [ + "tornado.httputil.doctests", + "tornado.iostream.doctests", + "tornado.util.doctests", + "tornado.test.asyncio_test", + "tornado.test.auth_test", + "tornado.test.autoreload_test", + "tornado.test.concurrent_test", + "tornado.test.curl_httpclient_test", + "tornado.test.escape_test", + "tornado.test.gen_test", + "tornado.test.http1connection_test", + "tornado.test.httpclient_test", + "tornado.test.httpserver_test", + "tornado.test.httputil_test", + "tornado.test.import_test", + "tornado.test.ioloop_test", + "tornado.test.iostream_test", + "tornado.test.locale_test", + "tornado.test.locks_test", + "tornado.test.netutil_test", + "tornado.test.log_test", + "tornado.test.options_test", + "tornado.test.process_test", + "tornado.test.queues_test", + "tornado.test.routing_test", + "tornado.test.simple_httpclient_test", + "tornado.test.tcpclient_test", + "tornado.test.tcpserver_test", + "tornado.test.template_test", + "tornado.test.testing_test", + "tornado.test.twisted_test", + "tornado.test.util_test", + "tornado.test.web_test", + "tornado.test.websocket_test", + "tornado.test.wsgi_test", +] + + +def all(): + return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES) + + +def test_runner_factory(stderr): + class TornadoTextTestRunner(unittest.TextTestRunner): + def __init__(self, *args, **kwargs): + kwargs["stream"] = stderr + super().__init__(*args, **kwargs) + + def run(self, test): + result = super().run(test) + if result.skipped: + skip_reasons = set(reason for (test, reason) in result.skipped) + self.stream.write( # type: ignore + textwrap.fill( + "Some tests were skipped because: %s" + % ", ".join(sorted(skip_reasons)) + ) + ) + self.stream.write("\n") # type: ignore + return result + + return TornadoTextTestRunner + + +class LogCounter(logging.Filter): + """Counts the number of WARNING or higher log records.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.info_count = self.warning_count = self.error_count = 0 + + def filter(self, record): + if record.levelno >= logging.ERROR: + self.error_count += 1 + elif record.levelno >= logging.WARNING: + self.warning_count += 1 + elif record.levelno >= logging.INFO: + self.info_count += 1 + return True + + +class CountingStderr(io.IOBase): + def __init__(self, real): + self.real = real + self.byte_count = 0 + + def write(self, data): + self.byte_count += len(data) + return self.real.write(data) + + def flush(self): + return self.real.flush() + + +def main(): + # Be strict about most warnings (This is set in our test running + # scripts to catch import-time warnings, but set it again here to + # be sure). This also turns on warnings that are ignored by + # default, including DeprecationWarnings and python 3.2's + # ResourceWarnings. + warnings.filterwarnings("error") + # setuptools sometimes gives ImportWarnings about things that are on + # sys.path even if they're not being used. + warnings.filterwarnings("ignore", category=ImportWarning) + # Tornado generally shouldn't use anything deprecated, but some of + # our dependencies do (last match wins). + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("error", category=DeprecationWarning, module=r"tornado\..*") + warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + warnings.filterwarnings( + "error", category=PendingDeprecationWarning, module=r"tornado\..*" + ) + # The unittest module is aggressive about deprecating redundant methods, + # leaving some without non-deprecated spellings that work on both + # 2.7 and 3.2 + warnings.filterwarnings( + "ignore", category=DeprecationWarning, message="Please use assert.* instead" + ) + warnings.filterwarnings( + "ignore", + category=PendingDeprecationWarning, + message="Please use assert.* instead", + ) + # Twisted 15.0.0 triggers some warnings on py3 with -bb. + warnings.filterwarnings("ignore", category=BytesWarning, module=r"twisted\..*") + if (3,) < sys.version_info < (3, 6): + # Prior to 3.6, async ResourceWarnings were rather noisy + # and even + # `python3.4 -W error -c 'import asyncio; asyncio.get_event_loop()'` + # would generate a warning. + warnings.filterwarnings( + "ignore", category=ResourceWarning, module=r"asyncio\..*" + ) + # This deprecation warning is introduced in Python 3.8 and is + # triggered by pycurl. Unforunately, because it is raised in the C + # layer it can't be filtered by module and we must match the + # message text instead (Tornado's C module uses PY_SSIZE_T_CLEAN + # so it's not at risk of running into this issue). + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="PY_SSIZE_T_CLEAN will be required", + ) + + logging.getLogger("tornado.access").setLevel(logging.CRITICAL) + + define( + "httpclient", + type=str, + default=None, + callback=lambda s: AsyncHTTPClient.configure( + s, defaults=dict(allow_ipv6=False) + ), + ) + define("httpserver", type=str, default=None, callback=HTTPServer.configure) + define("resolver", type=str, default=None, callback=Resolver.configure) + define( + "debug_gc", + type=str, + multiple=True, + help="A comma-separated list of gc module debug constants, " + "e.g. DEBUG_STATS or DEBUG_COLLECTABLE,DEBUG_OBJECTS", + callback=lambda values: gc.set_debug( + reduce(operator.or_, (getattr(gc, v) for v in values)) + ), + ) + define( + "fail-if-logs", + default=True, + help="If true, fail the tests if any log output is produced (unless captured by ExpectLog)", + ) + + def set_locale(x): + locale.setlocale(locale.LC_ALL, x) + + define("locale", type=str, default=None, callback=set_locale) + + log_counter = LogCounter() + add_parse_callback(lambda: logging.getLogger().handlers[0].addFilter(log_counter)) + + # Certain errors (especially "unclosed resource" errors raised in + # destructors) go directly to stderr instead of logging. Count + # anything written by anything but the test runner as an error. + orig_stderr = sys.stderr + counting_stderr = CountingStderr(orig_stderr) + sys.stderr = counting_stderr # type: ignore + + import tornado.testing + + kwargs = {} + + # HACK: unittest.main will make its own changes to the warning + # configuration, which may conflict with the settings above + # or command-line flags like -bb. Passing warnings=False + # suppresses this behavior, although this looks like an implementation + # detail. http://bugs.python.org/issue15626 + kwargs["warnings"] = False + + kwargs["testRunner"] = test_runner_factory(orig_stderr) + try: + tornado.testing.main(**kwargs) + finally: + # The tests should run clean; consider it a failure if they + # logged anything at info level or above. + if ( + log_counter.info_count > 0 + or log_counter.warning_count > 0 + or log_counter.error_count > 0 + or counting_stderr.byte_count > 0 + ): + logging.error( + "logged %d infos, %d warnings, %d errors, and %d bytes to stderr", + log_counter.info_count, + log_counter.warning_count, + log_counter.error_count, + counting_stderr.byte_count, + ) + if options.fail_if_logs: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/venv/lib/python3.9/site-packages/tornado/test/simple_httpclient_test.py b/venv/lib/python3.9/site-packages/tornado/test/simple_httpclient_test.py new file mode 100644 index 00000000..62bd4830 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/simple_httpclient_test.py @@ -0,0 +1,837 @@ +import collections +from contextlib import closing +import errno +import logging +import os +import re +import socket +import ssl +import sys +import typing # noqa: F401 + +from tornado.escape import to_unicode, utf8 +from tornado import gen, version +from tornado.httpclient import AsyncHTTPClient +from tornado.httputil import HTTPHeaders, ResponseStartLine +from tornado.ioloop import IOLoop +from tornado.iostream import UnsatisfiableReadError +from tornado.locks import Event +from tornado.log import gen_log +from tornado.netutil import Resolver, bind_sockets +from tornado.simple_httpclient import ( + SimpleAsyncHTTPClient, + HTTPStreamClosedError, + HTTPTimeoutError, +) +from tornado.test.httpclient_test import ( + ChunkHandler, + CountdownHandler, + HelloWorldHandler, + RedirectHandler, + UserAgentHandler, +) +from tornado.test import httpclient_test +from tornado.testing import ( + AsyncHTTPTestCase, + AsyncHTTPSTestCase, + AsyncTestCase, + ExpectLog, + gen_test, +) +from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port +from tornado.web import RequestHandler, Application, url, stream_request_body + + +class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase): + def get_http_client(self): + client = SimpleAsyncHTTPClient(force_instance=True) + self.assertTrue(isinstance(client, SimpleAsyncHTTPClient)) + return client + + +class TriggerHandler(RequestHandler): + def initialize(self, queue, wake_callback): + self.queue = queue + self.wake_callback = wake_callback + + @gen.coroutine + def get(self): + logging.debug("queuing trigger") + event = Event() + self.queue.append(event.set) + if self.get_argument("wake", "true") == "true": + self.wake_callback() + yield event.wait() + + +class ContentLengthHandler(RequestHandler): + def get(self): + self.stream = self.detach() + IOLoop.current().spawn_callback(self.write_response) + + @gen.coroutine + def write_response(self): + yield self.stream.write( + utf8( + "HTTP/1.0 200 OK\r\nContent-Length: %s\r\n\r\nok" + % self.get_argument("value") + ) + ) + self.stream.close() + + +class HeadHandler(RequestHandler): + def head(self): + self.set_header("Content-Length", "7") + + +class OptionsHandler(RequestHandler): + def options(self): + self.set_header("Access-Control-Allow-Origin", "*") + self.write("ok") + + +class NoContentHandler(RequestHandler): + def get(self): + self.set_status(204) + self.finish() + + +class SeeOtherPostHandler(RequestHandler): + def post(self): + redirect_code = int(self.request.body) + assert redirect_code in (302, 303), "unexpected body %r" % self.request.body + self.set_header("Location", "/see_other_get") + self.set_status(redirect_code) + + +class SeeOtherGetHandler(RequestHandler): + def get(self): + if self.request.body: + raise Exception("unexpected body %r" % self.request.body) + self.write("ok") + + +class HostEchoHandler(RequestHandler): + def get(self): + self.write(self.request.headers["Host"]) + + +class NoContentLengthHandler(RequestHandler): + def get(self): + if self.request.version.startswith("HTTP/1"): + # Emulate the old HTTP/1.0 behavior of returning a body with no + # content-length. Tornado handles content-length at the framework + # level so we have to go around it. + stream = self.detach() + stream.write(b"HTTP/1.0 200 OK\r\n\r\n" b"hello") + stream.close() + else: + self.finish("HTTP/1 required") + + +class EchoPostHandler(RequestHandler): + def post(self): + self.write(self.request.body) + + +@stream_request_body +class RespondInPrepareHandler(RequestHandler): + def prepare(self): + self.set_status(403) + self.finish("forbidden") + + +class SimpleHTTPClientTestMixin(object): + def create_client(self, **kwargs): + raise NotImplementedError() + + def get_app(self: typing.Any): + # callable objects to finish pending /trigger requests + self.triggers = ( + collections.deque() + ) # type: typing.Deque[typing.Callable[[], None]] + return Application( + [ + url( + "/trigger", + TriggerHandler, + dict(queue=self.triggers, wake_callback=self.stop), + ), + url("/chunk", ChunkHandler), + url("/countdown/([0-9]+)", CountdownHandler, name="countdown"), + url("/hello", HelloWorldHandler), + url("/content_length", ContentLengthHandler), + url("/head", HeadHandler), + url("/options", OptionsHandler), + url("/no_content", NoContentHandler), + url("/see_other_post", SeeOtherPostHandler), + url("/see_other_get", SeeOtherGetHandler), + url("/host_echo", HostEchoHandler), + url("/no_content_length", NoContentLengthHandler), + url("/echo_post", EchoPostHandler), + url("/respond_in_prepare", RespondInPrepareHandler), + url("/redirect", RedirectHandler), + url("/user_agent", UserAgentHandler), + ], + gzip=True, + ) + + def test_singleton(self: typing.Any): + # Class "constructor" reuses objects on the same IOLoop + self.assertTrue(SimpleAsyncHTTPClient() is SimpleAsyncHTTPClient()) + # unless force_instance is used + self.assertTrue( + SimpleAsyncHTTPClient() is not SimpleAsyncHTTPClient(force_instance=True) + ) + # different IOLoops use different objects + with closing(IOLoop()) as io_loop2: + + async def make_client(): + await gen.sleep(0) + return SimpleAsyncHTTPClient() + + client1 = self.io_loop.run_sync(make_client) + client2 = io_loop2.run_sync(make_client) + self.assertTrue(client1 is not client2) + + def test_connection_limit(self: typing.Any): + with closing(self.create_client(max_clients=2)) as client: + self.assertEqual(client.max_clients, 2) + seen = [] + # Send 4 requests. Two can be sent immediately, while the others + # will be queued + for i in range(4): + + def cb(fut, i=i): + seen.append(i) + self.stop() + + client.fetch(self.get_url("/trigger")).add_done_callback(cb) + self.wait(condition=lambda: len(self.triggers) == 2) + self.assertEqual(len(client.queue), 2) + + # Finish the first two requests and let the next two through + self.triggers.popleft()() + self.triggers.popleft()() + self.wait(condition=lambda: (len(self.triggers) == 2 and len(seen) == 2)) + self.assertEqual(set(seen), set([0, 1])) + self.assertEqual(len(client.queue), 0) + + # Finish all the pending requests + self.triggers.popleft()() + self.triggers.popleft()() + self.wait(condition=lambda: len(seen) == 4) + self.assertEqual(set(seen), set([0, 1, 2, 3])) + self.assertEqual(len(self.triggers), 0) + + @gen_test + def test_redirect_connection_limit(self: typing.Any): + # following redirects should not consume additional connections + with closing(self.create_client(max_clients=1)) as client: + response = yield client.fetch(self.get_url("/countdown/3"), max_redirects=3) + response.rethrow() + + def test_max_redirects(self: typing.Any): + response = self.fetch("/countdown/5", max_redirects=3) + self.assertEqual(302, response.code) + # We requested 5, followed three redirects for 4, 3, 2, then the last + # unfollowed redirect is to 1. + self.assertTrue(response.request.url.endswith("/countdown/5")) + self.assertTrue(response.effective_url.endswith("/countdown/2")) + self.assertTrue(response.headers["Location"].endswith("/countdown/1")) + + def test_header_reuse(self: typing.Any): + # Apps may reuse a headers object if they are only passing in constant + # headers like user-agent. The header object should not be modified. + headers = HTTPHeaders({"User-Agent": "Foo"}) + self.fetch("/hello", headers=headers) + self.assertEqual(list(headers.get_all()), [("User-Agent", "Foo")]) + + def test_default_user_agent(self: typing.Any): + response = self.fetch("/user_agent", method="GET") + self.assertEqual(200, response.code) + self.assertEqual(response.body.decode(), "Tornado/{}".format(version)) + + def test_see_other_redirect(self: typing.Any): + for code in (302, 303): + response = self.fetch("/see_other_post", method="POST", body="%d" % code) + self.assertEqual(200, response.code) + self.assertTrue(response.request.url.endswith("/see_other_post")) + self.assertTrue(response.effective_url.endswith("/see_other_get")) + # request is the original request, is a POST still + self.assertEqual("POST", response.request.method) + + @skipOnTravis + @gen_test + def test_connect_timeout(self: typing.Any): + timeout = 0.1 + + cleanup_event = Event() + test = self + + class TimeoutResolver(Resolver): + async def resolve(self, *args, **kwargs): + await cleanup_event.wait() + # Return something valid so the test doesn't raise during shutdown. + return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))] + + with closing(self.create_client(resolver=TimeoutResolver())) as client: + with self.assertRaises(HTTPTimeoutError): + yield client.fetch( + self.get_url("/hello"), + connect_timeout=timeout, + request_timeout=3600, + raise_error=True, + ) + + # Let the hanging coroutine clean up after itself. We need to + # wait more than a single IOLoop iteration for the SSL case, + # which logs errors on unexpected EOF. + cleanup_event.set() + yield gen.sleep(0.2) + + @skipOnTravis + def test_request_timeout(self: typing.Any): + timeout = 0.1 + if os.name == "nt": + timeout = 0.5 + + with self.assertRaises(HTTPTimeoutError): + self.fetch("/trigger?wake=false", request_timeout=timeout, raise_error=True) + # trigger the hanging request to let it clean up after itself + self.triggers.popleft()() + self.io_loop.run_sync(lambda: gen.sleep(0)) + + @skipIfNoIPv6 + def test_ipv6(self: typing.Any): + [sock] = bind_sockets(0, "::1", family=socket.AF_INET6) + port = sock.getsockname()[1] + self.http_server.add_socket(sock) + url = "%s://[::1]:%d/hello" % (self.get_protocol(), port) + + # ipv6 is currently enabled by default but can be disabled + with self.assertRaises(Exception): + self.fetch(url, allow_ipv6=False, raise_error=True) + + response = self.fetch(url) + self.assertEqual(response.body, b"Hello world!") + + def test_multiple_content_length_accepted(self: typing.Any): + response = self.fetch("/content_length?value=2,2") + self.assertEqual(response.body, b"ok") + response = self.fetch("/content_length?value=2,%202,2") + self.assertEqual(response.body, b"ok") + + with ExpectLog( + gen_log, ".*Multiple unequal Content-Lengths", level=logging.INFO + ): + with self.assertRaises(HTTPStreamClosedError): + self.fetch("/content_length?value=2,4", raise_error=True) + with self.assertRaises(HTTPStreamClosedError): + self.fetch("/content_length?value=2,%202,3", raise_error=True) + + def test_head_request(self: typing.Any): + response = self.fetch("/head", method="HEAD") + self.assertEqual(response.code, 200) + self.assertEqual(response.headers["content-length"], "7") + self.assertFalse(response.body) + + def test_options_request(self: typing.Any): + response = self.fetch("/options", method="OPTIONS") + self.assertEqual(response.code, 200) + self.assertEqual(response.headers["content-length"], "2") + self.assertEqual(response.headers["access-control-allow-origin"], "*") + self.assertEqual(response.body, b"ok") + + def test_no_content(self: typing.Any): + response = self.fetch("/no_content") + self.assertEqual(response.code, 204) + # 204 status shouldn't have a content-length + # + # Tests with a content-length header are included below + # in HTTP204NoContentTestCase. + self.assertNotIn("Content-Length", response.headers) + + def test_host_header(self: typing.Any): + host_re = re.compile(b"^127.0.0.1:[0-9]+$") + response = self.fetch("/host_echo") + self.assertTrue(host_re.match(response.body)) + + url = self.get_url("/host_echo").replace("http://", "http://me:secret@") + response = self.fetch(url) + self.assertTrue(host_re.match(response.body), response.body) + + def test_connection_refused(self: typing.Any): + cleanup_func, port = refusing_port() + self.addCleanup(cleanup_func) + with ExpectLog(gen_log, ".*", required=False): + with self.assertRaises(socket.error) as cm: + self.fetch("http://127.0.0.1:%d/" % port, raise_error=True) + + if sys.platform != "cygwin": + # cygwin returns EPERM instead of ECONNREFUSED here + contains_errno = str(errno.ECONNREFUSED) in str(cm.exception) + if not contains_errno and hasattr(errno, "WSAECONNREFUSED"): + contains_errno = str(errno.WSAECONNREFUSED) in str( # type: ignore + cm.exception + ) + self.assertTrue(contains_errno, cm.exception) + # This is usually "Connection refused". + # On windows, strerror is broken and returns "Unknown error". + expected_message = os.strerror(errno.ECONNREFUSED) + self.assertTrue(expected_message in str(cm.exception), cm.exception) + + def test_queue_timeout(self: typing.Any): + with closing(self.create_client(max_clients=1)) as client: + # Wait for the trigger request to block, not complete. + fut1 = client.fetch(self.get_url("/trigger"), request_timeout=10) + self.wait() + with self.assertRaises(HTTPTimeoutError) as cm: + self.io_loop.run_sync( + lambda: client.fetch( + self.get_url("/hello"), connect_timeout=0.1, raise_error=True + ) + ) + + self.assertEqual(str(cm.exception), "Timeout in request queue") + self.triggers.popleft()() + self.io_loop.run_sync(lambda: fut1) + + def test_no_content_length(self: typing.Any): + response = self.fetch("/no_content_length") + if response.body == b"HTTP/1 required": + self.skipTest("requires HTTP/1.x") + else: + self.assertEqual(b"hello", response.body) + + def sync_body_producer(self, write): + write(b"1234") + write(b"5678") + + @gen.coroutine + def async_body_producer(self, write): + yield write(b"1234") + yield gen.moment + yield write(b"5678") + + def test_sync_body_producer_chunked(self: typing.Any): + response = self.fetch( + "/echo_post", method="POST", body_producer=self.sync_body_producer + ) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_sync_body_producer_content_length(self: typing.Any): + response = self.fetch( + "/echo_post", + method="POST", + body_producer=self.sync_body_producer, + headers={"Content-Length": "8"}, + ) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_async_body_producer_chunked(self: typing.Any): + response = self.fetch( + "/echo_post", method="POST", body_producer=self.async_body_producer + ) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_async_body_producer_content_length(self: typing.Any): + response = self.fetch( + "/echo_post", + method="POST", + body_producer=self.async_body_producer, + headers={"Content-Length": "8"}, + ) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_native_body_producer_chunked(self: typing.Any): + async def body_producer(write): + await write(b"1234") + import asyncio + + await asyncio.sleep(0) + await write(b"5678") + + response = self.fetch("/echo_post", method="POST", body_producer=body_producer) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_native_body_producer_content_length(self: typing.Any): + async def body_producer(write): + await write(b"1234") + import asyncio + + await asyncio.sleep(0) + await write(b"5678") + + response = self.fetch( + "/echo_post", + method="POST", + body_producer=body_producer, + headers={"Content-Length": "8"}, + ) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_100_continue(self: typing.Any): + response = self.fetch( + "/echo_post", method="POST", body=b"1234", expect_100_continue=True + ) + self.assertEqual(response.body, b"1234") + + def test_100_continue_early_response(self: typing.Any): + def body_producer(write): + raise Exception("should not be called") + + response = self.fetch( + "/respond_in_prepare", + method="POST", + body_producer=body_producer, + expect_100_continue=True, + ) + self.assertEqual(response.code, 403) + + def test_streaming_follow_redirects(self: typing.Any): + # When following redirects, header and streaming callbacks + # should only be called for the final result. + # TODO(bdarnell): this test belongs in httpclient_test instead of + # simple_httpclient_test, but it fails with the version of libcurl + # available on travis-ci. Move it when that has been upgraded + # or we have a better framework to skip tests based on curl version. + headers = [] # type: typing.List[str] + chunk_bytes = [] # type: typing.List[bytes] + self.fetch( + "/redirect?url=/hello", + header_callback=headers.append, + streaming_callback=chunk_bytes.append, + ) + chunks = list(map(to_unicode, chunk_bytes)) + self.assertEqual(chunks, ["Hello world!"]) + # Make sure we only got one set of headers. + num_start_lines = len([h for h in headers if h.startswith("HTTP/")]) + self.assertEqual(num_start_lines, 1) + + +class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase): + def setUp(self): + super().setUp() + self.http_client = self.create_client() + + def create_client(self, **kwargs): + return SimpleAsyncHTTPClient(force_instance=True, **kwargs) + + +class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase): + def setUp(self): + super().setUp() + self.http_client = self.create_client() + + def create_client(self, **kwargs): + return SimpleAsyncHTTPClient( + force_instance=True, defaults=dict(validate_cert=False), **kwargs + ) + + def test_ssl_options(self): + resp = self.fetch("/hello", ssl_options={"cert_reqs": ssl.CERT_NONE}) + self.assertEqual(resp.body, b"Hello world!") + + def test_ssl_context(self): + ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE + resp = self.fetch("/hello", ssl_options=ssl_ctx) + self.assertEqual(resp.body, b"Hello world!") + + def test_ssl_options_handshake_fail(self): + with ExpectLog(gen_log, "SSL Error|Uncaught exception", required=False): + with self.assertRaises(ssl.SSLError): + self.fetch( + "/hello", + ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED), + raise_error=True, + ) + + def test_ssl_context_handshake_fail(self): + with ExpectLog(gen_log, "SSL Error|Uncaught exception"): + # CERT_REQUIRED is set by default. + ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + with self.assertRaises(ssl.SSLError): + self.fetch("/hello", ssl_options=ctx, raise_error=True) + + def test_error_logging(self): + # No stack traces are logged for SSL errors (in this case, + # failure to validate the testing self-signed cert). + # The SSLError is exposed through ssl.SSLError. + with ExpectLog(gen_log, ".*") as expect_log: + with self.assertRaises(ssl.SSLError): + self.fetch("/", validate_cert=True, raise_error=True) + self.assertFalse(expect_log.logged_stack) + + +class CreateAsyncHTTPClientTestCase(AsyncTestCase): + def setUp(self): + super().setUp() + self.saved = AsyncHTTPClient._save_configuration() + + def tearDown(self): + AsyncHTTPClient._restore_configuration(self.saved) + super().tearDown() + + def test_max_clients(self): + AsyncHTTPClient.configure(SimpleAsyncHTTPClient) + with closing(AsyncHTTPClient(force_instance=True)) as client: + self.assertEqual(client.max_clients, 10) # type: ignore + with closing(AsyncHTTPClient(max_clients=11, force_instance=True)) as client: + self.assertEqual(client.max_clients, 11) # type: ignore + + # Now configure max_clients statically and try overriding it + # with each way max_clients can be passed + AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12) + with closing(AsyncHTTPClient(force_instance=True)) as client: + self.assertEqual(client.max_clients, 12) # type: ignore + with closing(AsyncHTTPClient(max_clients=13, force_instance=True)) as client: + self.assertEqual(client.max_clients, 13) # type: ignore + with closing(AsyncHTTPClient(max_clients=14, force_instance=True)) as client: + self.assertEqual(client.max_clients, 14) # type: ignore + + +class HTTP100ContinueTestCase(AsyncHTTPTestCase): + def respond_100(self, request): + self.http1 = request.version.startswith("HTTP/1.") + if not self.http1: + request.connection.write_headers( + ResponseStartLine("", 200, "OK"), HTTPHeaders() + ) + request.connection.finish() + return + self.request = request + fut = self.request.connection.stream.write(b"HTTP/1.1 100 CONTINUE\r\n\r\n") + fut.add_done_callback(self.respond_200) + + def respond_200(self, fut): + fut.result() + fut = self.request.connection.stream.write( + b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA" + ) + fut.add_done_callback(lambda f: self.request.connection.stream.close()) + + def get_app(self): + # Not a full Application, but works as an HTTPServer callback + return self.respond_100 + + def test_100_continue(self): + res = self.fetch("/") + if not self.http1: + self.skipTest("requires HTTP/1.x") + self.assertEqual(res.body, b"A") + + +class HTTP204NoContentTestCase(AsyncHTTPTestCase): + def respond_204(self, request): + self.http1 = request.version.startswith("HTTP/1.") + if not self.http1: + # Close the request cleanly in HTTP/2; it will be skipped anyway. + request.connection.write_headers( + ResponseStartLine("", 200, "OK"), HTTPHeaders() + ) + request.connection.finish() + return + + # A 204 response never has a body, even if doesn't have a content-length + # (which would otherwise mean read-until-close). We simulate here a + # server that sends no content length and does not close the connection. + # + # Tests of a 204 response with no Content-Length header are included + # in SimpleHTTPClientTestMixin. + stream = request.connection.detach() + stream.write(b"HTTP/1.1 204 No content\r\n") + if request.arguments.get("error", [False])[-1]: + stream.write(b"Content-Length: 5\r\n") + else: + stream.write(b"Content-Length: 0\r\n") + stream.write(b"\r\n") + stream.close() + + def get_app(self): + return self.respond_204 + + def test_204_no_content(self): + resp = self.fetch("/") + if not self.http1: + self.skipTest("requires HTTP/1.x") + self.assertEqual(resp.code, 204) + self.assertEqual(resp.body, b"") + + def test_204_invalid_content_length(self): + # 204 status with non-zero content length is malformed + with ExpectLog( + gen_log, ".*Response with code 204 should not have body", level=logging.INFO + ): + with self.assertRaises(HTTPStreamClosedError): + self.fetch("/?error=1", raise_error=True) + if not self.http1: + self.skipTest("requires HTTP/1.x") + if self.http_client.configured_class != SimpleAsyncHTTPClient: + self.skipTest("curl client accepts invalid headers") + + +class HostnameMappingTestCase(AsyncHTTPTestCase): + def setUp(self): + super().setUp() + self.http_client = SimpleAsyncHTTPClient( + hostname_mapping={ + "www.example.com": "127.0.0.1", + ("foo.example.com", 8000): ("127.0.0.1", self.get_http_port()), + } + ) + + def get_app(self): + return Application([url("/hello", HelloWorldHandler)]) + + def test_hostname_mapping(self): + response = self.fetch("http://www.example.com:%d/hello" % self.get_http_port()) + response.rethrow() + self.assertEqual(response.body, b"Hello world!") + + def test_port_mapping(self): + response = self.fetch("http://foo.example.com:8000/hello") + response.rethrow() + self.assertEqual(response.body, b"Hello world!") + + +class ResolveTimeoutTestCase(AsyncHTTPTestCase): + def setUp(self): + self.cleanup_event = Event() + test = self + + # Dummy Resolver subclass that never finishes. + class BadResolver(Resolver): + @gen.coroutine + def resolve(self, *args, **kwargs): + yield test.cleanup_event.wait() + # Return something valid so the test doesn't raise during cleanup. + return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))] + + super().setUp() + self.http_client = SimpleAsyncHTTPClient(resolver=BadResolver()) + + def get_app(self): + return Application([url("/hello", HelloWorldHandler)]) + + def test_resolve_timeout(self): + with self.assertRaises(HTTPTimeoutError): + self.fetch("/hello", connect_timeout=0.1, raise_error=True) + + # Let the hanging coroutine clean up after itself + self.cleanup_event.set() + self.io_loop.run_sync(lambda: gen.sleep(0)) + + +class MaxHeaderSizeTest(AsyncHTTPTestCase): + def get_app(self): + class SmallHeaders(RequestHandler): + def get(self): + self.set_header("X-Filler", "a" * 100) + self.write("ok") + + class LargeHeaders(RequestHandler): + def get(self): + self.set_header("X-Filler", "a" * 1000) + self.write("ok") + + return Application([("/small", SmallHeaders), ("/large", LargeHeaders)]) + + def get_http_client(self): + return SimpleAsyncHTTPClient(max_header_size=1024) + + def test_small_headers(self): + response = self.fetch("/small") + response.rethrow() + self.assertEqual(response.body, b"ok") + + def test_large_headers(self): + with ExpectLog(gen_log, "Unsatisfiable read", level=logging.INFO): + with self.assertRaises(UnsatisfiableReadError): + self.fetch("/large", raise_error=True) + + +class MaxBodySizeTest(AsyncHTTPTestCase): + def get_app(self): + class SmallBody(RequestHandler): + def get(self): + self.write("a" * 1024 * 64) + + class LargeBody(RequestHandler): + def get(self): + self.write("a" * 1024 * 100) + + return Application([("/small", SmallBody), ("/large", LargeBody)]) + + def get_http_client(self): + return SimpleAsyncHTTPClient(max_body_size=1024 * 64) + + def test_small_body(self): + response = self.fetch("/small") + response.rethrow() + self.assertEqual(response.body, b"a" * 1024 * 64) + + def test_large_body(self): + with ExpectLog( + gen_log, + "Malformed HTTP message from None: Content-Length too long", + level=logging.INFO, + ): + with self.assertRaises(HTTPStreamClosedError): + self.fetch("/large", raise_error=True) + + +class MaxBufferSizeTest(AsyncHTTPTestCase): + def get_app(self): + class LargeBody(RequestHandler): + def get(self): + self.write("a" * 1024 * 100) + + return Application([("/large", LargeBody)]) + + def get_http_client(self): + # 100KB body with 64KB buffer + return SimpleAsyncHTTPClient( + max_body_size=1024 * 100, max_buffer_size=1024 * 64 + ) + + def test_large_body(self): + response = self.fetch("/large") + response.rethrow() + self.assertEqual(response.body, b"a" * 1024 * 100) + + +class ChunkedWithContentLengthTest(AsyncHTTPTestCase): + def get_app(self): + class ChunkedWithContentLength(RequestHandler): + def get(self): + # Add an invalid Transfer-Encoding to the response + self.set_header("Transfer-Encoding", "chunked") + self.write("Hello world") + + return Application([("/chunkwithcl", ChunkedWithContentLength)]) + + def get_http_client(self): + return SimpleAsyncHTTPClient() + + def test_chunked_with_content_length(self): + # Make sure the invalid headers are detected + with ExpectLog( + gen_log, + ( + "Malformed HTTP message from None: Response " + "with both Transfer-Encoding and Content-Length" + ), + level=logging.INFO, + ): + with self.assertRaises(HTTPStreamClosedError): + self.fetch("/chunkwithcl", raise_error=True) diff --git a/venv/lib/python3.9/site-packages/tornado/test/static/dir/index.html b/venv/lib/python3.9/site-packages/tornado/test/static/dir/index.html new file mode 100644 index 00000000..e1cd9d8a --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/static/dir/index.html @@ -0,0 +1 @@ +this is the index diff --git a/venv/lib/python3.9/site-packages/tornado/test/static/robots.txt b/venv/lib/python3.9/site-packages/tornado/test/static/robots.txt new file mode 100644 index 00000000..1f53798b --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/static/robots.txt @@ -0,0 +1,2 @@ +User-agent: * +Disallow: / diff --git a/venv/lib/python3.9/site-packages/tornado/test/static/sample.xml b/venv/lib/python3.9/site-packages/tornado/test/static/sample.xml new file mode 100644 index 00000000..35ea0e29 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/static/sample.xml @@ -0,0 +1,23 @@ +<?xml version="1.0"?> +<data> + <country name="Liechtenstein"> + <rank>1</rank> + <year>2008</year> + <gdppc>141100</gdppc> + <neighbor name="Austria" direction="E"/> + <neighbor name="Switzerland" direction="W"/> + </country> + <country name="Singapore"> + <rank>4</rank> + <year>2011</year> + <gdppc>59900</gdppc> + <neighbor name="Malaysia" direction="N"/> + </country> + <country name="Panama"> + <rank>68</rank> + <year>2011</year> + <gdppc>13600</gdppc> + <neighbor name="Costa Rica" direction="W"/> + <neighbor name="Colombia" direction="E"/> + </country> +</data> diff --git a/venv/lib/python3.9/site-packages/tornado/test/static/sample.xml.bz2 b/venv/lib/python3.9/site-packages/tornado/test/static/sample.xml.bz2 Binary files differnew file mode 100644 index 00000000..44dc6633 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/static/sample.xml.bz2 diff --git a/venv/lib/python3.9/site-packages/tornado/test/static/sample.xml.gz b/venv/lib/python3.9/site-packages/tornado/test/static/sample.xml.gz Binary files differnew file mode 100644 index 00000000..c0fd5e6f --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/static/sample.xml.gz diff --git a/venv/lib/python3.9/site-packages/tornado/test/static_foo.txt b/venv/lib/python3.9/site-packages/tornado/test/static_foo.txt new file mode 100644 index 00000000..bdb44f39 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/static_foo.txt @@ -0,0 +1,2 @@ +This file should not be served by StaticFileHandler even though +its name starts with "static". diff --git a/venv/lib/python3.9/site-packages/tornado/test/tcpclient_test.py b/venv/lib/python3.9/site-packages/tornado/test/tcpclient_test.py new file mode 100644 index 00000000..ecfc82d6 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/tcpclient_test.py @@ -0,0 +1,439 @@ +# +# Copyright 2014 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from contextlib import closing +import getpass +import os +import socket +import unittest + +from tornado.concurrent import Future +from tornado.netutil import bind_sockets, Resolver +from tornado.queues import Queue +from tornado.tcpclient import TCPClient, _Connector +from tornado.tcpserver import TCPServer +from tornado.testing import AsyncTestCase, gen_test +from tornado.test.util import skipIfNoIPv6, refusing_port, skipIfNonUnix +from tornado.gen import TimeoutError + +import typing + +if typing.TYPE_CHECKING: + from tornado.iostream import IOStream # noqa: F401 + from typing import List, Dict, Tuple # noqa: F401 + +# Fake address families for testing. Used in place of AF_INET +# and AF_INET6 because some installations do not have AF_INET6. +AF1, AF2 = 1, 2 + + +class TestTCPServer(TCPServer): + def __init__(self, family): + super().__init__() + self.streams = [] # type: List[IOStream] + self.queue = Queue() # type: Queue[IOStream] + sockets = bind_sockets(0, "localhost", family) + self.add_sockets(sockets) + self.port = sockets[0].getsockname()[1] + + def handle_stream(self, stream, address): + self.streams.append(stream) + self.queue.put(stream) + + def stop(self): + super().stop() + for stream in self.streams: + stream.close() + + +class TCPClientTest(AsyncTestCase): + def setUp(self): + super().setUp() + self.server = None + self.client = TCPClient() + + def start_server(self, family): + if family == socket.AF_UNSPEC and "TRAVIS" in os.environ: + self.skipTest("dual-stack servers often have port conflicts on travis") + self.server = TestTCPServer(family) + return self.server.port + + def stop_server(self): + if self.server is not None: + self.server.stop() + self.server = None + + def tearDown(self): + self.client.close() + self.stop_server() + super().tearDown() + + def skipIfLocalhostV4(self): + # The port used here doesn't matter, but some systems require it + # to be non-zero if we do not also pass AI_PASSIVE. + addrinfo = self.io_loop.run_sync(lambda: Resolver().resolve("localhost", 80)) + families = set(addr[0] for addr in addrinfo) + if socket.AF_INET6 not in families: + self.skipTest("localhost does not resolve to ipv6") + + @gen_test + def do_test_connect(self, family, host, source_ip=None, source_port=None): + port = self.start_server(family) + stream = yield self.client.connect( + host, + port, + source_ip=source_ip, + source_port=source_port, + af=family, + ) + assert self.server is not None + server_stream = yield self.server.queue.get() + with closing(stream): + stream.write(b"hello") + data = yield server_stream.read_bytes(5) + self.assertEqual(data, b"hello") + + def test_connect_ipv4_ipv4(self): + self.do_test_connect(socket.AF_INET, "127.0.0.1") + + def test_connect_ipv4_dual(self): + self.do_test_connect(socket.AF_INET, "localhost") + + @skipIfNoIPv6 + def test_connect_ipv6_ipv6(self): + self.skipIfLocalhostV4() + self.do_test_connect(socket.AF_INET6, "::1") + + @skipIfNoIPv6 + def test_connect_ipv6_dual(self): + self.skipIfLocalhostV4() + if Resolver.configured_class().__name__.endswith("TwistedResolver"): + self.skipTest("TwistedResolver does not support multiple addresses") + self.do_test_connect(socket.AF_INET6, "localhost") + + def test_connect_unspec_ipv4(self): + self.do_test_connect(socket.AF_UNSPEC, "127.0.0.1") + + @skipIfNoIPv6 + def test_connect_unspec_ipv6(self): + self.skipIfLocalhostV4() + self.do_test_connect(socket.AF_UNSPEC, "::1") + + def test_connect_unspec_dual(self): + self.do_test_connect(socket.AF_UNSPEC, "localhost") + + @gen_test + def test_refused_ipv4(self): + cleanup_func, port = refusing_port() + self.addCleanup(cleanup_func) + with self.assertRaises(IOError): + yield self.client.connect("127.0.0.1", port) + + def test_source_ip_fail(self): + """Fail when trying to use the source IP Address '8.8.8.8'.""" + self.assertRaises( + socket.error, + self.do_test_connect, + socket.AF_INET, + "127.0.0.1", + source_ip="8.8.8.8", + ) + + def test_source_ip_success(self): + """Success when trying to use the source IP Address '127.0.0.1'.""" + self.do_test_connect(socket.AF_INET, "127.0.0.1", source_ip="127.0.0.1") + + @skipIfNonUnix + def test_source_port_fail(self): + """Fail when trying to use source port 1.""" + if getpass.getuser() == "root": + # Root can use any port so we can't easily force this to fail. + # This is mainly relevant for docker. + self.skipTest("running as root") + self.assertRaises( + socket.error, + self.do_test_connect, + socket.AF_INET, + "127.0.0.1", + source_port=1, + ) + + @gen_test + def test_connect_timeout(self): + timeout = 0.05 + + class TimeoutResolver(Resolver): + def resolve(self, *args, **kwargs): + return Future() # never completes + + with self.assertRaises(TimeoutError): + yield TCPClient(resolver=TimeoutResolver()).connect( + "1.2.3.4", 12345, timeout=timeout + ) + + +class TestConnectorSplit(unittest.TestCase): + def test_one_family(self): + # These addresses aren't in the right format, but split doesn't care. + primary, secondary = _Connector.split([(AF1, "a"), (AF1, "b")]) + self.assertEqual(primary, [(AF1, "a"), (AF1, "b")]) + self.assertEqual(secondary, []) + + def test_mixed(self): + primary, secondary = _Connector.split( + [(AF1, "a"), (AF2, "b"), (AF1, "c"), (AF2, "d")] + ) + self.assertEqual(primary, [(AF1, "a"), (AF1, "c")]) + self.assertEqual(secondary, [(AF2, "b"), (AF2, "d")]) + + +class ConnectorTest(AsyncTestCase): + class FakeStream(object): + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + def setUp(self): + super().setUp() + self.connect_futures = ( + {} + ) # type: Dict[Tuple[int, typing.Any], Future[ConnectorTest.FakeStream]] + self.streams = {} # type: Dict[typing.Any, ConnectorTest.FakeStream] + self.addrinfo = [(AF1, "a"), (AF1, "b"), (AF2, "c"), (AF2, "d")] + + def tearDown(self): + # Unless explicitly checked (and popped) in the test, we shouldn't + # be closing any streams + for stream in self.streams.values(): + self.assertFalse(stream.closed) + super().tearDown() + + def create_stream(self, af, addr): + stream = ConnectorTest.FakeStream() + self.streams[addr] = stream + future = Future() # type: Future[ConnectorTest.FakeStream] + self.connect_futures[(af, addr)] = future + return stream, future + + def assert_pending(self, *keys): + self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys)) + + def resolve_connect(self, af, addr, success): + future = self.connect_futures.pop((af, addr)) + if success: + future.set_result(self.streams[addr]) + else: + self.streams.pop(addr) + future.set_exception(IOError()) + # Run the loop to allow callbacks to be run. + self.io_loop.add_callback(self.stop) + self.wait() + + def assert_connector_streams_closed(self, conn): + for stream in conn.streams: + self.assertTrue(stream.closed) + + def start_connect(self, addrinfo): + conn = _Connector(addrinfo, self.create_stream) + # Give it a huge timeout; we'll trigger timeouts manually. + future = conn.start(3600, connect_timeout=self.io_loop.time() + 3600) + return conn, future + + def test_immediate_success(self): + conn, future = self.start_connect(self.addrinfo) + self.assertEqual(list(self.connect_futures.keys()), [(AF1, "a")]) + self.resolve_connect(AF1, "a", True) + self.assertEqual(future.result(), (AF1, "a", self.streams["a"])) + + def test_immediate_failure(self): + # Fail with just one address. + conn, future = self.start_connect([(AF1, "a")]) + self.assert_pending((AF1, "a")) + self.resolve_connect(AF1, "a", False) + self.assertRaises(IOError, future.result) + + def test_one_family_second_try(self): + conn, future = self.start_connect([(AF1, "a"), (AF1, "b")]) + self.assert_pending((AF1, "a")) + self.resolve_connect(AF1, "a", False) + self.assert_pending((AF1, "b")) + self.resolve_connect(AF1, "b", True) + self.assertEqual(future.result(), (AF1, "b", self.streams["b"])) + + def test_one_family_second_try_failure(self): + conn, future = self.start_connect([(AF1, "a"), (AF1, "b")]) + self.assert_pending((AF1, "a")) + self.resolve_connect(AF1, "a", False) + self.assert_pending((AF1, "b")) + self.resolve_connect(AF1, "b", False) + self.assertRaises(IOError, future.result) + + def test_one_family_second_try_timeout(self): + conn, future = self.start_connect([(AF1, "a"), (AF1, "b")]) + self.assert_pending((AF1, "a")) + # trigger the timeout while the first lookup is pending; + # nothing happens. + conn.on_timeout() + self.assert_pending((AF1, "a")) + self.resolve_connect(AF1, "a", False) + self.assert_pending((AF1, "b")) + self.resolve_connect(AF1, "b", True) + self.assertEqual(future.result(), (AF1, "b", self.streams["b"])) + + def test_two_families_immediate_failure(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, "a")) + self.resolve_connect(AF1, "a", False) + self.assert_pending((AF1, "b"), (AF2, "c")) + self.resolve_connect(AF1, "b", False) + self.resolve_connect(AF2, "c", True) + self.assertEqual(future.result(), (AF2, "c", self.streams["c"])) + + def test_two_families_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, "a")) + conn.on_timeout() + self.assert_pending((AF1, "a"), (AF2, "c")) + self.resolve_connect(AF2, "c", True) + self.assertEqual(future.result(), (AF2, "c", self.streams["c"])) + # resolving 'a' after the connection has completed doesn't start 'b' + self.resolve_connect(AF1, "a", False) + self.assert_pending() + + def test_success_after_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, "a")) + conn.on_timeout() + self.assert_pending((AF1, "a"), (AF2, "c")) + self.resolve_connect(AF1, "a", True) + self.assertEqual(future.result(), (AF1, "a", self.streams["a"])) + # resolving 'c' after completion closes the connection. + self.resolve_connect(AF2, "c", True) + self.assertTrue(self.streams.pop("c").closed) + + def test_all_fail(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, "a")) + conn.on_timeout() + self.assert_pending((AF1, "a"), (AF2, "c")) + self.resolve_connect(AF2, "c", False) + self.assert_pending((AF1, "a"), (AF2, "d")) + self.resolve_connect(AF2, "d", False) + # one queue is now empty + self.assert_pending((AF1, "a")) + self.resolve_connect(AF1, "a", False) + self.assert_pending((AF1, "b")) + self.assertFalse(future.done()) + self.resolve_connect(AF1, "b", False) + self.assertRaises(IOError, future.result) + + def test_one_family_timeout_after_connect_timeout(self): + conn, future = self.start_connect([(AF1, "a"), (AF1, "b")]) + self.assert_pending((AF1, "a")) + conn.on_connect_timeout() + # the connector will close all streams on connect timeout, we + # should explicitly pop the connect_future. + self.connect_futures.pop((AF1, "a")) + self.assertTrue(self.streams.pop("a").closed) + conn.on_timeout() + # if the future is set with TimeoutError, we will not iterate next + # possible address. + self.assert_pending() + self.assertEqual(len(conn.streams), 1) + self.assert_connector_streams_closed(conn) + self.assertRaises(TimeoutError, future.result) + + def test_one_family_success_before_connect_timeout(self): + conn, future = self.start_connect([(AF1, "a"), (AF1, "b")]) + self.assert_pending((AF1, "a")) + self.resolve_connect(AF1, "a", True) + conn.on_connect_timeout() + self.assert_pending() + self.assertEqual(self.streams["a"].closed, False) + # success stream will be pop + self.assertEqual(len(conn.streams), 0) + # streams in connector should be closed after connect timeout + self.assert_connector_streams_closed(conn) + self.assertEqual(future.result(), (AF1, "a", self.streams["a"])) + + def test_one_family_second_try_after_connect_timeout(self): + conn, future = self.start_connect([(AF1, "a"), (AF1, "b")]) + self.assert_pending((AF1, "a")) + self.resolve_connect(AF1, "a", False) + self.assert_pending((AF1, "b")) + conn.on_connect_timeout() + self.connect_futures.pop((AF1, "b")) + self.assertTrue(self.streams.pop("b").closed) + self.assert_pending() + self.assertEqual(len(conn.streams), 2) + self.assert_connector_streams_closed(conn) + self.assertRaises(TimeoutError, future.result) + + def test_one_family_second_try_failure_before_connect_timeout(self): + conn, future = self.start_connect([(AF1, "a"), (AF1, "b")]) + self.assert_pending((AF1, "a")) + self.resolve_connect(AF1, "a", False) + self.assert_pending((AF1, "b")) + self.resolve_connect(AF1, "b", False) + conn.on_connect_timeout() + self.assert_pending() + self.assertEqual(len(conn.streams), 2) + self.assert_connector_streams_closed(conn) + self.assertRaises(IOError, future.result) + + def test_two_family_timeout_before_connect_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, "a")) + conn.on_timeout() + self.assert_pending((AF1, "a"), (AF2, "c")) + conn.on_connect_timeout() + self.connect_futures.pop((AF1, "a")) + self.assertTrue(self.streams.pop("a").closed) + self.connect_futures.pop((AF2, "c")) + self.assertTrue(self.streams.pop("c").closed) + self.assert_pending() + self.assertEqual(len(conn.streams), 2) + self.assert_connector_streams_closed(conn) + self.assertRaises(TimeoutError, future.result) + + def test_two_family_success_after_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, "a")) + conn.on_timeout() + self.assert_pending((AF1, "a"), (AF2, "c")) + self.resolve_connect(AF1, "a", True) + # if one of streams succeed, connector will close all other streams + self.connect_futures.pop((AF2, "c")) + self.assertTrue(self.streams.pop("c").closed) + self.assert_pending() + self.assertEqual(len(conn.streams), 1) + self.assert_connector_streams_closed(conn) + self.assertEqual(future.result(), (AF1, "a", self.streams["a"])) + + def test_two_family_timeout_after_connect_timeout(self): + conn, future = self.start_connect(self.addrinfo) + self.assert_pending((AF1, "a")) + conn.on_connect_timeout() + self.connect_futures.pop((AF1, "a")) + self.assertTrue(self.streams.pop("a").closed) + self.assert_pending() + conn.on_timeout() + # if the future is set with TimeoutError, connector will not + # trigger secondary address. + self.assert_pending() + self.assertEqual(len(conn.streams), 1) + self.assert_connector_streams_closed(conn) + self.assertRaises(TimeoutError, future.result) diff --git a/venv/lib/python3.9/site-packages/tornado/test/tcpserver_test.py b/venv/lib/python3.9/site-packages/tornado/test/tcpserver_test.py new file mode 100644 index 00000000..c636c858 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/tcpserver_test.py @@ -0,0 +1,230 @@ +import socket +import subprocess +import sys +import textwrap +import unittest + +from tornado import gen +from tornado.iostream import IOStream +from tornado.log import app_log +from tornado.tcpserver import TCPServer +from tornado.test.util import skipIfNonUnix +from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test + +from typing import Tuple + + +class TCPServerTest(AsyncTestCase): + @gen_test + def test_handle_stream_coroutine_logging(self): + # handle_stream may be a coroutine and any exception in its + # Future will be logged. + class TestServer(TCPServer): + @gen.coroutine + def handle_stream(self, stream, address): + yield stream.read_bytes(len(b"hello")) + stream.close() + 1 / 0 + + server = client = None + try: + sock, port = bind_unused_port() + server = TestServer() + server.add_socket(sock) + client = IOStream(socket.socket()) + with ExpectLog(app_log, "Exception in callback"): + yield client.connect(("localhost", port)) + yield client.write(b"hello") + yield client.read_until_close() + yield gen.moment + finally: + if server is not None: + server.stop() + if client is not None: + client.close() + + @gen_test + def test_handle_stream_native_coroutine(self): + # handle_stream may be a native coroutine. + + class TestServer(TCPServer): + async def handle_stream(self, stream, address): + stream.write(b"data") + stream.close() + + sock, port = bind_unused_port() + server = TestServer() + server.add_socket(sock) + client = IOStream(socket.socket()) + yield client.connect(("localhost", port)) + result = yield client.read_until_close() + self.assertEqual(result, b"data") + server.stop() + client.close() + + def test_stop_twice(self): + sock, port = bind_unused_port() + server = TCPServer() + server.add_socket(sock) + server.stop() + server.stop() + + @gen_test + def test_stop_in_callback(self): + # Issue #2069: calling server.stop() in a loop callback should not + # raise EBADF when the loop handles other server connection + # requests in the same loop iteration + + class TestServer(TCPServer): + @gen.coroutine + def handle_stream(self, stream, address): + server.stop() # type: ignore + yield stream.read_until_close() + + sock, port = bind_unused_port() + server = TestServer() + server.add_socket(sock) + server_addr = ("localhost", port) + N = 40 + clients = [IOStream(socket.socket()) for i in range(N)] + connected_clients = [] + + @gen.coroutine + def connect(c): + try: + yield c.connect(server_addr) + except EnvironmentError: + pass + else: + connected_clients.append(c) + + yield [connect(c) for c in clients] + + self.assertGreater(len(connected_clients), 0, "all clients failed connecting") + try: + if len(connected_clients) == N: + # Ideally we'd make the test deterministic, but we're testing + # for a race condition in combination with the system's TCP stack... + self.skipTest( + "at least one client should fail connecting " + "for the test to be meaningful" + ) + finally: + for c in connected_clients: + c.close() + + # Here tearDown() would re-raise the EBADF encountered in the IO loop + + +@skipIfNonUnix +class TestMultiprocess(unittest.TestCase): + # These tests verify that the two multiprocess examples from the + # TCPServer docs work. Both tests start a server with three worker + # processes, each of which prints its task id to stdout (a single + # byte, so we don't have to worry about atomicity of the shared + # stdout stream) and then exits. + def run_subproc(self, code: str) -> Tuple[str, str]: + try: + result = subprocess.run( + [sys.executable, "-Werror::DeprecationWarning"], + capture_output=True, + input=code, + encoding="utf8", + check=True, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"Process returned {e.returncode} stdout={e.stdout} stderr={e.stderr}" + ) from e + return result.stdout, result.stderr + + def test_listen_single(self): + # As a sanity check, run the single-process version through this test + # harness too. + code = textwrap.dedent( + """ + import asyncio + from tornado.tcpserver import TCPServer + + async def main(): + server = TCPServer() + server.listen(0, address='127.0.0.1') + + asyncio.run(main()) + print('012', end='') + """ + ) + out, err = self.run_subproc(code) + self.assertEqual("".join(sorted(out)), "012") + self.assertEqual(err, "") + + def test_bind_start(self): + code = textwrap.dedent( + """ + import warnings + + from tornado.ioloop import IOLoop + from tornado.process import task_id + from tornado.tcpserver import TCPServer + + warnings.simplefilter("ignore", DeprecationWarning) + + server = TCPServer() + server.bind(0, address='127.0.0.1') + server.start(3) + IOLoop.current().run_sync(lambda: None) + print(task_id(), end='') + """ + ) + out, err = self.run_subproc(code) + self.assertEqual("".join(sorted(out)), "012") + self.assertEqual(err, "") + + def test_add_sockets(self): + code = textwrap.dedent( + """ + import asyncio + from tornado.netutil import bind_sockets + from tornado.process import fork_processes, task_id + from tornado.ioloop import IOLoop + from tornado.tcpserver import TCPServer + + sockets = bind_sockets(0, address='127.0.0.1') + fork_processes(3) + async def post_fork_main(): + server = TCPServer() + server.add_sockets(sockets) + asyncio.run(post_fork_main()) + print(task_id(), end='') + """ + ) + out, err = self.run_subproc(code) + self.assertEqual("".join(sorted(out)), "012") + self.assertEqual(err, "") + + def test_listen_multi_reuse_port(self): + code = textwrap.dedent( + """ + import asyncio + import socket + from tornado.netutil import bind_sockets + from tornado.process import task_id, fork_processes + from tornado.tcpserver import TCPServer + + # Pick an unused port which we will be able to bind to multiple times. + (sock,) = bind_sockets(0, address='127.0.0.1', + family=socket.AF_INET, reuse_port=True) + port = sock.getsockname()[1] + + fork_processes(3) + + async def main(): + server = TCPServer() + server.listen(port, address='127.0.0.1', reuse_port=True) + asyncio.run(main()) + print(task_id(), end='') + """ + ) + out, err = self.run_subproc(code) + self.assertEqual("".join(sorted(out)), "012") + self.assertEqual(err, "") diff --git a/venv/lib/python3.9/site-packages/tornado/test/template_test.py b/venv/lib/python3.9/site-packages/tornado/test/template_test.py new file mode 100644 index 00000000..801de50b --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/template_test.py @@ -0,0 +1,536 @@ +import os +import traceback +import unittest + +from tornado.escape import utf8, native_str, to_unicode +from tornado.template import Template, DictLoader, ParseError, Loader +from tornado.util import ObjectDict + +import typing # noqa: F401 + + +class TemplateTest(unittest.TestCase): + def test_simple(self): + template = Template("Hello {{ name }}!") + self.assertEqual(template.generate(name="Ben"), b"Hello Ben!") + + def test_bytes(self): + template = Template("Hello {{ name }}!") + self.assertEqual(template.generate(name=utf8("Ben")), b"Hello Ben!") + + def test_expressions(self): + template = Template("2 + 2 = {{ 2 + 2 }}") + self.assertEqual(template.generate(), b"2 + 2 = 4") + + def test_comment(self): + template = Template("Hello{# TODO i18n #} {{ name }}!") + self.assertEqual(template.generate(name=utf8("Ben")), b"Hello Ben!") + + def test_include(self): + loader = DictLoader( + { + "index.html": '{% include "header.html" %}\nbody text', + "header.html": "header text", + } + ) + self.assertEqual( + loader.load("index.html").generate(), b"header text\nbody text" + ) + + def test_extends(self): + loader = DictLoader( + { + "base.html": """\ +<title>{% block title %}default title{% end %}</title> +<body>{% block body %}default body{% end %}</body> +""", + "page.html": """\ +{% extends "base.html" %} +{% block title %}page title{% end %} +{% block body %}page body{% end %} +""", + } + ) + self.assertEqual( + loader.load("page.html").generate(), + b"<title>page title</title>\n<body>page body</body>\n", + ) + + def test_relative_load(self): + loader = DictLoader( + { + "a/1.html": "{% include '2.html' %}", + "a/2.html": "{% include '../b/3.html' %}", + "b/3.html": "ok", + } + ) + self.assertEqual(loader.load("a/1.html").generate(), b"ok") + + def test_escaping(self): + self.assertRaises(ParseError, lambda: Template("{{")) + self.assertRaises(ParseError, lambda: Template("{%")) + self.assertEqual(Template("{{!").generate(), b"{{") + self.assertEqual(Template("{%!").generate(), b"{%") + self.assertEqual(Template("{#!").generate(), b"{#") + self.assertEqual( + Template("{{ 'expr' }} {{!jquery expr}}").generate(), + b"expr {{jquery expr}}", + ) + + def test_unicode_template(self): + template = Template(utf8("\u00e9")) + self.assertEqual(template.generate(), utf8("\u00e9")) + + def test_unicode_literal_expression(self): + # Unicode literals should be usable in templates. Note that this + # test simulates unicode characters appearing directly in the + # template file (with utf8 encoding), i.e. \u escapes would not + # be used in the template file itself. + template = Template(utf8('{{ "\u00e9" }}')) + self.assertEqual(template.generate(), utf8("\u00e9")) + + def test_custom_namespace(self): + loader = DictLoader( + {"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x + 1} + ) + self.assertEqual(loader.load("test.html").generate(), b"6") + + def test_apply(self): + def upper(s): + return s.upper() + + template = Template(utf8("{% apply upper %}foo{% end %}")) + self.assertEqual(template.generate(upper=upper), b"FOO") + + def test_unicode_apply(self): + def upper(s): + return to_unicode(s).upper() + + template = Template(utf8("{% apply upper %}foo \u00e9{% end %}")) + self.assertEqual(template.generate(upper=upper), utf8("FOO \u00c9")) + + def test_bytes_apply(self): + def upper(s): + return utf8(to_unicode(s).upper()) + + template = Template(utf8("{% apply upper %}foo \u00e9{% end %}")) + self.assertEqual(template.generate(upper=upper), utf8("FOO \u00c9")) + + def test_if(self): + template = Template(utf8("{% if x > 4 %}yes{% else %}no{% end %}")) + self.assertEqual(template.generate(x=5), b"yes") + self.assertEqual(template.generate(x=3), b"no") + + def test_if_empty_body(self): + template = Template(utf8("{% if True %}{% else %}{% end %}")) + self.assertEqual(template.generate(), b"") + + def test_try(self): + template = Template( + utf8( + """{% try %} +try{% set y = 1/x %} +{% except %}-except +{% else %}-else +{% finally %}-finally +{% end %}""" + ) + ) + self.assertEqual(template.generate(x=1), b"\ntry\n-else\n-finally\n") + self.assertEqual(template.generate(x=0), b"\ntry-except\n-finally\n") + + def test_comment_directive(self): + template = Template(utf8("{% comment blah blah %}foo")) + self.assertEqual(template.generate(), b"foo") + + def test_break_continue(self): + template = Template( + utf8( + """\ +{% for i in range(10) %} + {% if i == 2 %} + {% continue %} + {% end %} + {{ i }} + {% if i == 6 %} + {% break %} + {% end %} +{% end %}""" + ) + ) + result = template.generate() + # remove extraneous whitespace + result = b"".join(result.split()) + self.assertEqual(result, b"013456") + + def test_break_outside_loop(self): + try: + Template(utf8("{% break %}")) + raise Exception("Did not get expected exception") + except ParseError: + pass + + def test_break_in_apply(self): + # This test verifies current behavior, although of course it would + # be nice if apply didn't cause seemingly unrelated breakage + try: + Template( + utf8("{% for i in [] %}{% apply foo %}{% break %}{% end %}{% end %}") + ) + raise Exception("Did not get expected exception") + except ParseError: + pass + + @unittest.skip("no testable future imports") + def test_no_inherit_future(self): + # TODO(bdarnell): make a test like this for one of the future + # imports available in python 3. Unfortunately they're harder + # to use in a template than division was. + + # This file has from __future__ import division... + self.assertEqual(1 / 2, 0.5) + # ...but the template doesn't + template = Template("{{ 1 / 2 }}") + self.assertEqual(template.generate(), "0") + + def test_non_ascii_name(self): + loader = DictLoader({"t\u00e9st.html": "hello"}) + self.assertEqual(loader.load("t\u00e9st.html").generate(), b"hello") + + +class StackTraceTest(unittest.TestCase): + def test_error_line_number_expression(self): + loader = DictLoader( + { + "test.html": """one +two{{1/0}} +three + """ + } + ) + try: + loader.load("test.html").generate() + self.fail("did not get expected exception") + except ZeroDivisionError: + self.assertTrue("# test.html:2" in traceback.format_exc()) + + def test_error_line_number_directive(self): + loader = DictLoader( + { + "test.html": """one +two{%if 1/0%} +three{%end%} + """ + } + ) + try: + loader.load("test.html").generate() + self.fail("did not get expected exception") + except ZeroDivisionError: + self.assertTrue("# test.html:2" in traceback.format_exc()) + + def test_error_line_number_module(self): + loader = None # type: typing.Optional[DictLoader] + + def load_generate(path, **kwargs): + assert loader is not None + return loader.load(path).generate(**kwargs) + + loader = DictLoader( + {"base.html": "{% module Template('sub.html') %}", "sub.html": "{{1/0}}"}, + namespace={"_tt_modules": ObjectDict(Template=load_generate)}, + ) + try: + loader.load("base.html").generate() + self.fail("did not get expected exception") + except ZeroDivisionError: + exc_stack = traceback.format_exc() + self.assertTrue("# base.html:1" in exc_stack) + self.assertTrue("# sub.html:1" in exc_stack) + + def test_error_line_number_include(self): + loader = DictLoader( + {"base.html": "{% include 'sub.html' %}", "sub.html": "{{1/0}}"} + ) + try: + loader.load("base.html").generate() + self.fail("did not get expected exception") + except ZeroDivisionError: + self.assertTrue("# sub.html:1 (via base.html:1)" in traceback.format_exc()) + + def test_error_line_number_extends_base_error(self): + loader = DictLoader( + {"base.html": "{{1/0}}", "sub.html": "{% extends 'base.html' %}"} + ) + try: + loader.load("sub.html").generate() + self.fail("did not get expected exception") + except ZeroDivisionError: + exc_stack = traceback.format_exc() + self.assertTrue("# base.html:1" in exc_stack) + + def test_error_line_number_extends_sub_error(self): + loader = DictLoader( + { + "base.html": "{% block 'block' %}{% end %}", + "sub.html": """ +{% extends 'base.html' %} +{% block 'block' %} +{{1/0}} +{% end %} + """, + } + ) + try: + loader.load("sub.html").generate() + self.fail("did not get expected exception") + except ZeroDivisionError: + self.assertTrue("# sub.html:4 (via base.html:1)" in traceback.format_exc()) + + def test_multi_includes(self): + loader = DictLoader( + { + "a.html": "{% include 'b.html' %}", + "b.html": "{% include 'c.html' %}", + "c.html": "{{1/0}}", + } + ) + try: + loader.load("a.html").generate() + self.fail("did not get expected exception") + except ZeroDivisionError: + self.assertTrue( + "# c.html:1 (via b.html:1, a.html:1)" in traceback.format_exc() + ) + + +class ParseErrorDetailTest(unittest.TestCase): + def test_details(self): + loader = DictLoader({"foo.html": "\n\n{{"}) + with self.assertRaises(ParseError) as cm: + loader.load("foo.html") + self.assertEqual("Missing end expression }} at foo.html:3", str(cm.exception)) + self.assertEqual("foo.html", cm.exception.filename) + self.assertEqual(3, cm.exception.lineno) + + def test_custom_parse_error(self): + # Make sure that ParseErrors remain compatible with their + # pre-4.3 signature. + self.assertEqual("asdf at None:0", str(ParseError("asdf"))) + + +class AutoEscapeTest(unittest.TestCase): + def setUp(self): + self.templates = { + "escaped.html": "{% autoescape xhtml_escape %}{{ name }}", + "unescaped.html": "{% autoescape None %}{{ name }}", + "default.html": "{{ name }}", + "include.html": """\ +escaped: {% include 'escaped.html' %} +unescaped: {% include 'unescaped.html' %} +default: {% include 'default.html' %} +""", + "escaped_block.html": """\ +{% autoescape xhtml_escape %}\ +{% block name %}base: {{ name }}{% end %}""", + "unescaped_block.html": """\ +{% autoescape None %}\ +{% block name %}base: {{ name }}{% end %}""", + # Extend a base template with different autoescape policy, + # with and without overriding the base's blocks + "escaped_extends_unescaped.html": """\ +{% autoescape xhtml_escape %}\ +{% extends "unescaped_block.html" %}""", + "escaped_overrides_unescaped.html": """\ +{% autoescape xhtml_escape %}\ +{% extends "unescaped_block.html" %}\ +{% block name %}extended: {{ name }}{% end %}""", + "unescaped_extends_escaped.html": """\ +{% autoescape None %}\ +{% extends "escaped_block.html" %}""", + "unescaped_overrides_escaped.html": """\ +{% autoescape None %}\ +{% extends "escaped_block.html" %}\ +{% block name %}extended: {{ name }}{% end %}""", + "raw_expression.html": """\ +{% autoescape xhtml_escape %}\ +expr: {{ name }} +raw: {% raw name %}""", + } + + def test_default_off(self): + loader = DictLoader(self.templates, autoescape=None) + name = "Bobby <table>s" + self.assertEqual( + loader.load("escaped.html").generate(name=name), b"Bobby <table>s" + ) + self.assertEqual( + loader.load("unescaped.html").generate(name=name), b"Bobby <table>s" + ) + self.assertEqual( + loader.load("default.html").generate(name=name), b"Bobby <table>s" + ) + + self.assertEqual( + loader.load("include.html").generate(name=name), + b"escaped: Bobby <table>s\n" + b"unescaped: Bobby <table>s\n" + b"default: Bobby <table>s\n", + ) + + def test_default_on(self): + loader = DictLoader(self.templates, autoescape="xhtml_escape") + name = "Bobby <table>s" + self.assertEqual( + loader.load("escaped.html").generate(name=name), b"Bobby <table>s" + ) + self.assertEqual( + loader.load("unescaped.html").generate(name=name), b"Bobby <table>s" + ) + self.assertEqual( + loader.load("default.html").generate(name=name), b"Bobby <table>s" + ) + + self.assertEqual( + loader.load("include.html").generate(name=name), + b"escaped: Bobby <table>s\n" + b"unescaped: Bobby <table>s\n" + b"default: Bobby <table>s\n", + ) + + def test_unextended_block(self): + loader = DictLoader(self.templates) + name = "<script>" + self.assertEqual( + loader.load("escaped_block.html").generate(name=name), + b"base: <script>", + ) + self.assertEqual( + loader.load("unescaped_block.html").generate(name=name), b"base: <script>" + ) + + def test_extended_block(self): + loader = DictLoader(self.templates) + + def render(name): + return loader.load(name).generate(name="<script>") + + self.assertEqual(render("escaped_extends_unescaped.html"), b"base: <script>") + self.assertEqual( + render("escaped_overrides_unescaped.html"), b"extended: <script>" + ) + + self.assertEqual( + render("unescaped_extends_escaped.html"), b"base: <script>" + ) + self.assertEqual( + render("unescaped_overrides_escaped.html"), b"extended: <script>" + ) + + def test_raw_expression(self): + loader = DictLoader(self.templates) + + def render(name): + return loader.load(name).generate(name='<>&"') + + self.assertEqual( + render("raw_expression.html"), b"expr: <>&"\n" b'raw: <>&"' + ) + + def test_custom_escape(self): + loader = DictLoader({"foo.py": "{% autoescape py_escape %}s = {{ name }}\n"}) + + def py_escape(s): + self.assertEqual(type(s), bytes) + return repr(native_str(s)) + + def render(template, name): + return loader.load(template).generate(py_escape=py_escape, name=name) + + self.assertEqual(render("foo.py", "<html>"), b"s = '<html>'\n") + self.assertEqual(render("foo.py", "';sys.exit()"), b"""s = "';sys.exit()"\n""") + self.assertEqual( + render("foo.py", ["not a string"]), b"""s = "['not a string']"\n""" + ) + + def test_manual_minimize_whitespace(self): + # Whitespace including newlines is allowed within template tags + # and directives, and this is one way to avoid long lines while + # keeping extra whitespace out of the rendered output. + loader = DictLoader( + { + "foo.txt": """\ +{% for i in items + %}{% if i > 0 %}, {% end %}{# + #}{{i + }}{% end +%}""" + } + ) + self.assertEqual( + loader.load("foo.txt").generate(items=range(5)), b"0, 1, 2, 3, 4" + ) + + def test_whitespace_by_filename(self): + # Default whitespace handling depends on the template filename. + loader = DictLoader( + { + "foo.html": " \n\t\n asdf\t ", + "bar.js": " \n\n\n\t qwer ", + "baz.txt": "\t zxcv\n\n", + "include.html": " {% include baz.txt %} \n ", + "include.txt": "\t\t{% include foo.html %} ", + } + ) + + # HTML and JS files have whitespace compressed by default. + self.assertEqual(loader.load("foo.html").generate(), b"\nasdf ") + self.assertEqual(loader.load("bar.js").generate(), b"\nqwer ") + # TXT files do not. + self.assertEqual(loader.load("baz.txt").generate(), b"\t zxcv\n\n") + + # Each file maintains its own status even when included in + # a file of the other type. + self.assertEqual(loader.load("include.html").generate(), b" \t zxcv\n\n\n") + self.assertEqual(loader.load("include.txt").generate(), b"\t\t\nasdf ") + + def test_whitespace_by_loader(self): + templates = {"foo.html": "\t\tfoo\n\n", "bar.txt": "\t\tbar\n\n"} + loader = DictLoader(templates, whitespace="all") + self.assertEqual(loader.load("foo.html").generate(), b"\t\tfoo\n\n") + self.assertEqual(loader.load("bar.txt").generate(), b"\t\tbar\n\n") + + loader = DictLoader(templates, whitespace="single") + self.assertEqual(loader.load("foo.html").generate(), b" foo\n") + self.assertEqual(loader.load("bar.txt").generate(), b" bar\n") + + loader = DictLoader(templates, whitespace="oneline") + self.assertEqual(loader.load("foo.html").generate(), b" foo ") + self.assertEqual(loader.load("bar.txt").generate(), b" bar ") + + def test_whitespace_directive(self): + loader = DictLoader( + { + "foo.html": """\ +{% whitespace oneline %} + {% for i in range(3) %} + {{ i }} + {% end %} +{% whitespace all %} + pre\tformatted +""" + } + ) + self.assertEqual( + loader.load("foo.html").generate(), b" 0 1 2 \n pre\tformatted\n" + ) + + +class TemplateLoaderTest(unittest.TestCase): + def setUp(self): + self.loader = Loader(os.path.join(os.path.dirname(__file__), "templates")) + + def test_utf8_in_file(self): + tmpl = self.loader.load("utf8.html") + result = tmpl.generate() + self.assertEqual(to_unicode(result).strip(), "H\u00e9llo") diff --git a/venv/lib/python3.9/site-packages/tornado/test/templates/utf8.html b/venv/lib/python3.9/site-packages/tornado/test/templates/utf8.html new file mode 100644 index 00000000..c5253dfa --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/templates/utf8.html @@ -0,0 +1 @@ +Héllo diff --git a/venv/lib/python3.9/site-packages/tornado/test/test.crt b/venv/lib/python3.9/site-packages/tornado/test/test.crt new file mode 100644 index 00000000..ffc49b06 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/test.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDWzCCAkOgAwIBAgIUV4spou0CenmvKqa7Hml/MC+JKiAwDQYJKoZIhvcNAQEL +BQAwPTELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExGTAXBgNVBAoM +EFRvcm5hZG8gV2ViIFRlc3QwHhcNMTgwOTI5MTM1NjQ1WhcNMjgwOTI2MTM1NjQ1 +WjA9MQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEZMBcGA1UECgwQ +VG9ybmFkbyBXZWIgVGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AKT0LdyI8tW5uwP3ahE8BFSz+j3SsKBDv/0cKvqxVVE6sLEST2s3HjArZvIIG5sb +iBkWDrqnZ6UKDvB4jlobLGAkepxDbrxHWxK53n0C28XXGLqJQ01TlTZ5rpjttMeg +5SKNjHbxpOvpUwwQS4br4WjZKKyTGiXpFkFUty+tYVU35/U2yyvreWHmzpHx/25t +H7O2RBARVwJYKOGPtlH62lQjpIWfVfklY4Ip8Hjl3B6rBxPyBULmVQw0qgoZn648 +oa4oLjs0wnYBz01gVjNMDHej52SsB/ieH7W1TxFMzqOlcvHh41uFbQJPgcXsruSS +9Z4twzSWkUp2vk/C//4Sz38CAwEAAaNTMFEwHQYDVR0OBBYEFLf8fQ5+u8sDWAd3 +r5ZjZ5MmDWJeMB8GA1UdIwQYMBaAFLf8fQ5+u8sDWAd3r5ZjZ5MmDWJeMA8GA1Ud +EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBADkkm3pIb9IeqVNmQ2uhQOgw +UwyToTYUHNTb/Nm5lzBTBqC8gbXAS24RQ30AB/7G115Uxeo+YMKfITxm/CgR+vhF +F59/YrzwXj+G8bdbuVl/UbB6f9RSp+Zo93rUZAtPWr77gxLUrcwSRzzDwxFjC2nC +6eigbkvt1OQY775RwnFAt7HKPclE0Out+cGJIboJuO1f3r57ZdyFH0GzbZEff/7K +atGXohijWJjYvU4mk0KFHORZrcBpsv9cfkFbmgVmiRwxRJ1tLauHM3Ne+VfqYE5M +4rTStSyz3ASqVKJ2iFMQueNR/tUOuDlfRt+0nhJMuYSSkW+KTgnwyOGU9cv+mxA= +-----END CERTIFICATE----- diff --git a/venv/lib/python3.9/site-packages/tornado/test/test.key b/venv/lib/python3.9/site-packages/tornado/test/test.key new file mode 100644 index 00000000..7cb7d8d2 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/test.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCk9C3ciPLVubsD +92oRPARUs/o90rCgQ7/9HCr6sVVROrCxEk9rNx4wK2byCBubG4gZFg66p2elCg7w +eI5aGyxgJHqcQ268R1sSud59AtvF1xi6iUNNU5U2ea6Y7bTHoOUijYx28aTr6VMM +EEuG6+Fo2Siskxol6RZBVLcvrWFVN+f1Nssr63lh5s6R8f9ubR+ztkQQEVcCWCjh +j7ZR+tpUI6SFn1X5JWOCKfB45dweqwcT8gVC5lUMNKoKGZ+uPKGuKC47NMJ2Ac9N +YFYzTAx3o+dkrAf4nh+1tU8RTM6jpXLx4eNbhW0CT4HF7K7kkvWeLcM0lpFKdr5P +wv/+Es9/AgMBAAECggEABi6AaXtYXloPgB6NgwfUwbfc8OQsalUfpMShd7OdluW0 +KW6eO05de0ClIvzay/1EJGyHMMeFQtIVrT1XWFkcWJ4FWkXMqJGkABenFtg8lDVz +X8o1E3jGZrw4ptKBq9mDvL/BO9PiclTUH+ecbPn6AIvi0lTQ7grGIryiAM9mjmLy +jpCwoutF2LD4RPNg8vqWe/Z1rQw5lp8FOHhRwPooHHeoq1bSrp8dqvVAwAam7Mmf +uFgI8jrNycPgr2cwEEtbq2TQ625MhVnCpwT+kErmAStfbXXuqv1X1ZZgiNxf+61C +OL0bhPRVIHmmjiK/5qHRuN4Q5u9/Yp2SJ4W5xadSQQKBgQDR7dnOlYYQiaoPJeD/ +7jcLVJbWwbr7bE19O/QpYAtkA/FtGlKr+hQxPhK6OYp+in8eHf+ga/NSAjCWRBoh +MNAVCJtiirHo2tFsLFOmlJpGL9n3sX8UnkJN90oHfWrzJ8BZnXaSw2eOuyw8LLj+ +Q+ISl6Go8/xfsuy3EDv4AP1wCwKBgQDJJ4vEV3Kr+bc6N/xeu+G0oHvRAWwuQpcx +9D+XpnqbJbFDnWKNE7oGsDCs8Qjr0CdFUN1pm1ppITDZ5N1cWuDg/47ZAXqEK6D1 +z13S7O0oQPlnsPL7mHs2Vl73muAaBPAojFvceHHfccr7Z94BXqKsiyfaWz6kclT/ +Nl4JTdsC3QKBgQCeYgozL2J/da2lUhnIXcyPstk+29kbueFYu/QBh2HwqnzqqLJ4 +5+t2H3P3plQUFp/DdDSZrvhcBiTsKiNgqThEtkKtfSCvIvBf4a2W/4TJsW6MzxCm +2KQDuK/UqM4Y+APKWN/N6Lln2VWNbNyBkWuuRVKFatccyJyJnSjxeqW7cwKBgGyN +idCYPIrwROAHLItXKvOWE5t0ABRq3TsZC2RkdA/b5HCPs4pclexcEriRjvXrK/Yt +MH94Ve8b+UftSUQ4ytjBMS6MrLg87y0YDhLwxv8NKUq65DXAUOW+8JsAmmWQOqY3 +MK+m1BT4TMklgVoN3w3sPsKIsSJ/jLz5cv/kYweFAoGAG4iWU1378tI2Ts/Fngsv +7eoWhoda77Y9D0Yoy20aN9VdMHzIYCBOubtRPEuwgaReNwbUBWap01J63yY/fF3K +8PTz6covjoOJqxQJOvM7nM0CsJawG9ccw3YXyd9KgRIdSt6ooEhb7N8W2EXYoKl3 +g1i2t41Q/SC3HUGC5mJjpO8= +-----END PRIVATE KEY----- diff --git a/venv/lib/python3.9/site-packages/tornado/test/testing_test.py b/venv/lib/python3.9/site-packages/tornado/test/testing_test.py new file mode 100644 index 00000000..0429feee --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/testing_test.py @@ -0,0 +1,345 @@ +from tornado import gen, ioloop +from tornado.httpserver import HTTPServer +from tornado.locks import Event +from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, bind_unused_port, gen_test +from tornado.web import Application +import asyncio +import contextlib +import inspect +import gc +import os +import platform +import sys +import traceback +import unittest +import warnings + + +@contextlib.contextmanager +def set_environ(name, value): + old_value = os.environ.get(name) + os.environ[name] = value + + try: + yield + finally: + if old_value is None: + del os.environ[name] + else: + os.environ[name] = old_value + + +class AsyncTestCaseTest(AsyncTestCase): + def test_wait_timeout(self): + time = self.io_loop.time + + # Accept default 5-second timeout, no error + self.io_loop.add_timeout(time() + 0.01, self.stop) + self.wait() + + # Timeout passed to wait() + self.io_loop.add_timeout(time() + 1, self.stop) + with self.assertRaises(self.failureException): + self.wait(timeout=0.01) + + # Timeout set with environment variable + self.io_loop.add_timeout(time() + 1, self.stop) + with set_environ("ASYNC_TEST_TIMEOUT", "0.01"): + with self.assertRaises(self.failureException): + self.wait() + + def test_subsequent_wait_calls(self): + """ + This test makes sure that a second call to wait() + clears the first timeout. + """ + # The first wait ends with time left on the clock + self.io_loop.add_timeout(self.io_loop.time() + 0.00, self.stop) + self.wait(timeout=0.1) + # The second wait has enough time for itself but would fail if the + # first wait's deadline were still in effect. + self.io_loop.add_timeout(self.io_loop.time() + 0.2, self.stop) + self.wait(timeout=0.4) + + +class LeakTest(AsyncTestCase): + def tearDown(self): + super().tearDown() + # Trigger a gc to make warnings more deterministic. + gc.collect() + + def test_leaked_coroutine(self): + # This test verifies that "leaked" coroutines are shut down + # without triggering warnings like "task was destroyed but it + # is pending". If this test were to fail, it would fail + # because runtests.py detected unexpected output to stderr. + event = Event() + + async def callback(): + try: + await event.wait() + except asyncio.CancelledError: + pass + + self.io_loop.add_callback(callback) + self.io_loop.add_callback(self.stop) + self.wait() + + +class AsyncHTTPTestCaseTest(AsyncHTTPTestCase): + def setUp(self): + super().setUp() + # Bind a second port. + sock, port = bind_unused_port() + app = Application() + server = HTTPServer(app, **self.get_httpserver_options()) + server.add_socket(sock) + self.second_port = port + self.second_server = server + + def get_app(self): + return Application() + + def test_fetch_segment(self): + path = "/path" + response = self.fetch(path) + self.assertEqual(response.request.url, self.get_url(path)) + + def test_fetch_full_http_url(self): + # Ensure that self.fetch() recognizes absolute urls and does + # not transform them into references to our main test server. + path = "http://127.0.0.1:%d/path" % self.second_port + + response = self.fetch(path) + self.assertEqual(response.request.url, path) + + def tearDown(self): + self.second_server.stop() + super().tearDown() + + +class AsyncTestCaseWrapperTest(unittest.TestCase): + def test_undecorated_generator(self): + class Test(AsyncTestCase): + def test_gen(self): + yield + + test = Test("test_gen") + result = unittest.TestResult() + test.run(result) + self.assertEqual(len(result.errors), 1) + self.assertIn("should be decorated", result.errors[0][1]) + + @unittest.skipIf( + platform.python_implementation() == "PyPy", + "pypy destructor warnings cannot be silenced", + ) + @unittest.skipIf( + sys.version_info >= (3, 12), "py312 has its own check for test case returns" + ) + def test_undecorated_coroutine(self): + class Test(AsyncTestCase): + async def test_coro(self): + pass + + test = Test("test_coro") + result = unittest.TestResult() + + # Silence "RuntimeWarning: coroutine 'test_coro' was never awaited". + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + test.run(result) + + self.assertEqual(len(result.errors), 1) + self.assertIn("should be decorated", result.errors[0][1]) + + def test_undecorated_generator_with_skip(self): + class Test(AsyncTestCase): + @unittest.skip("don't run this") + def test_gen(self): + yield + + test = Test("test_gen") + result = unittest.TestResult() + test.run(result) + self.assertEqual(len(result.errors), 0) + self.assertEqual(len(result.skipped), 1) + + def test_other_return(self): + class Test(AsyncTestCase): + def test_other_return(self): + return 42 + + test = Test("test_other_return") + result = unittest.TestResult() + test.run(result) + self.assertEqual(len(result.errors), 1) + self.assertIn("Return value from test method ignored", result.errors[0][1]) + + def test_unwrap(self): + class Test(AsyncTestCase): + def test_foo(self): + pass + + test = Test("test_foo") + self.assertIs( + inspect.unwrap(test.test_foo), + test.test_foo.orig_method, # type: ignore[attr-defined] + ) + + +class SetUpTearDownTest(unittest.TestCase): + def test_set_up_tear_down(self): + """ + This test makes sure that AsyncTestCase calls super methods for + setUp and tearDown. + + InheritBoth is a subclass of both AsyncTestCase and + SetUpTearDown, with the ordering so that the super of + AsyncTestCase will be SetUpTearDown. + """ + events = [] + result = unittest.TestResult() + + class SetUpTearDown(unittest.TestCase): + def setUp(self): + events.append("setUp") + + def tearDown(self): + events.append("tearDown") + + class InheritBoth(AsyncTestCase, SetUpTearDown): + def test(self): + events.append("test") + + InheritBoth("test").run(result) + expected = ["setUp", "test", "tearDown"] + self.assertEqual(expected, events) + + +class AsyncHTTPTestCaseSetUpTearDownTest(unittest.TestCase): + def test_tear_down_releases_app_and_http_server(self): + result = unittest.TestResult() + + class SetUpTearDown(AsyncHTTPTestCase): + def get_app(self): + return Application() + + def test(self): + self.assertTrue(hasattr(self, "_app")) + self.assertTrue(hasattr(self, "http_server")) + + test = SetUpTearDown("test") + test.run(result) + self.assertFalse(hasattr(test, "_app")) + self.assertFalse(hasattr(test, "http_server")) + + +class GenTest(AsyncTestCase): + def setUp(self): + super().setUp() + self.finished = False + + def tearDown(self): + self.assertTrue(self.finished) + super().tearDown() + + @gen_test + def test_sync(self): + self.finished = True + + @gen_test + def test_async(self): + yield gen.moment + self.finished = True + + def test_timeout(self): + # Set a short timeout and exceed it. + @gen_test(timeout=0.1) + def test(self): + yield gen.sleep(1) + + # This can't use assertRaises because we need to inspect the + # exc_info triple (and not just the exception object) + try: + test(self) + self.fail("did not get expected exception") + except ioloop.TimeoutError: + # The stack trace should blame the add_timeout line, not just + # unrelated IOLoop/testing internals. + self.assertIn("gen.sleep(1)", traceback.format_exc()) + + self.finished = True + + def test_no_timeout(self): + # A test that does not exceed its timeout should succeed. + @gen_test(timeout=1) + def test(self): + yield gen.sleep(0.1) + + test(self) + self.finished = True + + def test_timeout_environment_variable(self): + @gen_test(timeout=0.5) + def test_long_timeout(self): + yield gen.sleep(0.25) + + # Uses provided timeout of 0.5 seconds, doesn't time out. + with set_environ("ASYNC_TEST_TIMEOUT", "0.1"): + test_long_timeout(self) + + self.finished = True + + def test_no_timeout_environment_variable(self): + @gen_test(timeout=0.01) + def test_short_timeout(self): + yield gen.sleep(1) + + # Uses environment-variable timeout of 0.1, times out. + with set_environ("ASYNC_TEST_TIMEOUT", "0.1"): + with self.assertRaises(ioloop.TimeoutError): + test_short_timeout(self) + + self.finished = True + + def test_with_method_args(self): + @gen_test + def test_with_args(self, *args): + self.assertEqual(args, ("test",)) + yield gen.moment + + test_with_args(self, "test") + self.finished = True + + def test_with_method_kwargs(self): + @gen_test + def test_with_kwargs(self, **kwargs): + self.assertDictEqual(kwargs, {"test": "test"}) + yield gen.moment + + test_with_kwargs(self, test="test") + self.finished = True + + def test_native_coroutine(self): + @gen_test + async def test(self): + self.finished = True + + test(self) + + def test_native_coroutine_timeout(self): + # Set a short timeout and exceed it. + @gen_test(timeout=0.1) + async def test(self): + await gen.sleep(1) + + try: + test(self) + self.fail("did not get expected exception") + except ioloop.TimeoutError: + self.finished = True + + +if __name__ == "__main__": + unittest.main() diff --git a/venv/lib/python3.9/site-packages/tornado/test/twisted_test.py b/venv/lib/python3.9/site-packages/tornado/test/twisted_test.py new file mode 100644 index 00000000..7f983a73 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/twisted_test.py @@ -0,0 +1,64 @@ +# Author: Ovidiu Predescu +# Date: July 2011 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +from tornado.testing import AsyncTestCase, gen_test + +try: + from twisted.internet.defer import ( # type: ignore + inlineCallbacks, + returnValue, + ) + + have_twisted = True +except ImportError: + have_twisted = False +else: + # Not used directly but needed for `yield deferred` to work. + import tornado.platform.twisted # noqa: F401 + +skipIfNoTwisted = unittest.skipUnless(have_twisted, "twisted module not present") + + +@skipIfNoTwisted +class ConvertDeferredTest(AsyncTestCase): + @gen_test + def test_success(self): + @inlineCallbacks + def fn(): + if False: + # inlineCallbacks doesn't work with regular functions; + # must have a yield even if it's unreachable. + yield + returnValue(42) + + res = yield fn() + self.assertEqual(res, 42) + + @gen_test + def test_failure(self): + @inlineCallbacks + def fn(): + if False: + yield + 1 / 0 + + with self.assertRaises(ZeroDivisionError): + yield fn() + + +if __name__ == "__main__": + unittest.main() diff --git a/venv/lib/python3.9/site-packages/tornado/test/util.py b/venv/lib/python3.9/site-packages/tornado/test/util.py new file mode 100644 index 00000000..bcb9bbde --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/util.py @@ -0,0 +1,114 @@ +import contextlib +import os +import platform +import socket +import sys +import textwrap +import typing # noqa: F401 +import unittest +import warnings + +from tornado.testing import bind_unused_port + +skipIfNonUnix = unittest.skipIf( + os.name != "posix" or sys.platform == "cygwin", "non-unix platform" +) + +# travis-ci.org runs our tests in an overworked virtual machine, which makes +# timing-related tests unreliable. +skipOnTravis = unittest.skipIf( + "TRAVIS" in os.environ, "timing tests unreliable on travis" +) + +# Set the environment variable NO_NETWORK=1 to disable any tests that +# depend on an external network. +skipIfNoNetwork = unittest.skipIf("NO_NETWORK" in os.environ, "network access disabled") + +skipNotCPython = unittest.skipIf( + platform.python_implementation() != "CPython", "Not CPython implementation" +) + +# Used for tests affected by +# https://bitbucket.org/pypy/pypy/issues/2616/incomplete-error-handling-in +# TODO: remove this after pypy3 5.8 is obsolete. +skipPypy3V58 = unittest.skipIf( + platform.python_implementation() == "PyPy" + and sys.version_info > (3,) + and sys.pypy_version_info < (5, 9), # type: ignore + "pypy3 5.8 has buggy ssl module", +) + + +def _detect_ipv6(): + if not socket.has_ipv6: + # socket.has_ipv6 check reports whether ipv6 was present at compile + # time. It's usually true even when ipv6 doesn't work for other reasons. + return False + sock = None + try: + sock = socket.socket(socket.AF_INET6) + sock.bind(("::1", 0)) + except socket.error: + return False + finally: + if sock is not None: + sock.close() + return True + + +skipIfNoIPv6 = unittest.skipIf(not _detect_ipv6(), "ipv6 support not present") + + +def refusing_port(): + """Returns a local port number that will refuse all connections. + + Return value is (cleanup_func, port); the cleanup function + must be called to free the port to be reused. + """ + # On travis-ci, port numbers are reassigned frequently. To avoid + # collisions with other tests, we use an open client-side socket's + # ephemeral port number to ensure that nothing can listen on that + # port. + server_socket, port = bind_unused_port() + server_socket.setblocking(True) + client_socket = socket.socket() + client_socket.connect(("127.0.0.1", port)) + conn, client_addr = server_socket.accept() + conn.close() + server_socket.close() + return (client_socket.close, client_addr[1]) + + +def exec_test(caller_globals, caller_locals, s): + """Execute ``s`` in a given context and return the result namespace. + + Used to define functions for tests in particular python + versions that would be syntax errors in older versions. + """ + # Flatten the real global and local namespace into our fake + # globals: it's all global from the perspective of code defined + # in s. + global_namespace = dict(caller_globals, **caller_locals) # type: ignore + local_namespace = {} # type: typing.Dict[str, typing.Any] + exec(textwrap.dedent(s), global_namespace, local_namespace) + return local_namespace + + +def subTest(test, *args, **kwargs): + """Compatibility shim for unittest.TestCase.subTest. + + Usage: ``with tornado.test.util.subTest(self, x=x):`` + """ + try: + subTest = test.subTest # py34+ + except AttributeError: + subTest = contextlib.contextmanager(lambda *a, **kw: (yield)) + return subTest(*args, **kwargs) + + +@contextlib.contextmanager +def ignore_deprecation(): + """Context manager to ignore deprecation warnings.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + yield diff --git a/venv/lib/python3.9/site-packages/tornado/test/util_test.py b/venv/lib/python3.9/site-packages/tornado/test/util_test.py new file mode 100644 index 00000000..02cf0c19 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/util_test.py @@ -0,0 +1,308 @@ +from io import StringIO +import re +import sys +import datetime +import unittest + +import tornado +from tornado.escape import utf8 +from tornado.util import ( + raise_exc_info, + Configurable, + exec_in, + ArgReplacer, + timedelta_to_seconds, + import_object, + re_unescape, + is_finalizing, +) + +import typing +from typing import cast + +if typing.TYPE_CHECKING: + from typing import Dict, Any # noqa: F401 + + +class RaiseExcInfoTest(unittest.TestCase): + def test_two_arg_exception(self): + # This test would fail on python 3 if raise_exc_info were simply + # a three-argument raise statement, because TwoArgException + # doesn't have a "copy constructor" + class TwoArgException(Exception): + def __init__(self, a, b): + super().__init__() + self.a, self.b = a, b + + try: + raise TwoArgException(1, 2) + except TwoArgException: + exc_info = sys.exc_info() + try: + raise_exc_info(exc_info) + self.fail("didn't get expected exception") + except TwoArgException as e: + self.assertIs(e, exc_info[1]) + + +class TestConfigurable(Configurable): + @classmethod + def configurable_base(cls): + return TestConfigurable + + @classmethod + def configurable_default(cls): + return TestConfig1 + + +class TestConfig1(TestConfigurable): + def initialize(self, pos_arg=None, a=None): + self.a = a + self.pos_arg = pos_arg + + +class TestConfig2(TestConfigurable): + def initialize(self, pos_arg=None, b=None): + self.b = b + self.pos_arg = pos_arg + + +class TestConfig3(TestConfigurable): + # TestConfig3 is a configuration option that is itself configurable. + @classmethod + def configurable_base(cls): + return TestConfig3 + + @classmethod + def configurable_default(cls): + return TestConfig3A + + +class TestConfig3A(TestConfig3): + def initialize(self, a=None): + self.a = a + + +class TestConfig3B(TestConfig3): + def initialize(self, b=None): + self.b = b + + +class ConfigurableTest(unittest.TestCase): + def setUp(self): + self.saved = TestConfigurable._save_configuration() + self.saved3 = TestConfig3._save_configuration() + + def tearDown(self): + TestConfigurable._restore_configuration(self.saved) + TestConfig3._restore_configuration(self.saved3) + + def checkSubclasses(self): + # no matter how the class is configured, it should always be + # possible to instantiate the subclasses directly + self.assertIsInstance(TestConfig1(), TestConfig1) + self.assertIsInstance(TestConfig2(), TestConfig2) + + obj = TestConfig1(a=1) + self.assertEqual(obj.a, 1) + obj2 = TestConfig2(b=2) + self.assertEqual(obj2.b, 2) + + def test_default(self): + # In these tests we combine a typing.cast to satisfy mypy with + # a runtime type-assertion. Without the cast, mypy would only + # let us access attributes of the base class. + obj = cast(TestConfig1, TestConfigurable()) + self.assertIsInstance(obj, TestConfig1) + self.assertIs(obj.a, None) + + obj = cast(TestConfig1, TestConfigurable(a=1)) + self.assertIsInstance(obj, TestConfig1) + self.assertEqual(obj.a, 1) + + self.checkSubclasses() + + def test_config_class(self): + TestConfigurable.configure(TestConfig2) + obj = cast(TestConfig2, TestConfigurable()) + self.assertIsInstance(obj, TestConfig2) + self.assertIs(obj.b, None) + + obj = cast(TestConfig2, TestConfigurable(b=2)) + self.assertIsInstance(obj, TestConfig2) + self.assertEqual(obj.b, 2) + + self.checkSubclasses() + + def test_config_str(self): + TestConfigurable.configure("tornado.test.util_test.TestConfig2") + obj = cast(TestConfig2, TestConfigurable()) + self.assertIsInstance(obj, TestConfig2) + self.assertIs(obj.b, None) + + obj = cast(TestConfig2, TestConfigurable(b=2)) + self.assertIsInstance(obj, TestConfig2) + self.assertEqual(obj.b, 2) + + self.checkSubclasses() + + def test_config_args(self): + TestConfigurable.configure(None, a=3) + obj = cast(TestConfig1, TestConfigurable()) + self.assertIsInstance(obj, TestConfig1) + self.assertEqual(obj.a, 3) + + obj = cast(TestConfig1, TestConfigurable(42, a=4)) + self.assertIsInstance(obj, TestConfig1) + self.assertEqual(obj.a, 4) + self.assertEqual(obj.pos_arg, 42) + + self.checkSubclasses() + # args bound in configure don't apply when using the subclass directly + obj = TestConfig1() + self.assertIs(obj.a, None) + + def test_config_class_args(self): + TestConfigurable.configure(TestConfig2, b=5) + obj = cast(TestConfig2, TestConfigurable()) + self.assertIsInstance(obj, TestConfig2) + self.assertEqual(obj.b, 5) + + obj = cast(TestConfig2, TestConfigurable(42, b=6)) + self.assertIsInstance(obj, TestConfig2) + self.assertEqual(obj.b, 6) + self.assertEqual(obj.pos_arg, 42) + + self.checkSubclasses() + # args bound in configure don't apply when using the subclass directly + obj = TestConfig2() + self.assertIs(obj.b, None) + + def test_config_multi_level(self): + TestConfigurable.configure(TestConfig3, a=1) + obj = cast(TestConfig3A, TestConfigurable()) + self.assertIsInstance(obj, TestConfig3A) + self.assertEqual(obj.a, 1) + + TestConfigurable.configure(TestConfig3) + TestConfig3.configure(TestConfig3B, b=2) + obj2 = cast(TestConfig3B, TestConfigurable()) + self.assertIsInstance(obj2, TestConfig3B) + self.assertEqual(obj2.b, 2) + + def test_config_inner_level(self): + # The inner level can be used even when the outer level + # doesn't point to it. + obj = TestConfig3() + self.assertIsInstance(obj, TestConfig3A) + + TestConfig3.configure(TestConfig3B) + obj = TestConfig3() + self.assertIsInstance(obj, TestConfig3B) + + # Configuring the base doesn't configure the inner. + obj2 = TestConfigurable() + self.assertIsInstance(obj2, TestConfig1) + TestConfigurable.configure(TestConfig2) + + obj3 = TestConfigurable() + self.assertIsInstance(obj3, TestConfig2) + + obj = TestConfig3() + self.assertIsInstance(obj, TestConfig3B) + + +class UnicodeLiteralTest(unittest.TestCase): + def test_unicode_escapes(self): + self.assertEqual(utf8("\u00e9"), b"\xc3\xa9") + + +class ExecInTest(unittest.TestCase): + # TODO(bdarnell): make a version of this test for one of the new + # future imports available in python 3. + @unittest.skip("no testable future imports") + def test_no_inherit_future(self): + # This file has from __future__ import print_function... + f = StringIO() + print("hello", file=f) + # ...but the template doesn't + exec_in('print >> f, "world"', dict(f=f)) + self.assertEqual(f.getvalue(), "hello\nworld\n") + + +class ArgReplacerTest(unittest.TestCase): + def setUp(self): + def function(x, y, callback=None, z=None): + pass + + self.replacer = ArgReplacer(function, "callback") + + def test_omitted(self): + args = (1, 2) + kwargs = dict() # type: Dict[str, Any] + self.assertIs(self.replacer.get_old_value(args, kwargs), None) + self.assertEqual( + self.replacer.replace("new", args, kwargs), + (None, (1, 2), dict(callback="new")), + ) + + def test_position(self): + args = (1, 2, "old", 3) + kwargs = dict() # type: Dict[str, Any] + self.assertEqual(self.replacer.get_old_value(args, kwargs), "old") + self.assertEqual( + self.replacer.replace("new", args, kwargs), + ("old", [1, 2, "new", 3], dict()), + ) + + def test_keyword(self): + args = (1,) + kwargs = dict(y=2, callback="old", z=3) + self.assertEqual(self.replacer.get_old_value(args, kwargs), "old") + self.assertEqual( + self.replacer.replace("new", args, kwargs), + ("old", (1,), dict(y=2, callback="new", z=3)), + ) + + +class TimedeltaToSecondsTest(unittest.TestCase): + def test_timedelta_to_seconds(self): + time_delta = datetime.timedelta(hours=1) + self.assertEqual(timedelta_to_seconds(time_delta), 3600.0) + + +class ImportObjectTest(unittest.TestCase): + def test_import_member(self): + self.assertIs(import_object("tornado.escape.utf8"), utf8) + + def test_import_member_unicode(self): + self.assertIs(import_object("tornado.escape.utf8"), utf8) + + def test_import_module(self): + self.assertIs(import_object("tornado.escape"), tornado.escape) + + def test_import_module_unicode(self): + # The internal implementation of __import__ differs depending on + # whether the thing being imported is a module or not. + # This variant requires a byte string in python 2. + self.assertIs(import_object("tornado.escape"), tornado.escape) + + +class ReUnescapeTest(unittest.TestCase): + def test_re_unescape(self): + test_strings = ("/favicon.ico", "index.html", "Hello, World!", "!$@#%;") + for string in test_strings: + self.assertEqual(string, re_unescape(re.escape(string))) + + def test_re_unescape_raises_error_on_invalid_input(self): + with self.assertRaises(ValueError): + re_unescape("\\d") + with self.assertRaises(ValueError): + re_unescape("\\b") + with self.assertRaises(ValueError): + re_unescape("\\Z") + + +class IsFinalizingTest(unittest.TestCase): + def test_basic(self): + self.assertFalse(is_finalizing()) diff --git a/venv/lib/python3.9/site-packages/tornado/test/web_test.py b/venv/lib/python3.9/site-packages/tornado/test/web_test.py new file mode 100644 index 00000000..c2d057c5 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/web_test.py @@ -0,0 +1,3262 @@ +from tornado.concurrent import Future +from tornado import gen +from tornado.escape import ( + json_decode, + utf8, + to_unicode, + recursive_unicode, + native_str, + to_basestring, +) +from tornado.httpclient import HTTPClientError +from tornado.httputil import format_timestamp +from tornado.iostream import IOStream +from tornado import locale +from tornado.locks import Event +from tornado.log import app_log, gen_log +from tornado.simple_httpclient import SimpleAsyncHTTPClient +from tornado.template import DictLoader +from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test +from tornado.test.util import ignore_deprecation +from tornado.util import ObjectDict, unicode_type +from tornado.web import ( + Application, + RequestHandler, + StaticFileHandler, + RedirectHandler as WebRedirectHandler, + HTTPError, + MissingArgumentError, + ErrorHandler, + authenticated, + url, + _create_signature_v1, + create_signed_value, + decode_signed_value, + get_signature_key_version, + UIModule, + Finish, + stream_request_body, + removeslash, + addslash, + GZipContentEncoding, +) + +import binascii +import contextlib +import copy +import datetime +import email.utils +import gzip +from io import BytesIO +import itertools +import logging +import os +import re +import socket +import typing # noqa: F401 +import unittest +import urllib.parse + + +def relpath(*a): + return os.path.join(os.path.dirname(__file__), *a) + + +class WebTestCase(AsyncHTTPTestCase): + """Base class for web tests that also supports WSGI mode. + + Override get_handlers and get_app_kwargs instead of get_app. + This class is deprecated since WSGI mode is no longer supported. + """ + + def get_app(self): + self.app = Application(self.get_handlers(), **self.get_app_kwargs()) + return self.app + + def get_handlers(self): + raise NotImplementedError() + + def get_app_kwargs(self): + return {} + + +class SimpleHandlerTestCase(WebTestCase): + """Simplified base class for tests that work with a single handler class. + + To use, define a nested class named ``Handler``. + """ + + Handler = None + + def get_handlers(self): + return [("/", self.Handler)] + + +class HelloHandler(RequestHandler): + def get(self): + self.write("hello") + + +class CookieTestRequestHandler(RequestHandler): + # stub out enough methods to make the signed_cookie functions work + def __init__(self, cookie_secret="0123456789", key_version=None): + # don't call super.__init__ + self._cookies = {} # type: typing.Dict[str, bytes] + if key_version is None: + self.application = ObjectDict( # type: ignore + settings=dict(cookie_secret=cookie_secret) + ) + else: + self.application = ObjectDict( # type: ignore + settings=dict(cookie_secret=cookie_secret, key_version=key_version) + ) + + def get_cookie(self, name): + return self._cookies.get(name) + + def set_cookie(self, name, value, expires_days=None): + self._cookies[name] = value + + +# See SignedValueTest below for more. +class SecureCookieV1Test(unittest.TestCase): + def test_round_trip(self): + handler = CookieTestRequestHandler() + handler.set_signed_cookie("foo", b"bar", version=1) + self.assertEqual(handler.get_signed_cookie("foo", min_version=1), b"bar") + + def test_cookie_tampering_future_timestamp(self): + handler = CookieTestRequestHandler() + # this string base64-encodes to '12345678' + handler.set_signed_cookie("foo", binascii.a2b_hex(b"d76df8e7aefc"), version=1) + cookie = handler._cookies["foo"] + match = re.match(rb"12345678\|([0-9]+)\|([0-9a-f]+)", cookie) + assert match is not None + timestamp = match.group(1) + sig = match.group(2) + self.assertEqual( + _create_signature_v1( + handler.application.settings["cookie_secret"], + "foo", + "12345678", + timestamp, + ), + sig, + ) + # shifting digits from payload to timestamp doesn't alter signature + # (this is not desirable behavior, just confirming that that's how it + # works) + self.assertEqual( + _create_signature_v1( + handler.application.settings["cookie_secret"], + "foo", + "1234", + b"5678" + timestamp, + ), + sig, + ) + # tamper with the cookie + handler._cookies["foo"] = utf8( + "1234|5678%s|%s" % (to_basestring(timestamp), to_basestring(sig)) + ) + # it gets rejected + with ExpectLog(gen_log, "Cookie timestamp in future"): + self.assertTrue(handler.get_signed_cookie("foo", min_version=1) is None) + + def test_arbitrary_bytes(self): + # Secure cookies accept arbitrary data (which is base64 encoded). + # Note that normal cookies accept only a subset of ascii. + handler = CookieTestRequestHandler() + handler.set_signed_cookie("foo", b"\xe9", version=1) + self.assertEqual(handler.get_signed_cookie("foo", min_version=1), b"\xe9") + + +# See SignedValueTest below for more. +class SecureCookieV2Test(unittest.TestCase): + KEY_VERSIONS = {0: "ajklasdf0ojaisdf", 1: "aslkjasaolwkjsdf"} + + def test_round_trip(self): + handler = CookieTestRequestHandler() + handler.set_signed_cookie("foo", b"bar", version=2) + self.assertEqual(handler.get_signed_cookie("foo", min_version=2), b"bar") + + def test_key_version_roundtrip(self): + handler = CookieTestRequestHandler( + cookie_secret=self.KEY_VERSIONS, key_version=0 + ) + handler.set_signed_cookie("foo", b"bar") + self.assertEqual(handler.get_signed_cookie("foo"), b"bar") + + def test_key_version_roundtrip_differing_version(self): + handler = CookieTestRequestHandler( + cookie_secret=self.KEY_VERSIONS, key_version=1 + ) + handler.set_signed_cookie("foo", b"bar") + self.assertEqual(handler.get_signed_cookie("foo"), b"bar") + + def test_key_version_increment_version(self): + handler = CookieTestRequestHandler( + cookie_secret=self.KEY_VERSIONS, key_version=0 + ) + handler.set_signed_cookie("foo", b"bar") + new_handler = CookieTestRequestHandler( + cookie_secret=self.KEY_VERSIONS, key_version=1 + ) + new_handler._cookies = handler._cookies + self.assertEqual(new_handler.get_signed_cookie("foo"), b"bar") + + def test_key_version_invalidate_version(self): + handler = CookieTestRequestHandler( + cookie_secret=self.KEY_VERSIONS, key_version=0 + ) + handler.set_signed_cookie("foo", b"bar") + new_key_versions = self.KEY_VERSIONS.copy() + new_key_versions.pop(0) + new_handler = CookieTestRequestHandler( + cookie_secret=new_key_versions, key_version=1 + ) + new_handler._cookies = handler._cookies + self.assertEqual(new_handler.get_signed_cookie("foo"), None) + + +class FinalReturnTest(WebTestCase): + final_return = None # type: Future + + def get_handlers(self): + test = self + + class FinishHandler(RequestHandler): + @gen.coroutine + def get(self): + test.final_return = self.finish() + yield test.final_return + + @gen.coroutine + def post(self): + self.write("hello,") + yield self.flush() + test.final_return = self.finish("world") + yield test.final_return + + class RenderHandler(RequestHandler): + def create_template_loader(self, path): + return DictLoader({"foo.html": "hi"}) + + @gen.coroutine + def get(self): + test.final_return = self.render("foo.html") + + return [("/finish", FinishHandler), ("/render", RenderHandler)] + + def get_app_kwargs(self): + return dict(template_path="FinalReturnTest") + + def test_finish_method_return_future(self): + response = self.fetch(self.get_url("/finish")) + self.assertEqual(response.code, 200) + self.assertIsInstance(self.final_return, Future) + self.assertTrue(self.final_return.done()) + + response = self.fetch(self.get_url("/finish"), method="POST", body=b"") + self.assertEqual(response.code, 200) + self.assertIsInstance(self.final_return, Future) + self.assertTrue(self.final_return.done()) + + def test_render_method_return_future(self): + response = self.fetch(self.get_url("/render")) + self.assertEqual(response.code, 200) + self.assertIsInstance(self.final_return, Future) + + +class CookieTest(WebTestCase): + def get_handlers(self): + class SetCookieHandler(RequestHandler): + def get(self): + # Try setting cookies with different argument types + # to ensure that everything gets encoded correctly + self.set_cookie("str", "asdf") + self.set_cookie("unicode", "qwer") + self.set_cookie("bytes", b"zxcv") + + class GetCookieHandler(RequestHandler): + def get(self): + cookie = self.get_cookie("foo", "default") + assert cookie is not None + self.write(cookie) + + class SetCookieDomainHandler(RequestHandler): + def get(self): + # unicode domain and path arguments shouldn't break things + # either (see bug #285) + self.set_cookie("unicode_args", "blah", domain="foo.com", path="/foo") + + class SetCookieSpecialCharHandler(RequestHandler): + def get(self): + self.set_cookie("equals", "a=b") + self.set_cookie("semicolon", "a;b") + self.set_cookie("quote", 'a"b') + + class SetCookieOverwriteHandler(RequestHandler): + def get(self): + self.set_cookie("a", "b", domain="example.com") + self.set_cookie("c", "d", domain="example.com") + # A second call with the same name clobbers the first. + # Attributes from the first call are not carried over. + self.set_cookie("a", "e") + + class SetCookieMaxAgeHandler(RequestHandler): + def get(self): + self.set_cookie("foo", "bar", max_age=10) + + class SetCookieExpiresDaysHandler(RequestHandler): + def get(self): + self.set_cookie("foo", "bar", expires_days=10) + + class SetCookieFalsyFlags(RequestHandler): + def get(self): + self.set_cookie("a", "1", secure=True) + self.set_cookie("b", "1", secure=False) + self.set_cookie("c", "1", httponly=True) + self.set_cookie("d", "1", httponly=False) + + class SetCookieDeprecatedArgs(RequestHandler): + def get(self): + # Mixed case is supported, but deprecated + self.set_cookie("a", "b", HttpOnly=True, pATH="/foo") + + return [ + ("/set", SetCookieHandler), + ("/get", GetCookieHandler), + ("/set_domain", SetCookieDomainHandler), + ("/special_char", SetCookieSpecialCharHandler), + ("/set_overwrite", SetCookieOverwriteHandler), + ("/set_max_age", SetCookieMaxAgeHandler), + ("/set_expires_days", SetCookieExpiresDaysHandler), + ("/set_falsy_flags", SetCookieFalsyFlags), + ("/set_deprecated", SetCookieDeprecatedArgs), + ] + + def test_set_cookie(self): + response = self.fetch("/set") + self.assertEqual( + sorted(response.headers.get_list("Set-Cookie")), + ["bytes=zxcv; Path=/", "str=asdf; Path=/", "unicode=qwer; Path=/"], + ) + + def test_get_cookie(self): + response = self.fetch("/get", headers={"Cookie": "foo=bar"}) + self.assertEqual(response.body, b"bar") + + response = self.fetch("/get", headers={"Cookie": 'foo="bar"'}) + self.assertEqual(response.body, b"bar") + + response = self.fetch("/get", headers={"Cookie": "/=exception;"}) + self.assertEqual(response.body, b"default") + + def test_set_cookie_domain(self): + response = self.fetch("/set_domain") + self.assertEqual( + response.headers.get_list("Set-Cookie"), + ["unicode_args=blah; Domain=foo.com; Path=/foo"], + ) + + def test_cookie_special_char(self): + response = self.fetch("/special_char") + headers = sorted(response.headers.get_list("Set-Cookie")) + self.assertEqual(len(headers), 3) + self.assertEqual(headers[0], 'equals="a=b"; Path=/') + self.assertEqual(headers[1], 'quote="a\\"b"; Path=/') + # python 2.7 octal-escapes the semicolon; older versions leave it alone + self.assertTrue( + headers[2] in ('semicolon="a;b"; Path=/', 'semicolon="a\\073b"; Path=/'), + headers[2], + ) + + data = [ + ("foo=a=b", "a=b"), + ('foo="a=b"', "a=b"), + ('foo="a;b"', '"a'), # even quoted, ";" is a delimiter + ("foo=a\\073b", "a\\073b"), # escapes only decoded in quotes + ('foo="a\\073b"', "a;b"), + ('foo="a\\"b"', 'a"b'), + ] + for header, expected in data: + logging.debug("trying %r", header) + response = self.fetch("/get", headers={"Cookie": header}) + self.assertEqual(response.body, utf8(expected)) + + def test_set_cookie_overwrite(self): + response = self.fetch("/set_overwrite") + headers = response.headers.get_list("Set-Cookie") + self.assertEqual( + sorted(headers), ["a=e; Path=/", "c=d; Domain=example.com; Path=/"] + ) + + def test_set_cookie_max_age(self): + response = self.fetch("/set_max_age") + headers = response.headers.get_list("Set-Cookie") + self.assertEqual(sorted(headers), ["foo=bar; Max-Age=10; Path=/"]) + + def test_set_cookie_expires_days(self): + response = self.fetch("/set_expires_days") + header = response.headers.get("Set-Cookie") + assert header is not None + match = re.match("foo=bar; expires=(?P<expires>.+); Path=/", header) + assert match is not None + + expires = datetime.datetime.utcnow() + datetime.timedelta(days=10) + parsed = email.utils.parsedate(match.groupdict()["expires"]) + assert parsed is not None + header_expires = datetime.datetime(*parsed[:6]) + self.assertTrue(abs((expires - header_expires).total_seconds()) < 10) + + def test_set_cookie_false_flags(self): + response = self.fetch("/set_falsy_flags") + headers = sorted(response.headers.get_list("Set-Cookie")) + # The secure and httponly headers are capitalized in py35 and + # lowercase in older versions. + self.assertEqual(headers[0].lower(), "a=1; path=/; secure") + self.assertEqual(headers[1].lower(), "b=1; path=/") + self.assertEqual(headers[2].lower(), "c=1; httponly; path=/") + self.assertEqual(headers[3].lower(), "d=1; path=/") + + def test_set_cookie_deprecated(self): + with ignore_deprecation(): + response = self.fetch("/set_deprecated") + header = response.headers.get("Set-Cookie") + self.assertEqual(header, "a=b; HttpOnly; Path=/foo") + + +class AuthRedirectRequestHandler(RequestHandler): + def initialize(self, login_url): + self.login_url = login_url + + def get_login_url(self): + return self.login_url + + @authenticated + def get(self): + # we'll never actually get here because the test doesn't follow redirects + self.send_error(500) + + +class AuthRedirectTest(WebTestCase): + def get_handlers(self): + return [ + ("/relative", AuthRedirectRequestHandler, dict(login_url="/login")), + ( + "/absolute", + AuthRedirectRequestHandler, + dict(login_url="http://example.com/login"), + ), + ] + + def test_relative_auth_redirect(self): + response = self.fetch(self.get_url("/relative"), follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertEqual(response.headers["Location"], "/login?next=%2Frelative") + + def test_absolute_auth_redirect(self): + response = self.fetch(self.get_url("/absolute"), follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue( + re.match( + r"http://example.com/login\?next=http%3A%2F%2F127.0.0.1%3A[0-9]+%2Fabsolute", + response.headers["Location"], + ), + response.headers["Location"], + ) + + +class ConnectionCloseHandler(RequestHandler): + def initialize(self, test): + self.test = test + + @gen.coroutine + def get(self): + self.test.on_handler_waiting() + yield self.test.cleanup_event.wait() + + def on_connection_close(self): + self.test.on_connection_close() + + +class ConnectionCloseTest(WebTestCase): + def get_handlers(self): + self.cleanup_event = Event() + return [("/", ConnectionCloseHandler, dict(test=self))] + + def test_connection_close(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + s.connect(("127.0.0.1", self.get_http_port())) + self.stream = IOStream(s) + self.stream.write(b"GET / HTTP/1.0\r\n\r\n") + self.wait() + # Let the hanging coroutine clean up after itself + self.cleanup_event.set() + self.io_loop.run_sync(lambda: gen.sleep(0)) + + def on_handler_waiting(self): + logging.debug("handler waiting") + self.stream.close() + + def on_connection_close(self): + logging.debug("connection closed") + self.stop() + + +class EchoHandler(RequestHandler): + def get(self, *path_args): + # Type checks: web.py interfaces convert argument values to + # unicode strings (by default, but see also decode_argument). + # In httpserver.py (i.e. self.request.arguments), they're left + # as bytes. Keys are always native strings. + for key in self.request.arguments: + if type(key) != str: + raise Exception("incorrect type for key: %r" % type(key)) + for bvalue in self.request.arguments[key]: + if type(bvalue) != bytes: + raise Exception("incorrect type for value: %r" % type(bvalue)) + for svalue in self.get_arguments(key): + if type(svalue) != unicode_type: + raise Exception("incorrect type for value: %r" % type(svalue)) + for arg in path_args: + if type(arg) != unicode_type: + raise Exception("incorrect type for path arg: %r" % type(arg)) + self.write( + dict( + path=self.request.path, + path_args=path_args, + args=recursive_unicode(self.request.arguments), + ) + ) + + +class RequestEncodingTest(WebTestCase): + def get_handlers(self): + return [("/group/(.*)", EchoHandler), ("/slashes/([^/]*)/([^/]*)", EchoHandler)] + + def fetch_json(self, path): + return json_decode(self.fetch(path).body) + + def test_group_question_mark(self): + # Ensure that url-encoded question marks are handled properly + self.assertEqual( + self.fetch_json("/group/%3F"), + dict(path="/group/%3F", path_args=["?"], args={}), + ) + self.assertEqual( + self.fetch_json("/group/%3F?%3F=%3F"), + dict(path="/group/%3F", path_args=["?"], args={"?": ["?"]}), + ) + + def test_group_encoding(self): + # Path components and query arguments should be decoded the same way + self.assertEqual( + self.fetch_json("/group/%C3%A9?arg=%C3%A9"), + { + "path": "/group/%C3%A9", + "path_args": ["\u00e9"], + "args": {"arg": ["\u00e9"]}, + }, + ) + + def test_slashes(self): + # Slashes may be escaped to appear as a single "directory" in the path, + # but they are then unescaped when passed to the get() method. + self.assertEqual( + self.fetch_json("/slashes/foo/bar"), + dict(path="/slashes/foo/bar", path_args=["foo", "bar"], args={}), + ) + self.assertEqual( + self.fetch_json("/slashes/a%2Fb/c%2Fd"), + dict(path="/slashes/a%2Fb/c%2Fd", path_args=["a/b", "c/d"], args={}), + ) + + def test_error(self): + # Percent signs (encoded as %25) should not mess up printf-style + # messages in logs + with ExpectLog(gen_log, ".*Invalid unicode"): + self.fetch("/group/?arg=%25%e9") + + +class TypeCheckHandler(RequestHandler): + def prepare(self): + self.errors = {} # type: typing.Dict[str, str] + + self.check_type("status", self.get_status(), int) + + # get_argument is an exception from the general rule of using + # type str for non-body data mainly for historical reasons. + self.check_type("argument", self.get_argument("foo"), unicode_type) + self.check_type("cookie_key", list(self.cookies.keys())[0], str) + self.check_type("cookie_value", list(self.cookies.values())[0].value, str) + + # Secure cookies return bytes because they can contain arbitrary + # data, but regular cookies are native strings. + if list(self.cookies.keys()) != ["asdf"]: + raise Exception( + "unexpected values for cookie keys: %r" % self.cookies.keys() + ) + self.check_type("get_signed_cookie", self.get_signed_cookie("asdf"), bytes) + self.check_type("get_cookie", self.get_cookie("asdf"), str) + + self.check_type("xsrf_token", self.xsrf_token, bytes) + self.check_type("xsrf_form_html", self.xsrf_form_html(), str) + + self.check_type("reverse_url", self.reverse_url("typecheck", "foo"), str) + + self.check_type("request_summary", self._request_summary(), str) + + def get(self, path_component): + # path_component uses type unicode instead of str for consistency + # with get_argument() + self.check_type("path_component", path_component, unicode_type) + self.write(self.errors) + + def post(self, path_component): + self.check_type("path_component", path_component, unicode_type) + self.write(self.errors) + + def check_type(self, name, obj, expected_type): + actual_type = type(obj) + if expected_type != actual_type: + self.errors[name] = "expected %s, got %s" % (expected_type, actual_type) + + +class DecodeArgHandler(RequestHandler): + def decode_argument(self, value, name=None): + if type(value) != bytes: + raise Exception("unexpected type for value: %r" % type(value)) + # use self.request.arguments directly to avoid recursion + if "encoding" in self.request.arguments: + return value.decode(to_unicode(self.request.arguments["encoding"][0])) + else: + return value + + def get(self, arg): + def describe(s): + if type(s) == bytes: + return ["bytes", native_str(binascii.b2a_hex(s))] + elif type(s) == unicode_type: + return ["unicode", s] + raise Exception("unknown type") + + self.write({"path": describe(arg), "query": describe(self.get_argument("foo"))}) + + +class LinkifyHandler(RequestHandler): + def get(self): + self.render("linkify.html", message="http://example.com") + + +class UIModuleResourceHandler(RequestHandler): + def get(self): + self.render("page.html", entries=[1, 2]) + + +class OptionalPathHandler(RequestHandler): + def get(self, path): + self.write({"path": path}) + + +class MultiHeaderHandler(RequestHandler): + def get(self): + self.set_header("x-overwrite", "1") + self.set_header("X-Overwrite", 2) + self.add_header("x-multi", 3) + self.add_header("X-Multi", "4") + + +class RedirectHandler(RequestHandler): + def get(self): + if self.get_argument("permanent", None) is not None: + self.redirect("/", permanent=bool(int(self.get_argument("permanent")))) + elif self.get_argument("status", None) is not None: + self.redirect("/", status=int(self.get_argument("status"))) + else: + raise Exception("didn't get permanent or status arguments") + + +class EmptyFlushCallbackHandler(RequestHandler): + @gen.coroutine + def get(self): + # Ensure that the flush callback is run whether or not there + # was any output. The gen.Task and direct yield forms are + # equivalent. + yield self.flush() # "empty" flush, but writes headers + yield self.flush() # empty flush + self.write("o") + yield self.flush() # flushes the "o" + yield self.flush() # empty flush + self.finish("k") + + +class HeaderInjectionHandler(RequestHandler): + def get(self): + try: + self.set_header("X-Foo", "foo\r\nX-Bar: baz") + raise Exception("Didn't get expected exception") + except ValueError as e: + if "Unsafe header value" in str(e): + self.finish(b"ok") + else: + raise + + +class GetArgumentHandler(RequestHandler): + def prepare(self): + if self.get_argument("source", None) == "query": + method = self.get_query_argument + elif self.get_argument("source", None) == "body": + method = self.get_body_argument + else: + method = self.get_argument # type: ignore + self.finish(method("foo", "default")) + + +class GetArgumentsHandler(RequestHandler): + def prepare(self): + self.finish( + dict( + default=self.get_arguments("foo"), + query=self.get_query_arguments("foo"), + body=self.get_body_arguments("foo"), + ) + ) + + +# This test was shared with wsgi_test.py; now the name is meaningless. +class WSGISafeWebTest(WebTestCase): + COOKIE_SECRET = "WebTest.COOKIE_SECRET" + + def get_app_kwargs(self): + loader = DictLoader( + { + "linkify.html": "{% module linkify(message) %}", + "page.html": """\ +<html><head></head><body> +{% for e in entries %} +{% module Template("entry.html", entry=e) %} +{% end %} +</body></html>""", + "entry.html": """\ +{{ set_resources(embedded_css=".entry { margin-bottom: 1em; }", + embedded_javascript="js_embed()", + css_files=["/base.css", "/foo.css"], + javascript_files="/common.js", + html_head="<meta>", + html_body='<script src="/analytics.js"/>') }} +<div class="entry">...</div>""", + } + ) + return dict( + template_loader=loader, + autoescape="xhtml_escape", + cookie_secret=self.COOKIE_SECRET, + ) + + def tearDown(self): + super().tearDown() + RequestHandler._template_loaders.clear() + + def get_handlers(self): + urls = [ + url("/typecheck/(.*)", TypeCheckHandler, name="typecheck"), + url("/decode_arg/(.*)", DecodeArgHandler, name="decode_arg"), + url("/decode_arg_kw/(?P<arg>.*)", DecodeArgHandler), + url("/linkify", LinkifyHandler), + url("/uimodule_resources", UIModuleResourceHandler), + url("/optional_path/(.+)?", OptionalPathHandler), + url("/multi_header", MultiHeaderHandler), + url("/redirect", RedirectHandler), + url( + "/web_redirect_permanent", + WebRedirectHandler, + {"url": "/web_redirect_newpath"}, + ), + url( + "/web_redirect", + WebRedirectHandler, + {"url": "/web_redirect_newpath", "permanent": False}, + ), + url( + "//web_redirect_double_slash", + WebRedirectHandler, + {"url": "/web_redirect_newpath"}, + ), + url("/header_injection", HeaderInjectionHandler), + url("/get_argument", GetArgumentHandler), + url("/get_arguments", GetArgumentsHandler), + ] + return urls + + def fetch_json(self, *args, **kwargs): + response = self.fetch(*args, **kwargs) + response.rethrow() + return json_decode(response.body) + + def test_types(self): + cookie_value = to_unicode( + create_signed_value(self.COOKIE_SECRET, "asdf", "qwer") + ) + response = self.fetch( + "/typecheck/asdf?foo=bar", headers={"Cookie": "asdf=" + cookie_value} + ) + data = json_decode(response.body) + self.assertEqual(data, {}) + + response = self.fetch( + "/typecheck/asdf?foo=bar", + method="POST", + headers={"Cookie": "asdf=" + cookie_value}, + body="foo=bar", + ) + + def test_decode_argument(self): + # These urls all decode to the same thing + urls = [ + "/decode_arg/%C3%A9?foo=%C3%A9&encoding=utf-8", + "/decode_arg/%E9?foo=%E9&encoding=latin1", + "/decode_arg_kw/%E9?foo=%E9&encoding=latin1", + ] + for req_url in urls: + response = self.fetch(req_url) + response.rethrow() + data = json_decode(response.body) + self.assertEqual( + data, + {"path": ["unicode", "\u00e9"], "query": ["unicode", "\u00e9"]}, + ) + + response = self.fetch("/decode_arg/%C3%A9?foo=%C3%A9") + response.rethrow() + data = json_decode(response.body) + self.assertEqual(data, {"path": ["bytes", "c3a9"], "query": ["bytes", "c3a9"]}) + + def test_decode_argument_invalid_unicode(self): + # test that invalid unicode in URLs causes 400, not 500 + with ExpectLog(gen_log, ".*Invalid unicode.*"): + response = self.fetch("/typecheck/invalid%FF") + self.assertEqual(response.code, 400) + response = self.fetch("/typecheck/invalid?foo=%FF") + self.assertEqual(response.code, 400) + + def test_decode_argument_plus(self): + # These urls are all equivalent. + urls = [ + "/decode_arg/1%20%2B%201?foo=1%20%2B%201&encoding=utf-8", + "/decode_arg/1%20+%201?foo=1+%2B+1&encoding=utf-8", + ] + for req_url in urls: + response = self.fetch(req_url) + response.rethrow() + data = json_decode(response.body) + self.assertEqual( + data, + {"path": ["unicode", "1 + 1"], "query": ["unicode", "1 + 1"]}, + ) + + def test_reverse_url(self): + self.assertEqual(self.app.reverse_url("decode_arg", "foo"), "/decode_arg/foo") + self.assertEqual(self.app.reverse_url("decode_arg", 42), "/decode_arg/42") + self.assertEqual(self.app.reverse_url("decode_arg", b"\xe9"), "/decode_arg/%E9") + self.assertEqual( + self.app.reverse_url("decode_arg", "\u00e9"), "/decode_arg/%C3%A9" + ) + self.assertEqual( + self.app.reverse_url("decode_arg", "1 + 1"), "/decode_arg/1%20%2B%201" + ) + + def test_uimodule_unescaped(self): + response = self.fetch("/linkify") + self.assertEqual( + response.body, b'<a href="http://example.com">http://example.com</a>' + ) + + def test_uimodule_resources(self): + response = self.fetch("/uimodule_resources") + self.assertEqual( + response.body, + b"""\ +<html><head><link href="/base.css" type="text/css" rel="stylesheet"/><link href="/foo.css" type="text/css" rel="stylesheet"/> +<style type="text/css"> +.entry { margin-bottom: 1em; } +</style> +<meta> +</head><body> + + +<div class="entry">...</div> + + +<div class="entry">...</div> + +<script src="/common.js" type="text/javascript"></script> +<script type="text/javascript"> +//<![CDATA[ +js_embed() +//]]> +</script> +<script src="/analytics.js"/> +</body></html>""", # noqa: E501 + ) + + def test_optional_path(self): + self.assertEqual(self.fetch_json("/optional_path/foo"), {"path": "foo"}) + self.assertEqual(self.fetch_json("/optional_path/"), {"path": None}) + + def test_multi_header(self): + response = self.fetch("/multi_header") + self.assertEqual(response.headers["x-overwrite"], "2") + self.assertEqual(response.headers.get_list("x-multi"), ["3", "4"]) + + def test_redirect(self): + response = self.fetch("/redirect?permanent=1", follow_redirects=False) + self.assertEqual(response.code, 301) + response = self.fetch("/redirect?permanent=0", follow_redirects=False) + self.assertEqual(response.code, 302) + response = self.fetch("/redirect?status=307", follow_redirects=False) + self.assertEqual(response.code, 307) + + def test_web_redirect(self): + response = self.fetch("/web_redirect_permanent", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers["Location"], "/web_redirect_newpath") + response = self.fetch("/web_redirect", follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertEqual(response.headers["Location"], "/web_redirect_newpath") + + def test_web_redirect_double_slash(self): + response = self.fetch("//web_redirect_double_slash", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers["Location"], "/web_redirect_newpath") + + def test_header_injection(self): + response = self.fetch("/header_injection") + self.assertEqual(response.body, b"ok") + + def test_get_argument(self): + response = self.fetch("/get_argument?foo=bar") + self.assertEqual(response.body, b"bar") + response = self.fetch("/get_argument?foo=") + self.assertEqual(response.body, b"") + response = self.fetch("/get_argument") + self.assertEqual(response.body, b"default") + + # Test merging of query and body arguments. + # In singular form, body arguments take precedence over query arguments. + body = urllib.parse.urlencode(dict(foo="hello")) + response = self.fetch("/get_argument?foo=bar", method="POST", body=body) + self.assertEqual(response.body, b"hello") + # In plural methods they are merged. + response = self.fetch("/get_arguments?foo=bar", method="POST", body=body) + self.assertEqual( + json_decode(response.body), + dict(default=["bar", "hello"], query=["bar"], body=["hello"]), + ) + + def test_get_query_arguments(self): + # send as a post so we can ensure the separation between query + # string and body arguments. + body = urllib.parse.urlencode(dict(foo="hello")) + response = self.fetch( + "/get_argument?source=query&foo=bar", method="POST", body=body + ) + self.assertEqual(response.body, b"bar") + response = self.fetch( + "/get_argument?source=query&foo=", method="POST", body=body + ) + self.assertEqual(response.body, b"") + response = self.fetch("/get_argument?source=query", method="POST", body=body) + self.assertEqual(response.body, b"default") + + def test_get_body_arguments(self): + body = urllib.parse.urlencode(dict(foo="bar")) + response = self.fetch( + "/get_argument?source=body&foo=hello", method="POST", body=body + ) + self.assertEqual(response.body, b"bar") + + body = urllib.parse.urlencode(dict(foo="")) + response = self.fetch( + "/get_argument?source=body&foo=hello", method="POST", body=body + ) + self.assertEqual(response.body, b"") + + body = urllib.parse.urlencode(dict()) + response = self.fetch( + "/get_argument?source=body&foo=hello", method="POST", body=body + ) + self.assertEqual(response.body, b"default") + + def test_no_gzip(self): + response = self.fetch("/get_argument") + self.assertNotIn("Accept-Encoding", response.headers.get("Vary", "")) + self.assertNotIn("gzip", response.headers.get("Content-Encoding", "")) + + +class NonWSGIWebTests(WebTestCase): + def get_handlers(self): + return [("/empty_flush", EmptyFlushCallbackHandler)] + + def test_empty_flush(self): + response = self.fetch("/empty_flush") + self.assertEqual(response.body, b"ok") + + +class ErrorResponseTest(WebTestCase): + def get_handlers(self): + class DefaultHandler(RequestHandler): + def get(self): + if self.get_argument("status", None): + raise HTTPError(int(self.get_argument("status"))) + 1 / 0 + + class WriteErrorHandler(RequestHandler): + def get(self): + if self.get_argument("status", None): + self.send_error(int(self.get_argument("status"))) + else: + 1 / 0 + + def write_error(self, status_code, **kwargs): + self.set_header("Content-Type", "text/plain") + if "exc_info" in kwargs: + self.write("Exception: %s" % kwargs["exc_info"][0].__name__) + else: + self.write("Status: %d" % status_code) + + class FailedWriteErrorHandler(RequestHandler): + def get(self): + 1 / 0 + + def write_error(self, status_code, **kwargs): + raise Exception("exception in write_error") + + return [ + url("/default", DefaultHandler), + url("/write_error", WriteErrorHandler), + url("/failed_write_error", FailedWriteErrorHandler), + ] + + def test_default(self): + with ExpectLog(app_log, "Uncaught exception"): + response = self.fetch("/default") + self.assertEqual(response.code, 500) + self.assertTrue(b"500: Internal Server Error" in response.body) + + response = self.fetch("/default?status=503") + self.assertEqual(response.code, 503) + self.assertTrue(b"503: Service Unavailable" in response.body) + + response = self.fetch("/default?status=435") + self.assertEqual(response.code, 435) + self.assertTrue(b"435: Unknown" in response.body) + + def test_write_error(self): + with ExpectLog(app_log, "Uncaught exception"): + response = self.fetch("/write_error") + self.assertEqual(response.code, 500) + self.assertEqual(b"Exception: ZeroDivisionError", response.body) + + response = self.fetch("/write_error?status=503") + self.assertEqual(response.code, 503) + self.assertEqual(b"Status: 503", response.body) + + def test_failed_write_error(self): + with ExpectLog(app_log, "Uncaught exception"): + response = self.fetch("/failed_write_error") + self.assertEqual(response.code, 500) + self.assertEqual(b"", response.body) + + +class StaticFileTest(WebTestCase): + # The expected SHA-512 hash of robots.txt, used in tests that call + # StaticFileHandler.get_version + robots_txt_hash = ( + b"63a36e950e134b5217e33c763e88840c10a07d80e6057d92b9ac97508de7fb1f" + b"a6f0e9b7531e169657165ea764e8963399cb6d921ffe6078425aaafe54c04563" + ) + static_dir = os.path.join(os.path.dirname(__file__), "static") + + def get_handlers(self): + class StaticUrlHandler(RequestHandler): + def get(self, path): + with_v = int(self.get_argument("include_version", "1")) + self.write(self.static_url(path, include_version=with_v)) + + class AbsoluteStaticUrlHandler(StaticUrlHandler): + include_host = True + + class OverrideStaticUrlHandler(RequestHandler): + def get(self, path): + do_include = bool(self.get_argument("include_host")) + self.include_host = not do_include + + regular_url = self.static_url(path) + override_url = self.static_url(path, include_host=do_include) + if override_url == regular_url: + return self.write(str(False)) + + protocol = self.request.protocol + "://" + protocol_length = len(protocol) + check_regular = regular_url.find(protocol, 0, protocol_length) + check_override = override_url.find(protocol, 0, protocol_length) + + if do_include: + result = check_override == 0 and check_regular == -1 + else: + result = check_override == -1 and check_regular == 0 + self.write(str(result)) + + return [ + ("/static_url/(.*)", StaticUrlHandler), + ("/abs_static_url/(.*)", AbsoluteStaticUrlHandler), + ("/override_static_url/(.*)", OverrideStaticUrlHandler), + ("/root_static/(.*)", StaticFileHandler, dict(path="/")), + ] + + def get_app_kwargs(self): + return dict(static_path=relpath("static")) + + def test_static_files(self): + response = self.fetch("/robots.txt") + self.assertTrue(b"Disallow: /" in response.body) + + response = self.fetch("/static/robots.txt") + self.assertTrue(b"Disallow: /" in response.body) + self.assertEqual(response.headers.get("Content-Type"), "text/plain") + + def test_static_compressed_files(self): + response = self.fetch("/static/sample.xml.gz") + self.assertEqual(response.headers.get("Content-Type"), "application/gzip") + response = self.fetch("/static/sample.xml.bz2") + self.assertEqual( + response.headers.get("Content-Type"), "application/octet-stream" + ) + # make sure the uncompressed file still has the correct type + response = self.fetch("/static/sample.xml") + self.assertTrue( + response.headers.get("Content-Type") in set(("text/xml", "application/xml")) + ) + + def test_static_url(self): + response = self.fetch("/static_url/robots.txt") + self.assertEqual(response.body, b"/static/robots.txt?v=" + self.robots_txt_hash) + + def test_absolute_static_url(self): + response = self.fetch("/abs_static_url/robots.txt") + self.assertEqual( + response.body, + (utf8(self.get_url("/")) + b"static/robots.txt?v=" + self.robots_txt_hash), + ) + + def test_relative_version_exclusion(self): + response = self.fetch("/static_url/robots.txt?include_version=0") + self.assertEqual(response.body, b"/static/robots.txt") + + def test_absolute_version_exclusion(self): + response = self.fetch("/abs_static_url/robots.txt?include_version=0") + self.assertEqual(response.body, utf8(self.get_url("/") + "static/robots.txt")) + + def test_include_host_override(self): + self._trigger_include_host_check(False) + self._trigger_include_host_check(True) + + def _trigger_include_host_check(self, include_host): + path = "/override_static_url/robots.txt?include_host=%s" + response = self.fetch(path % int(include_host)) + self.assertEqual(response.body, utf8(str(True))) + + def get_and_head(self, *args, **kwargs): + """Performs a GET and HEAD request and returns the GET response. + + Fails if any ``Content-*`` headers returned by the two requests + differ. + """ + head_response = self.fetch(*args, method="HEAD", **kwargs) + get_response = self.fetch(*args, method="GET", **kwargs) + content_headers = set() + for h in itertools.chain(head_response.headers, get_response.headers): + if h.startswith("Content-"): + content_headers.add(h) + for h in content_headers: + self.assertEqual( + head_response.headers.get(h), + get_response.headers.get(h), + "%s differs between GET (%s) and HEAD (%s)" + % (h, head_response.headers.get(h), get_response.headers.get(h)), + ) + return get_response + + def test_static_304_if_modified_since(self): + response1 = self.get_and_head("/static/robots.txt") + response2 = self.get_and_head( + "/static/robots.txt", + headers={"If-Modified-Since": response1.headers["Last-Modified"]}, + ) + self.assertEqual(response2.code, 304) + self.assertTrue("Content-Length" not in response2.headers) + + def test_static_304_if_none_match(self): + response1 = self.get_and_head("/static/robots.txt") + response2 = self.get_and_head( + "/static/robots.txt", headers={"If-None-Match": response1.headers["Etag"]} + ) + self.assertEqual(response2.code, 304) + + def test_static_304_etag_modified_bug(self): + response1 = self.get_and_head("/static/robots.txt") + response2 = self.get_and_head( + "/static/robots.txt", + headers={ + "If-None-Match": '"MISMATCH"', + "If-Modified-Since": response1.headers["Last-Modified"], + }, + ) + self.assertEqual(response2.code, 200) + + def test_static_if_modified_since_pre_epoch(self): + # On windows, the functions that work with time_t do not accept + # negative values, and at least one client (processing.js) seems + # to use if-modified-since 1/1/1960 as a cache-busting technique. + response = self.get_and_head( + "/static/robots.txt", + headers={"If-Modified-Since": "Fri, 01 Jan 1960 00:00:00 GMT"}, + ) + self.assertEqual(response.code, 200) + + def test_static_if_modified_since_time_zone(self): + # Instead of the value from Last-Modified, make requests with times + # chosen just before and after the known modification time + # of the file to ensure that the right time zone is being used + # when parsing If-Modified-Since. + stat = os.stat(relpath("static/robots.txt")) + + response = self.get_and_head( + "/static/robots.txt", + headers={"If-Modified-Since": format_timestamp(stat.st_mtime - 1)}, + ) + self.assertEqual(response.code, 200) + response = self.get_and_head( + "/static/robots.txt", + headers={"If-Modified-Since": format_timestamp(stat.st_mtime + 1)}, + ) + self.assertEqual(response.code, 304) + + def test_static_etag(self): + response = self.get_and_head("/static/robots.txt") + self.assertEqual( + utf8(response.headers.get("Etag")), b'"' + self.robots_txt_hash + b'"' + ) + + def test_static_with_range(self): + response = self.get_and_head( + "/static/robots.txt", headers={"Range": "bytes=0-9"} + ) + self.assertEqual(response.code, 206) + self.assertEqual(response.body, b"User-agent") + self.assertEqual( + utf8(response.headers.get("Etag")), b'"' + self.robots_txt_hash + b'"' + ) + self.assertEqual(response.headers.get("Content-Length"), "10") + self.assertEqual(response.headers.get("Content-Range"), "bytes 0-9/26") + + def test_static_with_range_full_file(self): + response = self.get_and_head( + "/static/robots.txt", headers={"Range": "bytes=0-"} + ) + # Note: Chrome refuses to play audio if it gets an HTTP 206 in response + # to ``Range: bytes=0-`` :( + self.assertEqual(response.code, 200) + robots_file_path = os.path.join(self.static_dir, "robots.txt") + with open(robots_file_path) as f: + self.assertEqual(response.body, utf8(f.read())) + self.assertEqual(response.headers.get("Content-Length"), "26") + self.assertEqual(response.headers.get("Content-Range"), None) + + def test_static_with_range_full_past_end(self): + response = self.get_and_head( + "/static/robots.txt", headers={"Range": "bytes=0-10000000"} + ) + self.assertEqual(response.code, 200) + robots_file_path = os.path.join(self.static_dir, "robots.txt") + with open(robots_file_path) as f: + self.assertEqual(response.body, utf8(f.read())) + self.assertEqual(response.headers.get("Content-Length"), "26") + self.assertEqual(response.headers.get("Content-Range"), None) + + def test_static_with_range_partial_past_end(self): + response = self.get_and_head( + "/static/robots.txt", headers={"Range": "bytes=1-10000000"} + ) + self.assertEqual(response.code, 206) + robots_file_path = os.path.join(self.static_dir, "robots.txt") + with open(robots_file_path) as f: + self.assertEqual(response.body, utf8(f.read()[1:])) + self.assertEqual(response.headers.get("Content-Length"), "25") + self.assertEqual(response.headers.get("Content-Range"), "bytes 1-25/26") + + def test_static_with_range_end_edge(self): + response = self.get_and_head( + "/static/robots.txt", headers={"Range": "bytes=22-"} + ) + self.assertEqual(response.body, b": /\n") + self.assertEqual(response.headers.get("Content-Length"), "4") + self.assertEqual(response.headers.get("Content-Range"), "bytes 22-25/26") + + def test_static_with_range_neg_end(self): + response = self.get_and_head( + "/static/robots.txt", headers={"Range": "bytes=-4"} + ) + self.assertEqual(response.body, b": /\n") + self.assertEqual(response.headers.get("Content-Length"), "4") + self.assertEqual(response.headers.get("Content-Range"), "bytes 22-25/26") + + def test_static_with_range_neg_past_start(self): + response = self.get_and_head( + "/static/robots.txt", headers={"Range": "bytes=-1000000"} + ) + self.assertEqual(response.code, 200) + robots_file_path = os.path.join(self.static_dir, "robots.txt") + with open(robots_file_path) as f: + self.assertEqual(response.body, utf8(f.read())) + self.assertEqual(response.headers.get("Content-Length"), "26") + self.assertEqual(response.headers.get("Content-Range"), None) + + def test_static_invalid_range(self): + response = self.get_and_head("/static/robots.txt", headers={"Range": "asdf"}) + self.assertEqual(response.code, 200) + + def test_static_unsatisfiable_range_zero_suffix(self): + response = self.get_and_head( + "/static/robots.txt", headers={"Range": "bytes=-0"} + ) + self.assertEqual(response.headers.get("Content-Range"), "bytes */26") + self.assertEqual(response.code, 416) + + def test_static_unsatisfiable_range_invalid_start(self): + response = self.get_and_head( + "/static/robots.txt", headers={"Range": "bytes=26"} + ) + self.assertEqual(response.code, 416) + self.assertEqual(response.headers.get("Content-Range"), "bytes */26") + + def test_static_unsatisfiable_range_end_less_than_start(self): + response = self.get_and_head( + "/static/robots.txt", headers={"Range": "bytes=10-3"} + ) + self.assertEqual(response.code, 416) + self.assertEqual(response.headers.get("Content-Range"), "bytes */26") + + def test_static_head(self): + response = self.fetch("/static/robots.txt", method="HEAD") + self.assertEqual(response.code, 200) + # No body was returned, but we did get the right content length. + self.assertEqual(response.body, b"") + self.assertEqual(response.headers["Content-Length"], "26") + self.assertEqual( + utf8(response.headers["Etag"]), b'"' + self.robots_txt_hash + b'"' + ) + + def test_static_head_range(self): + response = self.fetch( + "/static/robots.txt", method="HEAD", headers={"Range": "bytes=1-4"} + ) + self.assertEqual(response.code, 206) + self.assertEqual(response.body, b"") + self.assertEqual(response.headers["Content-Length"], "4") + self.assertEqual( + utf8(response.headers["Etag"]), b'"' + self.robots_txt_hash + b'"' + ) + + def test_static_range_if_none_match(self): + response = self.get_and_head( + "/static/robots.txt", + headers={ + "Range": "bytes=1-4", + "If-None-Match": b'"' + self.robots_txt_hash + b'"', + }, + ) + self.assertEqual(response.code, 304) + self.assertEqual(response.body, b"") + self.assertTrue("Content-Length" not in response.headers) + self.assertEqual( + utf8(response.headers["Etag"]), b'"' + self.robots_txt_hash + b'"' + ) + + def test_static_404(self): + response = self.get_and_head("/static/blarg") + self.assertEqual(response.code, 404) + + def test_path_traversal_protection(self): + # curl_httpclient processes ".." on the client side, so we + # must test this with simple_httpclient. + self.http_client.close() + self.http_client = SimpleAsyncHTTPClient() + with ExpectLog(gen_log, ".*not in root static directory"): + response = self.get_and_head("/static/../static_foo.txt") + # Attempted path traversal should result in 403, not 200 + # (which means the check failed and the file was served) + # or 404 (which means that the file didn't exist and + # is probably a packaging error). + self.assertEqual(response.code, 403) + + @unittest.skipIf(os.name != "posix", "non-posix OS") + def test_root_static_path(self): + # Sometimes people set the StaticFileHandler's path to '/' + # to disable Tornado's path validation (in conjunction with + # their own validation in get_absolute_path). Make sure + # that the stricter validation in 4.2.1 doesn't break them. + path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "static/robots.txt" + ) + response = self.get_and_head("/root_static" + urllib.parse.quote(path)) + self.assertEqual(response.code, 200) + + +class StaticDefaultFilenameTest(WebTestCase): + def get_app_kwargs(self): + return dict( + static_path=relpath("static"), + static_handler_args=dict(default_filename="index.html"), + ) + + def get_handlers(self): + return [] + + def test_static_default_filename(self): + response = self.fetch("/static/dir/", follow_redirects=False) + self.assertEqual(response.code, 200) + self.assertEqual(b"this is the index\n", response.body) + + def test_static_default_redirect(self): + response = self.fetch("/static/dir", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertTrue(response.headers["Location"].endswith("/static/dir/")) + + +class StaticFileWithPathTest(WebTestCase): + def get_app_kwargs(self): + return dict( + static_path=relpath("static"), + static_handler_args=dict(default_filename="index.html"), + ) + + def get_handlers(self): + return [("/foo/(.*)", StaticFileHandler, {"path": relpath("templates/")})] + + def test_serve(self): + response = self.fetch("/foo/utf8.html") + self.assertEqual(response.body, b"H\xc3\xa9llo\n") + + +class CustomStaticFileTest(WebTestCase): + def get_handlers(self): + class MyStaticFileHandler(StaticFileHandler): + @classmethod + def make_static_url(cls, settings, path): + version_hash = cls.get_version(settings, path) + extension_index = path.rindex(".") + before_version = path[:extension_index] + after_version = path[(extension_index + 1) :] + return "/static/%s.%s.%s" % ( + before_version, + version_hash, + after_version, + ) + + def parse_url_path(self, url_path): + extension_index = url_path.rindex(".") + version_index = url_path.rindex(".", 0, extension_index) + return "%s%s" % (url_path[:version_index], url_path[extension_index:]) + + @classmethod + def get_absolute_path(cls, settings, path): + return "CustomStaticFileTest:" + path + + def validate_absolute_path(self, root, absolute_path): + return absolute_path + + @classmethod + def get_content(self, path, start=None, end=None): + assert start is None and end is None + if path == "CustomStaticFileTest:foo.txt": + return b"bar" + raise Exception("unexpected path %r" % path) + + def get_content_size(self): + if self.absolute_path == "CustomStaticFileTest:foo.txt": + return 3 + raise Exception("unexpected path %r" % self.absolute_path) + + def get_modified_time(self): + return None + + @classmethod + def get_version(cls, settings, path): + return "42" + + class StaticUrlHandler(RequestHandler): + def get(self, path): + self.write(self.static_url(path)) + + self.static_handler_class = MyStaticFileHandler + + return [("/static_url/(.*)", StaticUrlHandler)] + + def get_app_kwargs(self): + return dict(static_path="dummy", static_handler_class=self.static_handler_class) + + def test_serve(self): + response = self.fetch("/static/foo.42.txt") + self.assertEqual(response.body, b"bar") + + def test_static_url(self): + with ExpectLog(gen_log, "Could not open static file", required=False): + response = self.fetch("/static_url/foo.txt") + self.assertEqual(response.body, b"/static/foo.42.txt") + + +class HostMatchingTest(WebTestCase): + class Handler(RequestHandler): + def initialize(self, reply): + self.reply = reply + + def get(self): + self.write(self.reply) + + def get_handlers(self): + return [("/foo", HostMatchingTest.Handler, {"reply": "wildcard"})] + + def test_host_matching(self): + self.app.add_handlers( + "www.example.com", [("/foo", HostMatchingTest.Handler, {"reply": "[0]"})] + ) + self.app.add_handlers( + r"www\.example\.com", [("/bar", HostMatchingTest.Handler, {"reply": "[1]"})] + ) + self.app.add_handlers( + "www.example.com", [("/baz", HostMatchingTest.Handler, {"reply": "[2]"})] + ) + self.app.add_handlers( + "www.e.*e.com", [("/baz", HostMatchingTest.Handler, {"reply": "[3]"})] + ) + + response = self.fetch("/foo") + self.assertEqual(response.body, b"wildcard") + response = self.fetch("/bar") + self.assertEqual(response.code, 404) + response = self.fetch("/baz") + self.assertEqual(response.code, 404) + + response = self.fetch("/foo", headers={"Host": "www.example.com"}) + self.assertEqual(response.body, b"[0]") + response = self.fetch("/bar", headers={"Host": "www.example.com"}) + self.assertEqual(response.body, b"[1]") + response = self.fetch("/baz", headers={"Host": "www.example.com"}) + self.assertEqual(response.body, b"[2]") + response = self.fetch("/baz", headers={"Host": "www.exe.com"}) + self.assertEqual(response.body, b"[3]") + + +class DefaultHostMatchingTest(WebTestCase): + def get_handlers(self): + return [] + + def get_app_kwargs(self): + return {"default_host": "www.example.com"} + + def test_default_host_matching(self): + self.app.add_handlers( + "www.example.com", [("/foo", HostMatchingTest.Handler, {"reply": "[0]"})] + ) + self.app.add_handlers( + r"www\.example\.com", [("/bar", HostMatchingTest.Handler, {"reply": "[1]"})] + ) + self.app.add_handlers( + "www.test.com", [("/baz", HostMatchingTest.Handler, {"reply": "[2]"})] + ) + + response = self.fetch("/foo") + self.assertEqual(response.body, b"[0]") + response = self.fetch("/bar") + self.assertEqual(response.body, b"[1]") + response = self.fetch("/baz") + self.assertEqual(response.code, 404) + + response = self.fetch("/foo", headers={"X-Real-Ip": "127.0.0.1"}) + self.assertEqual(response.code, 404) + + self.app.default_host = "www.test.com" + + response = self.fetch("/baz") + self.assertEqual(response.body, b"[2]") + + +class NamedURLSpecGroupsTest(WebTestCase): + def get_handlers(self): + class EchoHandler(RequestHandler): + def get(self, path): + self.write(path) + + return [ + ("/str/(?P<path>.*)", EchoHandler), + ("/unicode/(?P<path>.*)", EchoHandler), + ] + + def test_named_urlspec_groups(self): + response = self.fetch("/str/foo") + self.assertEqual(response.body, b"foo") + + response = self.fetch("/unicode/bar") + self.assertEqual(response.body, b"bar") + + +class ClearHeaderTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + self.set_header("h1", "foo") + self.set_header("h2", "bar") + self.clear_header("h1") + self.clear_header("nonexistent") + + def test_clear_header(self): + response = self.fetch("/") + self.assertTrue("h1" not in response.headers) + self.assertEqual(response.headers["h2"], "bar") + + +class Header204Test(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + self.set_status(204) + self.finish() + + def test_204_headers(self): + response = self.fetch("/") + self.assertEqual(response.code, 204) + self.assertNotIn("Content-Length", response.headers) + self.assertNotIn("Transfer-Encoding", response.headers) + + +class Header304Test(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + self.set_header("Content-Language", "en_US") + self.write("hello") + + def test_304_headers(self): + response1 = self.fetch("/") + self.assertEqual(response1.headers["Content-Length"], "5") + self.assertEqual(response1.headers["Content-Language"], "en_US") + + response2 = self.fetch( + "/", headers={"If-None-Match": response1.headers["Etag"]} + ) + self.assertEqual(response2.code, 304) + self.assertTrue("Content-Length" not in response2.headers) + self.assertTrue("Content-Language" not in response2.headers) + # Not an entity header, but should not be added to 304s by chunking + self.assertTrue("Transfer-Encoding" not in response2.headers) + + +class StatusReasonTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + reason = self.request.arguments.get("reason", []) + self.set_status( + int(self.get_argument("code")), + reason=to_unicode(reason[0]) if reason else None, + ) + + def get_http_client(self): + # simple_httpclient only: curl doesn't expose the reason string + return SimpleAsyncHTTPClient() + + def test_status(self): + response = self.fetch("/?code=304") + self.assertEqual(response.code, 304) + self.assertEqual(response.reason, "Not Modified") + response = self.fetch("/?code=304&reason=Foo") + self.assertEqual(response.code, 304) + self.assertEqual(response.reason, "Foo") + response = self.fetch("/?code=682&reason=Bar") + self.assertEqual(response.code, 682) + self.assertEqual(response.reason, "Bar") + response = self.fetch("/?code=682") + self.assertEqual(response.code, 682) + self.assertEqual(response.reason, "Unknown") + + +class DateHeaderTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + self.write("hello") + + def test_date_header(self): + response = self.fetch("/") + parsed = email.utils.parsedate(response.headers["Date"]) + assert parsed is not None + header_date = datetime.datetime(*parsed[:6]) + self.assertTrue( + header_date - datetime.datetime.utcnow() < datetime.timedelta(seconds=2) + ) + + +class RaiseWithReasonTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + raise HTTPError(682, reason="Foo") + + def get_http_client(self): + # simple_httpclient only: curl doesn't expose the reason string + return SimpleAsyncHTTPClient() + + def test_raise_with_reason(self): + response = self.fetch("/") + self.assertEqual(response.code, 682) + self.assertEqual(response.reason, "Foo") + self.assertIn(b"682: Foo", response.body) + + def test_httperror_str(self): + self.assertEqual(str(HTTPError(682, reason="Foo")), "HTTP 682: Foo") + + def test_httperror_str_from_httputil(self): + self.assertEqual(str(HTTPError(682)), "HTTP 682: Unknown") + + +class ErrorHandlerXSRFTest(WebTestCase): + def get_handlers(self): + # note that if the handlers list is empty we get the default_host + # redirect fallback instead of a 404, so test with both an + # explicitly defined error handler and an implicit 404. + return [("/error", ErrorHandler, dict(status_code=417))] + + def get_app_kwargs(self): + return dict(xsrf_cookies=True) + + def test_error_xsrf(self): + response = self.fetch("/error", method="POST", body="") + self.assertEqual(response.code, 417) + + def test_404_xsrf(self): + response = self.fetch("/404", method="POST", body="") + self.assertEqual(response.code, 404) + + +class GzipTestCase(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + for v in self.get_arguments("vary"): + self.add_header("Vary", v) + # Must write at least MIN_LENGTH bytes to activate compression. + self.write("hello world" + ("!" * GZipContentEncoding.MIN_LENGTH)) + + def get_app_kwargs(self): + return dict( + gzip=True, static_path=os.path.join(os.path.dirname(__file__), "static") + ) + + def assert_compressed(self, response): + # simple_httpclient renames the content-encoding header; + # curl_httpclient doesn't. + self.assertEqual( + response.headers.get( + "Content-Encoding", response.headers.get("X-Consumed-Content-Encoding") + ), + "gzip", + ) + + def test_gzip(self): + response = self.fetch("/") + self.assert_compressed(response) + self.assertEqual(response.headers["Vary"], "Accept-Encoding") + + def test_gzip_static(self): + # The streaming responses in StaticFileHandler have subtle + # interactions with the gzip output so test this case separately. + response = self.fetch("/robots.txt") + self.assert_compressed(response) + self.assertEqual(response.headers["Vary"], "Accept-Encoding") + + def test_gzip_not_requested(self): + response = self.fetch("/", use_gzip=False) + self.assertNotIn("Content-Encoding", response.headers) + self.assertEqual(response.headers["Vary"], "Accept-Encoding") + + def test_vary_already_present(self): + response = self.fetch("/?vary=Accept-Language") + self.assert_compressed(response) + self.assertEqual( + [s.strip() for s in response.headers["Vary"].split(",")], + ["Accept-Language", "Accept-Encoding"], + ) + + def test_vary_already_present_multiple(self): + # Regression test for https://github.com/tornadoweb/tornado/issues/1670 + response = self.fetch("/?vary=Accept-Language&vary=Cookie") + self.assert_compressed(response) + self.assertEqual( + [s.strip() for s in response.headers["Vary"].split(",")], + ["Accept-Language", "Cookie", "Accept-Encoding"], + ) + + +class PathArgsInPrepareTest(WebTestCase): + class Handler(RequestHandler): + def prepare(self): + self.write(dict(args=self.path_args, kwargs=self.path_kwargs)) + + def get(self, path): + assert path == "foo" + self.finish() + + def get_handlers(self): + return [("/pos/(.*)", self.Handler), ("/kw/(?P<path>.*)", self.Handler)] + + def test_pos(self): + response = self.fetch("/pos/foo") + response.rethrow() + data = json_decode(response.body) + self.assertEqual(data, {"args": ["foo"], "kwargs": {}}) + + def test_kw(self): + response = self.fetch("/kw/foo") + response.rethrow() + data = json_decode(response.body) + self.assertEqual(data, {"args": [], "kwargs": {"path": "foo"}}) + + +class ClearAllCookiesTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + self.clear_all_cookies() + self.write("ok") + + def test_clear_all_cookies(self): + response = self.fetch("/", headers={"Cookie": "foo=bar; baz=xyzzy"}) + set_cookies = sorted(response.headers.get_list("Set-Cookie")) + # Python 3.5 sends 'baz="";'; older versions use 'baz=;' + self.assertTrue( + set_cookies[0].startswith("baz=;") or set_cookies[0].startswith('baz="";') + ) + self.assertTrue( + set_cookies[1].startswith("foo=;") or set_cookies[1].startswith('foo="";') + ) + + +class PermissionError(Exception): + pass + + +class ExceptionHandlerTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + exc = self.get_argument("exc") + if exc == "http": + raise HTTPError(410, "no longer here") + elif exc == "zero": + 1 / 0 + elif exc == "permission": + raise PermissionError("not allowed") + + def write_error(self, status_code, **kwargs): + if "exc_info" in kwargs: + typ, value, tb = kwargs["exc_info"] + if isinstance(value, PermissionError): + self.set_status(403) + self.write("PermissionError") + return + RequestHandler.write_error(self, status_code, **kwargs) + + def log_exception(self, typ, value, tb): + if isinstance(value, PermissionError): + app_log.warning("custom logging for PermissionError: %s", value.args[0]) + else: + RequestHandler.log_exception(self, typ, value, tb) + + def test_http_error(self): + # HTTPErrors are logged as warnings with no stack trace. + # TODO: extend ExpectLog to test this more precisely + with ExpectLog(gen_log, ".*no longer here"): + response = self.fetch("/?exc=http") + self.assertEqual(response.code, 410) + + def test_unknown_error(self): + # Unknown errors are logged as errors with a stack trace. + with ExpectLog(app_log, "Uncaught exception"): + response = self.fetch("/?exc=zero") + self.assertEqual(response.code, 500) + + def test_known_error(self): + # log_exception can override logging behavior, and write_error + # can override the response. + with ExpectLog(app_log, "custom logging for PermissionError: not allowed"): + response = self.fetch("/?exc=permission") + self.assertEqual(response.code, 403) + + +class BuggyLoggingTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + 1 / 0 + + def log_exception(self, typ, value, tb): + 1 / 0 + + def test_buggy_log_exception(self): + # Something gets logged even though the application's + # logger is broken. + with ExpectLog(app_log, ".*"): + self.fetch("/") + + +class UIMethodUIModuleTest(SimpleHandlerTestCase): + """Test that UI methods and modules are created correctly and + associated with the handler. + """ + + class Handler(RequestHandler): + def get(self): + self.render("foo.html") + + def value(self): + return self.get_argument("value") + + def get_app_kwargs(self): + def my_ui_method(handler, x): + return "In my_ui_method(%s) with handler value %s." % (x, handler.value()) + + class MyModule(UIModule): + def render(self, x): + return "In MyModule(%s) with handler value %s." % ( + x, + typing.cast(UIMethodUIModuleTest.Handler, self.handler).value(), + ) + + loader = DictLoader( + {"foo.html": "{{ my_ui_method(42) }} {% module MyModule(123) %}"} + ) + return dict( + template_loader=loader, + ui_methods={"my_ui_method": my_ui_method}, + ui_modules={"MyModule": MyModule}, + ) + + def tearDown(self): + super().tearDown() + # TODO: fix template loader caching so this isn't necessary. + RequestHandler._template_loaders.clear() + + def test_ui_method(self): + response = self.fetch("/?value=asdf") + self.assertEqual( + response.body, + b"In my_ui_method(42) with handler value asdf. " + b"In MyModule(123) with handler value asdf.", + ) + + +class GetArgumentErrorTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + try: + self.get_argument("foo") + self.write({}) + except MissingArgumentError as e: + self.write({"arg_name": e.arg_name, "log_message": e.log_message}) + + def test_catch_error(self): + response = self.fetch("/") + self.assertEqual( + json_decode(response.body), + {"arg_name": "foo", "log_message": "Missing argument foo"}, + ) + + +class SetLazyPropertiesTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def prepare(self): + self.current_user = "Ben" + self.locale = locale.get("en_US") + + def get_user_locale(self): + raise NotImplementedError() + + def get_current_user(self): + raise NotImplementedError() + + def get(self): + self.write("Hello %s (%s)" % (self.current_user, self.locale.code)) + + def test_set_properties(self): + # Ensure that current_user can be assigned to normally for apps + # that want to forgo the lazy get_current_user property + response = self.fetch("/") + self.assertEqual(response.body, b"Hello Ben (en_US)") + + +class GetCurrentUserTest(WebTestCase): + def get_app_kwargs(self): + class WithoutUserModule(UIModule): + def render(self): + return "" + + class WithUserModule(UIModule): + def render(self): + return str(self.current_user) + + loader = DictLoader( + { + "without_user.html": "", + "with_user.html": "{{ current_user }}", + "without_user_module.html": "{% module WithoutUserModule() %}", + "with_user_module.html": "{% module WithUserModule() %}", + } + ) + return dict( + template_loader=loader, + ui_modules={ + "WithUserModule": WithUserModule, + "WithoutUserModule": WithoutUserModule, + }, + ) + + def tearDown(self): + super().tearDown() + RequestHandler._template_loaders.clear() + + def get_handlers(self): + class CurrentUserHandler(RequestHandler): + def prepare(self): + self.has_loaded_current_user = False + + def get_current_user(self): + self.has_loaded_current_user = True + return "" + + class WithoutUserHandler(CurrentUserHandler): + def get(self): + self.render_string("without_user.html") + self.finish(str(self.has_loaded_current_user)) + + class WithUserHandler(CurrentUserHandler): + def get(self): + self.render_string("with_user.html") + self.finish(str(self.has_loaded_current_user)) + + class CurrentUserModuleHandler(CurrentUserHandler): + def get_template_namespace(self): + # If RequestHandler.get_template_namespace is called, then + # get_current_user is evaluated. Until #820 is fixed, this + # is a small hack to circumvent the issue. + return self.ui + + class WithoutUserModuleHandler(CurrentUserModuleHandler): + def get(self): + self.render_string("without_user_module.html") + self.finish(str(self.has_loaded_current_user)) + + class WithUserModuleHandler(CurrentUserModuleHandler): + def get(self): + self.render_string("with_user_module.html") + self.finish(str(self.has_loaded_current_user)) + + return [ + ("/without_user", WithoutUserHandler), + ("/with_user", WithUserHandler), + ("/without_user_module", WithoutUserModuleHandler), + ("/with_user_module", WithUserModuleHandler), + ] + + @unittest.skip("needs fix") + def test_get_current_user_is_lazy(self): + # TODO: Make this test pass. See #820. + response = self.fetch("/without_user") + self.assertEqual(response.body, b"False") + + def test_get_current_user_works(self): + response = self.fetch("/with_user") + self.assertEqual(response.body, b"True") + + def test_get_current_user_from_ui_module_is_lazy(self): + response = self.fetch("/without_user_module") + self.assertEqual(response.body, b"False") + + def test_get_current_user_from_ui_module_works(self): + response = self.fetch("/with_user_module") + self.assertEqual(response.body, b"True") + + +class UnimplementedHTTPMethodsTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + pass + + def test_unimplemented_standard_methods(self): + for method in ["HEAD", "GET", "DELETE", "OPTIONS"]: + response = self.fetch("/", method=method) + self.assertEqual(response.code, 405) + for method in ["POST", "PUT"]: + response = self.fetch("/", method=method, body=b"") + self.assertEqual(response.code, 405) + + +class UnimplementedNonStandardMethodsTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def other(self): + # Even though this method exists, it won't get called automatically + # because it is not in SUPPORTED_METHODS. + self.write("other") + + def test_unimplemented_patch(self): + # PATCH is recently standardized; Tornado supports it by default + # but wsgiref.validate doesn't like it. + response = self.fetch("/", method="PATCH", body=b"") + self.assertEqual(response.code, 405) + + def test_unimplemented_other(self): + response = self.fetch("/", method="OTHER", allow_nonstandard_methods=True) + self.assertEqual(response.code, 405) + + +class AllHTTPMethodsTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def method(self): + assert self.request.method is not None + self.write(self.request.method) + + get = delete = options = post = put = method # type: ignore + + def test_standard_methods(self): + response = self.fetch("/", method="HEAD") + self.assertEqual(response.body, b"") + for method in ["GET", "DELETE", "OPTIONS"]: + response = self.fetch("/", method=method) + self.assertEqual(response.body, utf8(method)) + for method in ["POST", "PUT"]: + response = self.fetch("/", method=method, body=b"") + self.assertEqual(response.body, utf8(method)) + + +class PatchMethodTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ( # type: ignore + "OTHER", + ) + + def patch(self): + self.write("patch") + + def other(self): + self.write("other") + + def test_patch(self): + response = self.fetch("/", method="PATCH", body=b"") + self.assertEqual(response.body, b"patch") + + def test_other(self): + response = self.fetch("/", method="OTHER", allow_nonstandard_methods=True) + self.assertEqual(response.body, b"other") + + +class FinishInPrepareTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def prepare(self): + self.finish("done") + + def get(self): + # It's difficult to assert for certain that a method did not + # or will not be called in an asynchronous context, but this + # will be logged noisily if it is reached. + raise Exception("should not reach this method") + + def test_finish_in_prepare(self): + response = self.fetch("/") + self.assertEqual(response.body, b"done") + + +class Default404Test(WebTestCase): + def get_handlers(self): + # If there are no handlers at all a default redirect handler gets added. + return [("/foo", RequestHandler)] + + def test_404(self): + response = self.fetch("/") + self.assertEqual(response.code, 404) + self.assertEqual( + response.body, + b"<html><title>404: Not Found</title>" + b"<body>404: Not Found</body></html>", + ) + + +class Custom404Test(WebTestCase): + def get_handlers(self): + return [("/foo", RequestHandler)] + + def get_app_kwargs(self): + class Custom404Handler(RequestHandler): + def get(self): + self.set_status(404) + self.write("custom 404 response") + + return dict(default_handler_class=Custom404Handler) + + def test_404(self): + response = self.fetch("/") + self.assertEqual(response.code, 404) + self.assertEqual(response.body, b"custom 404 response") + + +class DefaultHandlerArgumentsTest(WebTestCase): + def get_handlers(self): + return [("/foo", RequestHandler)] + + def get_app_kwargs(self): + return dict( + default_handler_class=ErrorHandler, + default_handler_args=dict(status_code=403), + ) + + def test_403(self): + response = self.fetch("/") + self.assertEqual(response.code, 403) + + +class HandlerByNameTest(WebTestCase): + def get_handlers(self): + # All three are equivalent. + return [ + ("/hello1", HelloHandler), + ("/hello2", "tornado.test.web_test.HelloHandler"), + url("/hello3", "tornado.test.web_test.HelloHandler"), + ] + + def test_handler_by_name(self): + resp = self.fetch("/hello1") + self.assertEqual(resp.body, b"hello") + resp = self.fetch("/hello2") + self.assertEqual(resp.body, b"hello") + resp = self.fetch("/hello3") + self.assertEqual(resp.body, b"hello") + + +class StreamingRequestBodyTest(WebTestCase): + def get_handlers(self): + @stream_request_body + class StreamingBodyHandler(RequestHandler): + def initialize(self, test): + self.test = test + + def prepare(self): + self.test.prepared.set_result(None) + + def data_received(self, data): + self.test.data.set_result(data) + + def get(self): + self.test.finished.set_result(None) + self.write({}) + + @stream_request_body + class EarlyReturnHandler(RequestHandler): + def prepare(self): + # If we finish the response in prepare, it won't continue to + # the (non-existent) data_received. + raise HTTPError(401) + + @stream_request_body + class CloseDetectionHandler(RequestHandler): + def initialize(self, test): + self.test = test + + def on_connection_close(self): + super().on_connection_close() + self.test.close_future.set_result(None) + + return [ + ("/stream_body", StreamingBodyHandler, dict(test=self)), + ("/early_return", EarlyReturnHandler), + ("/close_detection", CloseDetectionHandler, dict(test=self)), + ] + + def connect(self, url, connection_close): + # Use a raw connection so we can control the sending of data. + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + s.connect(("127.0.0.1", self.get_http_port())) + stream = IOStream(s) + stream.write(b"GET " + url + b" HTTP/1.1\r\n") + if connection_close: + stream.write(b"Connection: close\r\n") + stream.write(b"Transfer-Encoding: chunked\r\n\r\n") + return stream + + @gen_test + def test_streaming_body(self): + self.prepared = Future() # type: Future[None] + self.data = Future() # type: Future[bytes] + self.finished = Future() # type: Future[None] + + stream = self.connect(b"/stream_body", connection_close=True) + yield self.prepared + stream.write(b"4\r\nasdf\r\n") + # Ensure the first chunk is received before we send the second. + data = yield self.data + self.assertEqual(data, b"asdf") + self.data = Future() + stream.write(b"4\r\nqwer\r\n") + data = yield self.data + self.assertEqual(data, b"qwer") + stream.write(b"0\r\n\r\n") + yield self.finished + data = yield stream.read_until_close() + # This would ideally use an HTTP1Connection to read the response. + self.assertTrue(data.endswith(b"{}")) + stream.close() + + @gen_test + def test_early_return(self): + stream = self.connect(b"/early_return", connection_close=False) + data = yield stream.read_until_close() + self.assertTrue(data.startswith(b"HTTP/1.1 401")) + + @gen_test + def test_early_return_with_data(self): + stream = self.connect(b"/early_return", connection_close=False) + stream.write(b"4\r\nasdf\r\n") + data = yield stream.read_until_close() + self.assertTrue(data.startswith(b"HTTP/1.1 401")) + + @gen_test + def test_close_during_upload(self): + self.close_future = Future() # type: Future[None] + stream = self.connect(b"/close_detection", connection_close=False) + stream.close() + yield self.close_future + + +# Each method in this handler returns a yieldable object and yields to the +# IOLoop so the future is not immediately ready. Ensure that the +# yieldables are respected and no method is called before the previous +# one has completed. +@stream_request_body +class BaseFlowControlHandler(RequestHandler): + def initialize(self, test): + self.test = test + self.method = None + self.methods = [] # type: typing.List[str] + + @contextlib.contextmanager + def in_method(self, method): + if self.method is not None: + self.test.fail("entered method %s while in %s" % (method, self.method)) + self.method = method + self.methods.append(method) + try: + yield + finally: + self.method = None + + @gen.coroutine + def prepare(self): + # Note that asynchronous prepare() does not block data_received, + # so we don't use in_method here. + self.methods.append("prepare") + yield gen.moment + + @gen.coroutine + def post(self): + with self.in_method("post"): + yield gen.moment + self.write(dict(methods=self.methods)) + + +class BaseStreamingRequestFlowControlTest(object): + def get_httpserver_options(self): + # Use a small chunk size so flow control is relevant even though + # all the data arrives at once. + return dict(chunk_size=10, decompress_request=True) + + def get_http_client(self): + # simple_httpclient only: curl doesn't support body_producer. + return SimpleAsyncHTTPClient() + + # Test all the slightly different code paths for fixed, chunked, etc bodies. + def test_flow_control_fixed_body(self: typing.Any): + response = self.fetch("/", body="abcdefghijklmnopqrstuvwxyz", method="POST") + response.rethrow() + self.assertEqual( + json_decode(response.body), + dict( + methods=[ + "prepare", + "data_received", + "data_received", + "data_received", + "post", + ] + ), + ) + + def test_flow_control_chunked_body(self: typing.Any): + chunks = [b"abcd", b"efgh", b"ijkl"] + + @gen.coroutine + def body_producer(write): + for i in chunks: + yield write(i) + + response = self.fetch("/", body_producer=body_producer, method="POST") + response.rethrow() + self.assertEqual( + json_decode(response.body), + dict( + methods=[ + "prepare", + "data_received", + "data_received", + "data_received", + "post", + ] + ), + ) + + def test_flow_control_compressed_body(self: typing.Any): + bytesio = BytesIO() + gzip_file = gzip.GzipFile(mode="w", fileobj=bytesio) + gzip_file.write(b"abcdefghijklmnopqrstuvwxyz") + gzip_file.close() + compressed_body = bytesio.getvalue() + response = self.fetch( + "/", + body=compressed_body, + method="POST", + headers={"Content-Encoding": "gzip"}, + ) + response.rethrow() + self.assertEqual( + json_decode(response.body), + dict( + methods=[ + "prepare", + "data_received", + "data_received", + "data_received", + "post", + ] + ), + ) + + +class DecoratedStreamingRequestFlowControlTest( + BaseStreamingRequestFlowControlTest, WebTestCase +): + def get_handlers(self): + class DecoratedFlowControlHandler(BaseFlowControlHandler): + @gen.coroutine + def data_received(self, data): + with self.in_method("data_received"): + yield gen.moment + + return [("/", DecoratedFlowControlHandler, dict(test=self))] + + +class NativeStreamingRequestFlowControlTest( + BaseStreamingRequestFlowControlTest, WebTestCase +): + def get_handlers(self): + class NativeFlowControlHandler(BaseFlowControlHandler): + async def data_received(self, data): + with self.in_method("data_received"): + import asyncio + + await asyncio.sleep(0) + + return [("/", NativeFlowControlHandler, dict(test=self))] + + +class IncorrectContentLengthTest(SimpleHandlerTestCase): + def get_handlers(self): + test = self + self.server_error = None + + # Manually set a content-length that doesn't match the actual content. + class TooHigh(RequestHandler): + def get(self): + self.set_header("Content-Length", "42") + try: + self.finish("ok") + except Exception as e: + test.server_error = e + raise + + class TooLow(RequestHandler): + def get(self): + self.set_header("Content-Length", "2") + try: + self.finish("hello") + except Exception as e: + test.server_error = e + raise + + return [("/high", TooHigh), ("/low", TooLow)] + + def test_content_length_too_high(self): + # When the content-length is too high, the connection is simply + # closed without completing the response. An error is logged on + # the server. + with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"): + with ExpectLog( + gen_log, + "(Cannot send error response after headers written" + "|Failed to flush partial response)", + ): + with self.assertRaises(HTTPClientError): + self.fetch("/high", raise_error=True) + self.assertEqual( + str(self.server_error), "Tried to write 40 bytes less than Content-Length" + ) + + def test_content_length_too_low(self): + # When the content-length is too low, the connection is closed + # without writing the last chunk, so the client never sees the request + # complete (which would be a framing error). + with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"): + with ExpectLog( + gen_log, + "(Cannot send error response after headers written" + "|Failed to flush partial response)", + ): + with self.assertRaises(HTTPClientError): + self.fetch("/low", raise_error=True) + self.assertEqual( + str(self.server_error), "Tried to write more data than Content-Length" + ) + + +class ClientCloseTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + if self.request.version.startswith("HTTP/1"): + # Simulate a connection closed by the client during + # request processing. The client will see an error, but the + # server should respond gracefully (without logging errors + # because we were unable to write out as many bytes as + # Content-Length said we would) + self.request.connection.stream.close() # type: ignore + self.write("hello") + else: + # TODO: add a HTTP2-compatible version of this test. + self.write("requires HTTP/1.x") + + def test_client_close(self): + with self.assertRaises((HTTPClientError, unittest.SkipTest)): # type: ignore + response = self.fetch("/", raise_error=True) + if response.body == b"requires HTTP/1.x": + self.skipTest("requires HTTP/1.x") + self.assertEqual(response.code, 599) + + +class SignedValueTest(unittest.TestCase): + SECRET = "It's a secret to everybody" + SECRET_DICT = {0: "asdfbasdf", 1: "12312312", 2: "2342342"} + + def past(self): + return self.present() - 86400 * 32 + + def present(self): + return 1300000000 + + def test_known_values(self): + signed_v1 = create_signed_value( + SignedValueTest.SECRET, "key", "value", version=1, clock=self.present + ) + self.assertEqual( + signed_v1, b"dmFsdWU=|1300000000|31c934969f53e48164c50768b40cbd7e2daaaa4f" + ) + + signed_v2 = create_signed_value( + SignedValueTest.SECRET, "key", "value", version=2, clock=self.present + ) + self.assertEqual( + signed_v2, + b"2|1:0|10:1300000000|3:key|8:dmFsdWU=|" + b"3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152", + ) + + signed_default = create_signed_value( + SignedValueTest.SECRET, "key", "value", clock=self.present + ) + self.assertEqual(signed_default, signed_v2) + + decoded_v1 = decode_signed_value( + SignedValueTest.SECRET, "key", signed_v1, min_version=1, clock=self.present + ) + self.assertEqual(decoded_v1, b"value") + + decoded_v2 = decode_signed_value( + SignedValueTest.SECRET, "key", signed_v2, min_version=2, clock=self.present + ) + self.assertEqual(decoded_v2, b"value") + + def test_name_swap(self): + signed1 = create_signed_value( + SignedValueTest.SECRET, "key1", "value", clock=self.present + ) + signed2 = create_signed_value( + SignedValueTest.SECRET, "key2", "value", clock=self.present + ) + # Try decoding each string with the other's "name" + decoded1 = decode_signed_value( + SignedValueTest.SECRET, "key2", signed1, clock=self.present + ) + self.assertIs(decoded1, None) + decoded2 = decode_signed_value( + SignedValueTest.SECRET, "key1", signed2, clock=self.present + ) + self.assertIs(decoded2, None) + + def test_expired(self): + signed = create_signed_value( + SignedValueTest.SECRET, "key1", "value", clock=self.past + ) + decoded_past = decode_signed_value( + SignedValueTest.SECRET, "key1", signed, clock=self.past + ) + self.assertEqual(decoded_past, b"value") + decoded_present = decode_signed_value( + SignedValueTest.SECRET, "key1", signed, clock=self.present + ) + self.assertIs(decoded_present, None) + + def test_payload_tampering(self): + # These cookies are variants of the one in test_known_values. + sig = "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152" + + def validate(prefix): + return b"value" == decode_signed_value( + SignedValueTest.SECRET, "key", prefix + sig, clock=self.present + ) + + self.assertTrue(validate("2|1:0|10:1300000000|3:key|8:dmFsdWU=|")) + # Change key version + self.assertFalse(validate("2|1:1|10:1300000000|3:key|8:dmFsdWU=|")) + # length mismatch (field too short) + self.assertFalse(validate("2|1:0|10:130000000|3:key|8:dmFsdWU=|")) + # length mismatch (field too long) + self.assertFalse(validate("2|1:0|10:1300000000|3:keey|8:dmFsdWU=|")) + + def test_signature_tampering(self): + prefix = "2|1:0|10:1300000000|3:key|8:dmFsdWU=|" + + def validate(sig): + return b"value" == decode_signed_value( + SignedValueTest.SECRET, "key", prefix + sig, clock=self.present + ) + + self.assertTrue( + validate("3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152") + ) + # All zeros + self.assertFalse(validate("0" * 32)) + # Change one character + self.assertFalse( + validate("4d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152") + ) + # Change another character + self.assertFalse( + validate("3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e153") + ) + # Truncate + self.assertFalse( + validate("3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e15") + ) + # Lengthen + self.assertFalse( + validate( + "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e1538" + ) + ) + + def test_non_ascii(self): + value = b"\xe9" + signed = create_signed_value( + SignedValueTest.SECRET, "key", value, clock=self.present + ) + decoded = decode_signed_value( + SignedValueTest.SECRET, "key", signed, clock=self.present + ) + self.assertEqual(value, decoded) + + def test_key_versioning_read_write_default_key(self): + value = b"\xe9" + signed = create_signed_value( + SignedValueTest.SECRET_DICT, "key", value, clock=self.present, key_version=0 + ) + decoded = decode_signed_value( + SignedValueTest.SECRET_DICT, "key", signed, clock=self.present + ) + self.assertEqual(value, decoded) + + def test_key_versioning_read_write_non_default_key(self): + value = b"\xe9" + signed = create_signed_value( + SignedValueTest.SECRET_DICT, "key", value, clock=self.present, key_version=1 + ) + decoded = decode_signed_value( + SignedValueTest.SECRET_DICT, "key", signed, clock=self.present + ) + self.assertEqual(value, decoded) + + def test_key_versioning_invalid_key(self): + value = b"\xe9" + signed = create_signed_value( + SignedValueTest.SECRET_DICT, "key", value, clock=self.present, key_version=0 + ) + newkeys = SignedValueTest.SECRET_DICT.copy() + newkeys.pop(0) + decoded = decode_signed_value(newkeys, "key", signed, clock=self.present) + self.assertEqual(None, decoded) + + def test_key_version_retrieval(self): + value = b"\xe9" + signed = create_signed_value( + SignedValueTest.SECRET_DICT, "key", value, clock=self.present, key_version=1 + ) + key_version = get_signature_key_version(signed) + self.assertEqual(1, key_version) + + +class XSRFTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + version = int(self.get_argument("version", "2")) + # This would be a bad idea in a real app, but in this test + # it's fine. + self.settings["xsrf_cookie_version"] = version + self.write(self.xsrf_token) + + def post(self): + self.write("ok") + + def get_app_kwargs(self): + return dict(xsrf_cookies=True) + + def setUp(self): + super().setUp() + self.xsrf_token = self.get_token() + + def get_token(self, old_token=None, version=None): + if old_token is not None: + headers = self.cookie_headers(old_token) + else: + headers = None + response = self.fetch( + "/" if version is None else ("/?version=%d" % version), headers=headers + ) + response.rethrow() + return native_str(response.body) + + def cookie_headers(self, token=None): + if token is None: + token = self.xsrf_token + return {"Cookie": "_xsrf=" + token} + + def test_xsrf_fail_no_token(self): + with ExpectLog(gen_log, ".*'_xsrf' argument missing"): + response = self.fetch("/", method="POST", body=b"") + self.assertEqual(response.code, 403) + + def test_xsrf_fail_body_no_cookie(self): + with ExpectLog(gen_log, ".*XSRF cookie does not match POST"): + response = self.fetch( + "/", + method="POST", + body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)), + ) + self.assertEqual(response.code, 403) + + def test_xsrf_fail_argument_invalid_format(self): + with ExpectLog(gen_log, ".*'_xsrf' argument has invalid format"): + response = self.fetch( + "/", + method="POST", + headers=self.cookie_headers(), + body=urllib.parse.urlencode(dict(_xsrf="3|")), + ) + self.assertEqual(response.code, 403) + + def test_xsrf_fail_cookie_invalid_format(self): + with ExpectLog(gen_log, ".*XSRF cookie does not match POST"): + response = self.fetch( + "/", + method="POST", + headers=self.cookie_headers(token="3|"), + body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)), + ) + self.assertEqual(response.code, 403) + + def test_xsrf_fail_cookie_no_body(self): + with ExpectLog(gen_log, ".*'_xsrf' argument missing"): + response = self.fetch( + "/", method="POST", body=b"", headers=self.cookie_headers() + ) + self.assertEqual(response.code, 403) + + def test_xsrf_success_short_token(self): + response = self.fetch( + "/", + method="POST", + body=urllib.parse.urlencode(dict(_xsrf="deadbeef")), + headers=self.cookie_headers(token="deadbeef"), + ) + self.assertEqual(response.code, 200) + + def test_xsrf_success_non_hex_token(self): + response = self.fetch( + "/", + method="POST", + body=urllib.parse.urlencode(dict(_xsrf="xoxo")), + headers=self.cookie_headers(token="xoxo"), + ) + self.assertEqual(response.code, 200) + + def test_xsrf_success_post_body(self): + response = self.fetch( + "/", + method="POST", + body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)), + headers=self.cookie_headers(), + ) + self.assertEqual(response.code, 200) + + def test_xsrf_success_query_string(self): + response = self.fetch( + "/?" + urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)), + method="POST", + body=b"", + headers=self.cookie_headers(), + ) + self.assertEqual(response.code, 200) + + def test_xsrf_success_header(self): + response = self.fetch( + "/", + method="POST", + body=b"", + headers=dict( + {"X-Xsrftoken": self.xsrf_token}, # type: ignore + **self.cookie_headers() + ), + ) + self.assertEqual(response.code, 200) + + def test_distinct_tokens(self): + # Every request gets a distinct token. + NUM_TOKENS = 10 + tokens = set() + for i in range(NUM_TOKENS): + tokens.add(self.get_token()) + self.assertEqual(len(tokens), NUM_TOKENS) + + def test_cross_user(self): + token2 = self.get_token() + # Each token can be used to authenticate its own request. + for token in (self.xsrf_token, token2): + response = self.fetch( + "/", + method="POST", + body=urllib.parse.urlencode(dict(_xsrf=token)), + headers=self.cookie_headers(token), + ) + self.assertEqual(response.code, 200) + # Sending one in the cookie and the other in the body is not allowed. + for cookie_token, body_token in ( + (self.xsrf_token, token2), + (token2, self.xsrf_token), + ): + with ExpectLog(gen_log, ".*XSRF cookie does not match POST"): + response = self.fetch( + "/", + method="POST", + body=urllib.parse.urlencode(dict(_xsrf=body_token)), + headers=self.cookie_headers(cookie_token), + ) + self.assertEqual(response.code, 403) + + def test_refresh_token(self): + token = self.xsrf_token + tokens_seen = set([token]) + # A user's token is stable over time. Refreshing the page in one tab + # might update the cookie while an older tab still has the old cookie + # in its DOM. Simulate this scenario by passing a constant token + # in the body and re-querying for the token. + for i in range(5): + token = self.get_token(token) + # Tokens are encoded uniquely each time + tokens_seen.add(token) + response = self.fetch( + "/", + method="POST", + body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)), + headers=self.cookie_headers(token), + ) + self.assertEqual(response.code, 200) + self.assertEqual(len(tokens_seen), 6) + + def test_versioning(self): + # Version 1 still produces distinct tokens per request. + self.assertNotEqual(self.get_token(version=1), self.get_token(version=1)) + + # Refreshed v1 tokens are all identical. + v1_token = self.get_token(version=1) + for i in range(5): + self.assertEqual(self.get_token(v1_token, version=1), v1_token) + + # Upgrade to a v2 version of the same token + v2_token = self.get_token(v1_token) + self.assertNotEqual(v1_token, v2_token) + # Each v1 token can map to many v2 tokens. + self.assertNotEqual(v2_token, self.get_token(v1_token)) + + # The tokens are cross-compatible. + for cookie_token, body_token in ((v1_token, v2_token), (v2_token, v1_token)): + response = self.fetch( + "/", + method="POST", + body=urllib.parse.urlencode(dict(_xsrf=body_token)), + headers=self.cookie_headers(cookie_token), + ) + self.assertEqual(response.code, 200) + + +# A subset of the previous test with a different cookie name +class XSRFCookieNameTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + self.write(self.xsrf_token) + + def post(self): + self.write("ok") + + def get_app_kwargs(self): + return dict( + xsrf_cookies=True, + xsrf_cookie_name="__Host-xsrf", + xsrf_cookie_kwargs={"secure": True}, + ) + + def setUp(self): + super().setUp() + self.xsrf_token = self.get_token() + + def get_token(self, old_token=None): + if old_token is not None: + headers = self.cookie_headers(old_token) + else: + headers = None + response = self.fetch("/", headers=headers) + response.rethrow() + return native_str(response.body) + + def cookie_headers(self, token=None): + if token is None: + token = self.xsrf_token + return {"Cookie": "__Host-xsrf=" + token} + + def test_xsrf_fail_no_token(self): + with ExpectLog(gen_log, ".*'_xsrf' argument missing"): + response = self.fetch("/", method="POST", body=b"") + self.assertEqual(response.code, 403) + + def test_xsrf_fail_body_no_cookie(self): + with ExpectLog(gen_log, ".*XSRF cookie does not match POST"): + response = self.fetch( + "/", + method="POST", + body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)), + ) + self.assertEqual(response.code, 403) + + def test_xsrf_success_post_body(self): + response = self.fetch( + "/", + method="POST", + # Note that renaming the cookie doesn't rename the POST param + body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)), + headers=self.cookie_headers(), + ) + self.assertEqual(response.code, 200) + + +class XSRFCookieKwargsTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + self.write(self.xsrf_token) + + def get_app_kwargs(self): + return dict( + xsrf_cookies=True, xsrf_cookie_kwargs=dict(httponly=True, expires_days=2) + ) + + def test_xsrf_httponly(self): + response = self.fetch("/") + self.assertIn("httponly;", response.headers["Set-Cookie"].lower()) + self.assertIn("expires=", response.headers["Set-Cookie"].lower()) + header = response.headers.get("Set-Cookie") + assert header is not None + match = re.match(".*; expires=(?P<expires>.+);.*", header) + assert match is not None + + expires = datetime.datetime.utcnow() + datetime.timedelta(days=2) + parsed = email.utils.parsedate(match.groupdict()["expires"]) + assert parsed is not None + header_expires = datetime.datetime(*parsed[:6]) + self.assertTrue(abs((expires - header_expires).total_seconds()) < 10) + + +class FinishExceptionTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + self.set_status(401) + self.set_header("WWW-Authenticate", 'Basic realm="something"') + if self.get_argument("finish_value", ""): + raise Finish("authentication required") + else: + self.write("authentication required") + raise Finish() + + def test_finish_exception(self): + for u in ["/", "/?finish_value=1"]: + response = self.fetch(u) + self.assertEqual(response.code, 401) + self.assertEqual( + 'Basic realm="something"', response.headers.get("WWW-Authenticate") + ) + self.assertEqual(b"authentication required", response.body) + + +class DecoratorTest(WebTestCase): + def get_handlers(self): + class RemoveSlashHandler(RequestHandler): + @removeslash + def get(self): + pass + + class AddSlashHandler(RequestHandler): + @addslash + def get(self): + pass + + return [("/removeslash/", RemoveSlashHandler), ("/addslash", AddSlashHandler)] + + def test_removeslash(self): + response = self.fetch("/removeslash/", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers["Location"], "/removeslash") + + response = self.fetch("/removeslash/?foo=bar", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers["Location"], "/removeslash?foo=bar") + + def test_addslash(self): + response = self.fetch("/addslash", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers["Location"], "/addslash/") + + response = self.fetch("/addslash?foo=bar", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers["Location"], "/addslash/?foo=bar") + + +class CacheTest(WebTestCase): + def get_handlers(self): + class EtagHandler(RequestHandler): + def get(self, computed_etag): + self.write(computed_etag) + + def compute_etag(self): + return self._write_buffer[0] + + return [("/etag/(.*)", EtagHandler)] + + def test_wildcard_etag(self): + computed_etag = '"xyzzy"' + etags = "*" + self._test_etag(computed_etag, etags, 304) + + def test_strong_etag_match(self): + computed_etag = '"xyzzy"' + etags = '"xyzzy"' + self._test_etag(computed_etag, etags, 304) + + def test_multiple_strong_etag_match(self): + computed_etag = '"xyzzy1"' + etags = '"xyzzy1", "xyzzy2"' + self._test_etag(computed_etag, etags, 304) + + def test_strong_etag_not_match(self): + computed_etag = '"xyzzy"' + etags = '"xyzzy1"' + self._test_etag(computed_etag, etags, 200) + + def test_multiple_strong_etag_not_match(self): + computed_etag = '"xyzzy"' + etags = '"xyzzy1", "xyzzy2"' + self._test_etag(computed_etag, etags, 200) + + def test_weak_etag_match(self): + computed_etag = '"xyzzy1"' + etags = 'W/"xyzzy1"' + self._test_etag(computed_etag, etags, 304) + + def test_multiple_weak_etag_match(self): + computed_etag = '"xyzzy2"' + etags = 'W/"xyzzy1", W/"xyzzy2"' + self._test_etag(computed_etag, etags, 304) + + def test_weak_etag_not_match(self): + computed_etag = '"xyzzy2"' + etags = 'W/"xyzzy1"' + self._test_etag(computed_etag, etags, 200) + + def test_multiple_weak_etag_not_match(self): + computed_etag = '"xyzzy3"' + etags = 'W/"xyzzy1", W/"xyzzy2"' + self._test_etag(computed_etag, etags, 200) + + def _test_etag(self, computed_etag, etags, status_code): + response = self.fetch( + "/etag/" + computed_etag, headers={"If-None-Match": etags} + ) + self.assertEqual(response.code, status_code) + + +class RequestSummaryTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def get(self): + # remote_ip is optional, although it's set by + # both HTTPServer and WSGIAdapter. + # Clobber it to make sure it doesn't break logging. + self.request.remote_ip = None + self.finish(self._request_summary()) + + def test_missing_remote_ip(self): + resp = self.fetch("/") + self.assertEqual(resp.body, b"GET / (None)") + + +class HTTPErrorTest(unittest.TestCase): + def test_copy(self): + e = HTTPError(403, reason="Go away") + e2 = copy.copy(e) + self.assertIsNot(e, e2) + self.assertEqual(e.status_code, e2.status_code) + self.assertEqual(e.reason, e2.reason) + + +class ApplicationTest(AsyncTestCase): + def test_listen(self): + app = Application([]) + server = app.listen(0, address="127.0.0.1") + server.stop() + + +class URLSpecReverseTest(unittest.TestCase): + def test_reverse(self): + self.assertEqual("/favicon.ico", url(r"/favicon\.ico", None).reverse()) + self.assertEqual("/favicon.ico", url(r"^/favicon\.ico$", None).reverse()) + + def test_non_reversible(self): + # URLSpecs are non-reversible if they include non-constant + # regex features outside capturing groups. Currently, this is + # only strictly enforced for backslash-escaped character + # classes. + paths = [r"^/api/v\d+/foo/(\w+)$"] + for path in paths: + # A URLSpec can still be created even if it cannot be reversed. + url_spec = url(path, None) + try: + result = url_spec.reverse() + self.fail( + "did not get expected exception when reversing %s. " + "result: %s" % (path, result) + ) + except ValueError: + pass + + def test_reverse_arguments(self): + self.assertEqual( + "/api/v1/foo/bar", url(r"^/api/v1/foo/(\w+)$", None).reverse("bar") + ) + self.assertEqual( + "/api.v1/foo/5/icon.png", + url(r"/api\.v1/foo/([0-9]+)/icon\.png", None).reverse(5), + ) + + +class RedirectHandlerTest(WebTestCase): + def get_handlers(self): + return [ + ("/src", WebRedirectHandler, {"url": "/dst"}), + ("/src2", WebRedirectHandler, {"url": "/dst2?foo=bar"}), + (r"/(.*?)/(.*?)/(.*)", WebRedirectHandler, {"url": "/{1}/{0}/{2}"}), + ] + + def test_basic_redirect(self): + response = self.fetch("/src", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers["Location"], "/dst") + + def test_redirect_with_argument(self): + response = self.fetch("/src?foo=bar", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers["Location"], "/dst?foo=bar") + + def test_redirect_with_appending_argument(self): + response = self.fetch("/src2?foo2=bar2", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers["Location"], "/dst2?foo=bar&foo2=bar2") + + def test_redirect_pattern(self): + response = self.fetch("/a/b/c", follow_redirects=False) + self.assertEqual(response.code, 301) + self.assertEqual(response.headers["Location"], "/b/a/c") + + +class AcceptLanguageTest(WebTestCase): + """Test evaluation of Accept-Language header""" + + def get_handlers(self): + locale.load_gettext_translations( + os.path.join(os.path.dirname(__file__), "gettext_translations"), + "tornado_test", + ) + + class AcceptLanguageHandler(RequestHandler): + def get(self): + self.set_header( + "Content-Language", self.get_browser_locale().code.replace("_", "-") + ) + self.finish(b"") + + return [ + ("/", AcceptLanguageHandler), + ] + + def test_accept_language(self): + response = self.fetch("/", headers={"Accept-Language": "fr-FR;q=0.9"}) + self.assertEqual(response.headers["Content-Language"], "fr-FR") + + response = self.fetch("/", headers={"Accept-Language": "fr-FR; q=0.9"}) + self.assertEqual(response.headers["Content-Language"], "fr-FR") + + def test_accept_language_ignore(self): + response = self.fetch("/", headers={"Accept-Language": "fr-FR;q=0"}) + self.assertEqual(response.headers["Content-Language"], "en-US") + + def test_accept_language_invalid(self): + response = self.fetch("/", headers={"Accept-Language": "fr-FR;q=-1"}) + self.assertEqual(response.headers["Content-Language"], "en-US") diff --git a/venv/lib/python3.9/site-packages/tornado/test/websocket_test.py b/venv/lib/python3.9/site-packages/tornado/test/websocket_test.py new file mode 100644 index 00000000..f90c5f2c --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/websocket_test.py @@ -0,0 +1,858 @@ +import asyncio +import functools +import socket +import traceback +import typing +import unittest + +from tornado.concurrent import Future +from tornado import gen +from tornado.httpclient import HTTPError, HTTPRequest +from tornado.locks import Event +from tornado.log import gen_log, app_log +from tornado.netutil import Resolver +from tornado.simple_httpclient import SimpleAsyncHTTPClient +from tornado.template import DictLoader +from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog +from tornado.web import Application, RequestHandler + +try: + import tornado.websocket # noqa: F401 + from tornado.util import _websocket_mask_python +except ImportError: + # The unittest module presents misleading errors on ImportError + # (it acts as if websocket_test could not be found, hiding the underlying + # error). If we get an ImportError here (which could happen due to + # TORNADO_EXTENSION=1), print some extra information before failing. + traceback.print_exc() + raise + +from tornado.websocket import ( + WebSocketHandler, + websocket_connect, + WebSocketError, + WebSocketClosedError, +) + +try: + from tornado import speedups +except ImportError: + speedups = None # type: ignore + + +class TestWebSocketHandler(WebSocketHandler): + """Base class for testing handlers that exposes the on_close event. + + This allows for tests to see the close code and reason on the + server side. + + """ + + def initialize(self, close_future=None, compression_options=None): + self.close_future = close_future + self.compression_options = compression_options + + def get_compression_options(self): + return self.compression_options + + def on_close(self): + if self.close_future is not None: + self.close_future.set_result((self.close_code, self.close_reason)) + + +class EchoHandler(TestWebSocketHandler): + @gen.coroutine + def on_message(self, message): + try: + yield self.write_message(message, isinstance(message, bytes)) + except asyncio.CancelledError: + pass + except WebSocketClosedError: + pass + + +class ErrorInOnMessageHandler(TestWebSocketHandler): + def on_message(self, message): + 1 / 0 + + +class HeaderHandler(TestWebSocketHandler): + def open(self): + methods_to_test = [ + functools.partial(self.write, "This should not work"), + functools.partial(self.redirect, "http://localhost/elsewhere"), + functools.partial(self.set_header, "X-Test", ""), + functools.partial(self.set_cookie, "Chocolate", "Chip"), + functools.partial(self.set_status, 503), + self.flush, + self.finish, + ] + for method in methods_to_test: + try: + # In a websocket context, many RequestHandler methods + # raise RuntimeErrors. + method() # type: ignore + raise Exception("did not get expected exception") + except RuntimeError: + pass + self.write_message(self.request.headers.get("X-Test", "")) + + +class HeaderEchoHandler(TestWebSocketHandler): + def set_default_headers(self): + self.set_header("X-Extra-Response-Header", "Extra-Response-Value") + + def prepare(self): + for k, v in self.request.headers.get_all(): + if k.lower().startswith("x-test"): + self.set_header(k, v) + + +class NonWebSocketHandler(RequestHandler): + def get(self): + self.write("ok") + + +class RedirectHandler(RequestHandler): + def get(self): + self.redirect("/echo") + + +class CloseReasonHandler(TestWebSocketHandler): + def open(self): + self.on_close_called = False + self.close(1001, "goodbye") + + +class AsyncPrepareHandler(TestWebSocketHandler): + @gen.coroutine + def prepare(self): + yield gen.moment + + def on_message(self, message): + self.write_message(message) + + +class PathArgsHandler(TestWebSocketHandler): + def open(self, arg): + self.write_message(arg) + + +class CoroutineOnMessageHandler(TestWebSocketHandler): + def initialize(self, **kwargs): + super().initialize(**kwargs) + self.sleeping = 0 + + @gen.coroutine + def on_message(self, message): + if self.sleeping > 0: + self.write_message("another coroutine is already sleeping") + self.sleeping += 1 + yield gen.sleep(0.01) + self.sleeping -= 1 + self.write_message(message) + + +class RenderMessageHandler(TestWebSocketHandler): + def on_message(self, message): + self.write_message(self.render_string("message.html", message=message)) + + +class SubprotocolHandler(TestWebSocketHandler): + def initialize(self, **kwargs): + super().initialize(**kwargs) + self.select_subprotocol_called = False + + def select_subprotocol(self, subprotocols): + if self.select_subprotocol_called: + raise Exception("select_subprotocol called twice") + self.select_subprotocol_called = True + if "goodproto" in subprotocols: + return "goodproto" + return None + + def open(self): + if not self.select_subprotocol_called: + raise Exception("select_subprotocol not called") + self.write_message("subprotocol=%s" % self.selected_subprotocol) + + +class OpenCoroutineHandler(TestWebSocketHandler): + def initialize(self, test, **kwargs): + super().initialize(**kwargs) + self.test = test + self.open_finished = False + + @gen.coroutine + def open(self): + yield self.test.message_sent.wait() + yield gen.sleep(0.010) + self.open_finished = True + + def on_message(self, message): + if not self.open_finished: + raise Exception("on_message called before open finished") + self.write_message("ok") + + +class ErrorInOpenHandler(TestWebSocketHandler): + def open(self): + raise Exception("boom") + + +class ErrorInAsyncOpenHandler(TestWebSocketHandler): + async def open(self): + await asyncio.sleep(0) + raise Exception("boom") + + +class NoDelayHandler(TestWebSocketHandler): + def open(self): + self.set_nodelay(True) + self.write_message("hello") + + +class WebSocketBaseTestCase(AsyncHTTPTestCase): + @gen.coroutine + def ws_connect(self, path, **kwargs): + ws = yield websocket_connect( + "ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs + ) + raise gen.Return(ws) + + +class WebSocketTest(WebSocketBaseTestCase): + def get_app(self): + self.close_future = Future() # type: Future[None] + return Application( + [ + ("/echo", EchoHandler, dict(close_future=self.close_future)), + ("/non_ws", NonWebSocketHandler), + ("/redirect", RedirectHandler), + ("/header", HeaderHandler, dict(close_future=self.close_future)), + ( + "/header_echo", + HeaderEchoHandler, + dict(close_future=self.close_future), + ), + ( + "/close_reason", + CloseReasonHandler, + dict(close_future=self.close_future), + ), + ( + "/error_in_on_message", + ErrorInOnMessageHandler, + dict(close_future=self.close_future), + ), + ( + "/async_prepare", + AsyncPrepareHandler, + dict(close_future=self.close_future), + ), + ( + "/path_args/(.*)", + PathArgsHandler, + dict(close_future=self.close_future), + ), + ( + "/coroutine", + CoroutineOnMessageHandler, + dict(close_future=self.close_future), + ), + ("/render", RenderMessageHandler, dict(close_future=self.close_future)), + ( + "/subprotocol", + SubprotocolHandler, + dict(close_future=self.close_future), + ), + ( + "/open_coroutine", + OpenCoroutineHandler, + dict(close_future=self.close_future, test=self), + ), + ("/error_in_open", ErrorInOpenHandler), + ("/error_in_async_open", ErrorInAsyncOpenHandler), + ("/nodelay", NoDelayHandler), + ], + template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}), + ) + + def get_http_client(self): + # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient. + return SimpleAsyncHTTPClient() + + def tearDown(self): + super().tearDown() + RequestHandler._template_loaders.clear() + + def test_http_request(self): + # WS server, HTTP client. + response = self.fetch("/echo") + self.assertEqual(response.code, 400) + + def test_missing_websocket_key(self): + response = self.fetch( + "/echo", + headers={ + "Connection": "Upgrade", + "Upgrade": "WebSocket", + "Sec-WebSocket-Version": "13", + }, + ) + self.assertEqual(response.code, 400) + + def test_bad_websocket_version(self): + response = self.fetch( + "/echo", + headers={ + "Connection": "Upgrade", + "Upgrade": "WebSocket", + "Sec-WebSocket-Version": "12", + }, + ) + self.assertEqual(response.code, 426) + + @gen_test + def test_websocket_gen(self): + ws = yield self.ws_connect("/echo") + yield ws.write_message("hello") + response = yield ws.read_message() + self.assertEqual(response, "hello") + + def test_websocket_callbacks(self): + websocket_connect( + "ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop + ) + ws = self.wait().result() + ws.write_message("hello") + ws.read_message(self.stop) + response = self.wait().result() + self.assertEqual(response, "hello") + self.close_future.add_done_callback(lambda f: self.stop()) + ws.close() + self.wait() + + @gen_test + def test_binary_message(self): + ws = yield self.ws_connect("/echo") + ws.write_message(b"hello \xe9", binary=True) + response = yield ws.read_message() + self.assertEqual(response, b"hello \xe9") + + @gen_test + def test_unicode_message(self): + ws = yield self.ws_connect("/echo") + ws.write_message("hello \u00e9") + response = yield ws.read_message() + self.assertEqual(response, "hello \u00e9") + + @gen_test + def test_error_in_closed_client_write_message(self): + ws = yield self.ws_connect("/echo") + ws.close() + with self.assertRaises(WebSocketClosedError): + ws.write_message("hello \u00e9") + + @gen_test + def test_render_message(self): + ws = yield self.ws_connect("/render") + ws.write_message("hello") + response = yield ws.read_message() + self.assertEqual(response, "<b>hello</b>") + + @gen_test + def test_error_in_on_message(self): + ws = yield self.ws_connect("/error_in_on_message") + ws.write_message("hello") + with ExpectLog(app_log, "Uncaught exception"): + response = yield ws.read_message() + self.assertIs(response, None) + + @gen_test + def test_websocket_http_fail(self): + with self.assertRaises(HTTPError) as cm: + yield self.ws_connect("/notfound") + self.assertEqual(cm.exception.code, 404) + + @gen_test + def test_websocket_http_success(self): + with self.assertRaises(WebSocketError): + yield self.ws_connect("/non_ws") + + @gen_test + def test_websocket_http_redirect(self): + with self.assertRaises(HTTPError): + yield self.ws_connect("/redirect") + + @gen_test + def test_websocket_network_fail(self): + sock, port = bind_unused_port() + sock.close() + with self.assertRaises(IOError): + with ExpectLog(gen_log, ".*", required=False): + yield websocket_connect( + "ws://127.0.0.1:%d/" % port, connect_timeout=3600 + ) + + @gen_test + def test_websocket_close_buffered_data(self): + ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port()) + ws.write_message("hello") + ws.write_message("world") + # Close the underlying stream. + ws.stream.close() + + @gen_test + def test_websocket_headers(self): + # Ensure that arbitrary headers can be passed through websocket_connect. + ws = yield websocket_connect( + HTTPRequest( + "ws://127.0.0.1:%d/header" % self.get_http_port(), + headers={"X-Test": "hello"}, + ) + ) + response = yield ws.read_message() + self.assertEqual(response, "hello") + + @gen_test + def test_websocket_header_echo(self): + # Ensure that headers can be returned in the response. + # Specifically, that arbitrary headers passed through websocket_connect + # can be returned. + ws = yield websocket_connect( + HTTPRequest( + "ws://127.0.0.1:%d/header_echo" % self.get_http_port(), + headers={"X-Test-Hello": "hello"}, + ) + ) + self.assertEqual(ws.headers.get("X-Test-Hello"), "hello") + self.assertEqual( + ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value" + ) + + @gen_test + def test_server_close_reason(self): + ws = yield self.ws_connect("/close_reason") + msg = yield ws.read_message() + # A message of None means the other side closed the connection. + self.assertIs(msg, None) + self.assertEqual(ws.close_code, 1001) + self.assertEqual(ws.close_reason, "goodbye") + # The on_close callback is called no matter which side closed. + code, reason = yield self.close_future + # The client echoed the close code it received to the server, + # so the server's close code (returned via close_future) is + # the same. + self.assertEqual(code, 1001) + + @gen_test + def test_client_close_reason(self): + ws = yield self.ws_connect("/echo") + ws.close(1001, "goodbye") + code, reason = yield self.close_future + self.assertEqual(code, 1001) + self.assertEqual(reason, "goodbye") + + @gen_test + def test_write_after_close(self): + ws = yield self.ws_connect("/close_reason") + msg = yield ws.read_message() + self.assertIs(msg, None) + with self.assertRaises(WebSocketClosedError): + ws.write_message("hello") + + @gen_test + def test_async_prepare(self): + # Previously, an async prepare method triggered a bug that would + # result in a timeout on test shutdown (and a memory leak). + ws = yield self.ws_connect("/async_prepare") + ws.write_message("hello") + res = yield ws.read_message() + self.assertEqual(res, "hello") + + @gen_test + def test_path_args(self): + ws = yield self.ws_connect("/path_args/hello") + res = yield ws.read_message() + self.assertEqual(res, "hello") + + @gen_test + def test_coroutine(self): + ws = yield self.ws_connect("/coroutine") + # Send both messages immediately, coroutine must process one at a time. + yield ws.write_message("hello1") + yield ws.write_message("hello2") + res = yield ws.read_message() + self.assertEqual(res, "hello1") + res = yield ws.read_message() + self.assertEqual(res, "hello2") + + @gen_test + def test_check_origin_valid_no_path(self): + port = self.get_http_port() + + url = "ws://127.0.0.1:%d/echo" % port + headers = {"Origin": "http://127.0.0.1:%d" % port} + + ws = yield websocket_connect(HTTPRequest(url, headers=headers)) + ws.write_message("hello") + response = yield ws.read_message() + self.assertEqual(response, "hello") + + @gen_test + def test_check_origin_valid_with_path(self): + port = self.get_http_port() + + url = "ws://127.0.0.1:%d/echo" % port + headers = {"Origin": "http://127.0.0.1:%d/something" % port} + + ws = yield websocket_connect(HTTPRequest(url, headers=headers)) + ws.write_message("hello") + response = yield ws.read_message() + self.assertEqual(response, "hello") + + @gen_test + def test_check_origin_invalid_partial_url(self): + port = self.get_http_port() + + url = "ws://127.0.0.1:%d/echo" % port + headers = {"Origin": "127.0.0.1:%d" % port} + + with self.assertRaises(HTTPError) as cm: + yield websocket_connect(HTTPRequest(url, headers=headers)) + self.assertEqual(cm.exception.code, 403) + + @gen_test + def test_check_origin_invalid(self): + port = self.get_http_port() + + url = "ws://127.0.0.1:%d/echo" % port + # Host is 127.0.0.1, which should not be accessible from some other + # domain + headers = {"Origin": "http://somewhereelse.com"} + + with self.assertRaises(HTTPError) as cm: + yield websocket_connect(HTTPRequest(url, headers=headers)) + + self.assertEqual(cm.exception.code, 403) + + @gen_test + def test_check_origin_invalid_subdomains(self): + port = self.get_http_port() + + # CaresResolver may return ipv6-only results for localhost, but our + # server is only running on ipv4. Test for this edge case and skip + # the test if it happens. + addrinfo = yield Resolver().resolve("localhost", port) + families = set(addr[0] for addr in addrinfo) + if socket.AF_INET not in families: + self.skipTest("localhost does not resolve to ipv4") + return + + url = "ws://localhost:%d/echo" % port + # Subdomains should be disallowed by default. If we could pass a + # resolver to websocket_connect we could test sibling domains as well. + headers = {"Origin": "http://subtenant.localhost"} + + with self.assertRaises(HTTPError) as cm: + yield websocket_connect(HTTPRequest(url, headers=headers)) + + self.assertEqual(cm.exception.code, 403) + + @gen_test + def test_subprotocols(self): + ws = yield self.ws_connect( + "/subprotocol", subprotocols=["badproto", "goodproto"] + ) + self.assertEqual(ws.selected_subprotocol, "goodproto") + res = yield ws.read_message() + self.assertEqual(res, "subprotocol=goodproto") + + @gen_test + def test_subprotocols_not_offered(self): + ws = yield self.ws_connect("/subprotocol") + self.assertIs(ws.selected_subprotocol, None) + res = yield ws.read_message() + self.assertEqual(res, "subprotocol=None") + + @gen_test + def test_open_coroutine(self): + self.message_sent = Event() + ws = yield self.ws_connect("/open_coroutine") + yield ws.write_message("hello") + self.message_sent.set() + res = yield ws.read_message() + self.assertEqual(res, "ok") + + @gen_test + def test_error_in_open(self): + with ExpectLog(app_log, "Uncaught exception"): + ws = yield self.ws_connect("/error_in_open") + res = yield ws.read_message() + self.assertIsNone(res) + + @gen_test + def test_error_in_async_open(self): + with ExpectLog(app_log, "Uncaught exception"): + ws = yield self.ws_connect("/error_in_async_open") + res = yield ws.read_message() + self.assertIsNone(res) + + @gen_test + def test_nodelay(self): + ws = yield self.ws_connect("/nodelay") + res = yield ws.read_message() + self.assertEqual(res, "hello") + + +class NativeCoroutineOnMessageHandler(TestWebSocketHandler): + def initialize(self, **kwargs): + super().initialize(**kwargs) + self.sleeping = 0 + + async def on_message(self, message): + if self.sleeping > 0: + self.write_message("another coroutine is already sleeping") + self.sleeping += 1 + await gen.sleep(0.01) + self.sleeping -= 1 + self.write_message(message) + + +class WebSocketNativeCoroutineTest(WebSocketBaseTestCase): + def get_app(self): + return Application([("/native", NativeCoroutineOnMessageHandler)]) + + @gen_test + def test_native_coroutine(self): + ws = yield self.ws_connect("/native") + # Send both messages immediately, coroutine must process one at a time. + yield ws.write_message("hello1") + yield ws.write_message("hello2") + res = yield ws.read_message() + self.assertEqual(res, "hello1") + res = yield ws.read_message() + self.assertEqual(res, "hello2") + + +class CompressionTestMixin(object): + MESSAGE = "Hello world. Testing 123 123" + + def get_app(self): + class LimitedHandler(TestWebSocketHandler): + @property + def max_message_size(self): + return 1024 + + def on_message(self, message): + self.write_message(str(len(message))) + + return Application( + [ + ( + "/echo", + EchoHandler, + dict(compression_options=self.get_server_compression_options()), + ), + ( + "/limited", + LimitedHandler, + dict(compression_options=self.get_server_compression_options()), + ), + ] + ) + + def get_server_compression_options(self): + return None + + def get_client_compression_options(self): + return None + + def verify_wire_bytes(self, bytes_in: int, bytes_out: int) -> None: + raise NotImplementedError() + + @gen_test + def test_message_sizes(self: typing.Any): + ws = yield self.ws_connect( + "/echo", compression_options=self.get_client_compression_options() + ) + # Send the same message three times so we can measure the + # effect of the context_takeover options. + for i in range(3): + ws.write_message(self.MESSAGE) + response = yield ws.read_message() + self.assertEqual(response, self.MESSAGE) + self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3) + self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3) + self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out) + + @gen_test + def test_size_limit(self: typing.Any): + ws = yield self.ws_connect( + "/limited", compression_options=self.get_client_compression_options() + ) + # Small messages pass through. + ws.write_message("a" * 128) + response = yield ws.read_message() + self.assertEqual(response, "128") + # This message is too big after decompression, but it compresses + # down to a size that will pass the initial checks. + ws.write_message("a" * 2048) + response = yield ws.read_message() + self.assertIsNone(response) + + +class UncompressedTestMixin(CompressionTestMixin): + """Specialization of CompressionTestMixin when we expect no compression.""" + + def verify_wire_bytes(self: typing.Any, bytes_in, bytes_out): + # Bytes out includes the 4-byte mask key per message. + self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6)) + self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2)) + + +class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): + pass + + +# If only one side tries to compress, the extension is not negotiated. +class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): + def get_server_compression_options(self): + return {} + + +class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): + def get_client_compression_options(self): + return {} + + +class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase): + def get_server_compression_options(self): + return {} + + def get_client_compression_options(self): + return {} + + def verify_wire_bytes(self, bytes_in, bytes_out): + self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6)) + self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2)) + # Bytes out includes the 4 bytes mask key per message. + self.assertEqual(bytes_out, bytes_in + 12) + + +class MaskFunctionMixin(object): + # Subclasses should define self.mask(mask, data) + def mask(self, mask: bytes, data: bytes) -> bytes: + raise NotImplementedError() + + def test_mask(self: typing.Any): + self.assertEqual(self.mask(b"abcd", b""), b"") + self.assertEqual(self.mask(b"abcd", b"b"), b"\x03") + self.assertEqual(self.mask(b"abcd", b"54321"), b"TVPVP") + self.assertEqual(self.mask(b"ZXCV", b"98765432"), b"c`t`olpd") + # Include test cases with \x00 bytes (to ensure that the C + # extension isn't depending on null-terminated strings) and + # bytes with the high bit set (to smoke out signedness issues). + self.assertEqual( + self.mask(b"\x00\x01\x02\x03", b"\xff\xfb\xfd\xfc\xfe\xfa"), + b"\xff\xfa\xff\xff\xfe\xfb", + ) + self.assertEqual( + self.mask(b"\xff\xfb\xfd\xfc", b"\x00\x01\x02\x03\x04\x05"), + b"\xff\xfa\xff\xff\xfb\xfe", + ) + + +class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase): + def mask(self, mask, data): + return _websocket_mask_python(mask, data) + + +@unittest.skipIf(speedups is None, "tornado.speedups module not present") +class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase): + def mask(self, mask, data): + return speedups.websocket_mask(mask, data) + + +class ServerPeriodicPingTest(WebSocketBaseTestCase): + def get_app(self): + class PingHandler(TestWebSocketHandler): + def on_pong(self, data): + self.write_message("got pong") + + return Application([("/", PingHandler)], websocket_ping_interval=0.01) + + @gen_test + def test_server_ping(self): + ws = yield self.ws_connect("/") + for i in range(3): + response = yield ws.read_message() + self.assertEqual(response, "got pong") + # TODO: test that the connection gets closed if ping responses stop. + + +class ClientPeriodicPingTest(WebSocketBaseTestCase): + def get_app(self): + class PingHandler(TestWebSocketHandler): + def on_ping(self, data): + self.write_message("got ping") + + return Application([("/", PingHandler)]) + + @gen_test + def test_client_ping(self): + ws = yield self.ws_connect("/", ping_interval=0.01) + for i in range(3): + response = yield ws.read_message() + self.assertEqual(response, "got ping") + # TODO: test that the connection gets closed if ping responses stop. + + +class ManualPingTest(WebSocketBaseTestCase): + def get_app(self): + class PingHandler(TestWebSocketHandler): + def on_ping(self, data): + self.write_message(data, binary=isinstance(data, bytes)) + + return Application([("/", PingHandler)]) + + @gen_test + def test_manual_ping(self): + ws = yield self.ws_connect("/") + + self.assertRaises(ValueError, ws.ping, "a" * 126) + + ws.ping("hello") + resp = yield ws.read_message() + # on_ping always sees bytes. + self.assertEqual(resp, b"hello") + + ws.ping(b"binary hello") + resp = yield ws.read_message() + self.assertEqual(resp, b"binary hello") + + +class MaxMessageSizeTest(WebSocketBaseTestCase): + def get_app(self): + return Application([("/", EchoHandler)], websocket_max_message_size=1024) + + @gen_test + def test_large_message(self): + ws = yield self.ws_connect("/") + + # Write a message that is allowed. + msg = "a" * 1024 + ws.write_message(msg) + resp = yield ws.read_message() + self.assertEqual(resp, msg) + + # Write a message that is too large. + ws.write_message(msg + "b") + resp = yield ws.read_message() + # A message of None means the other side closed the connection. + self.assertIs(resp, None) + self.assertEqual(ws.close_code, 1009) + self.assertEqual(ws.close_reason, "message too big") + # TODO: Needs tests of messages split over multiple + # continuation frames. diff --git a/venv/lib/python3.9/site-packages/tornado/test/wsgi_test.py b/venv/lib/python3.9/site-packages/tornado/test/wsgi_test.py new file mode 100644 index 00000000..9fbc744e --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/test/wsgi_test.py @@ -0,0 +1,116 @@ +import asyncio +import concurrent.futures +import threading + +from wsgiref.validate import validator + +from tornado.routing import RuleRouter +from tornado.testing import AsyncHTTPTestCase, gen_test +from tornado.wsgi import WSGIContainer + + +class WSGIAppMixin: + # TODO: Now that WSGIAdapter is gone, this is a pretty weak test. + def get_executor(self): + raise NotImplementedError() + + def get_app(self): + executor = self.get_executor() + # The barrier test in DummyExecutorTest will always wait the full + # value of this timeout, so we don't want it to be too high. + self.barrier = threading.Barrier(2, timeout=0.3) + + def make_container(app): + return WSGIContainer(validator(app), executor=executor) + + return RuleRouter( + [ + ("/simple", make_container(self.simple_wsgi_app)), + ("/barrier", make_container(self.barrier_wsgi_app)), + ("/streaming_barrier", make_container(self.streaming_barrier_wsgi_app)), + ] + ) + + def respond_plain(self, start_response): + status = "200 OK" + response_headers = [("Content-Type", "text/plain")] + start_response(status, response_headers) + + def simple_wsgi_app(self, environ, start_response): + self.respond_plain(start_response) + return [b"Hello world!"] + + def barrier_wsgi_app(self, environ, start_response): + self.respond_plain(start_response) + try: + n = self.barrier.wait() + except threading.BrokenBarrierError: + return [b"broken barrier"] + else: + return [b"ok %d" % n] + + def streaming_barrier_wsgi_app(self, environ, start_response): + self.respond_plain(start_response) + yield b"ok " + try: + n = self.barrier.wait() + except threading.BrokenBarrierError: + yield b"broken barrier" + else: + yield b"%d" % n + + +class WSGIContainerDummyExecutorTest(WSGIAppMixin, AsyncHTTPTestCase): + def get_executor(self): + return None + + def test_simple(self): + response = self.fetch("/simple") + self.assertEqual(response.body, b"Hello world!") + + @gen_test + async def test_concurrent_barrier(self): + self.barrier.reset() + resps = await asyncio.gather( + self.http_client.fetch(self.get_url("/barrier")), + self.http_client.fetch(self.get_url("/barrier")), + ) + for resp in resps: + self.assertEqual(resp.body, b"broken barrier") + + @gen_test + async def test_concurrent_streaming_barrier(self): + self.barrier.reset() + resps = await asyncio.gather( + self.http_client.fetch(self.get_url("/streaming_barrier")), + self.http_client.fetch(self.get_url("/streaming_barrier")), + ) + for resp in resps: + self.assertEqual(resp.body, b"ok broken barrier") + + +class WSGIContainerThreadPoolTest(WSGIAppMixin, AsyncHTTPTestCase): + def get_executor(self): + return concurrent.futures.ThreadPoolExecutor() + + def test_simple(self): + response = self.fetch("/simple") + self.assertEqual(response.body, b"Hello world!") + + @gen_test + async def test_concurrent_barrier(self): + self.barrier.reset() + resps = await asyncio.gather( + self.http_client.fetch(self.get_url("/barrier")), + self.http_client.fetch(self.get_url("/barrier")), + ) + self.assertEqual([b"ok 0", b"ok 1"], sorted([resp.body for resp in resps])) + + @gen_test + async def test_concurrent_streaming_barrier(self): + self.barrier.reset() + resps = await asyncio.gather( + self.http_client.fetch(self.get_url("/streaming_barrier")), + self.http_client.fetch(self.get_url("/streaming_barrier")), + ) + self.assertEqual([b"ok 0", b"ok 1"], sorted([resp.body for resp in resps])) diff --git a/venv/lib/python3.9/site-packages/tornado/testing.py b/venv/lib/python3.9/site-packages/tornado/testing.py new file mode 100644 index 00000000..9bfadf45 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/testing.py @@ -0,0 +1,872 @@ +"""Support classes for automated testing. + +* `AsyncTestCase` and `AsyncHTTPTestCase`: Subclasses of unittest.TestCase + with additional support for testing asynchronous (`.IOLoop`-based) code. + +* `ExpectLog`: Make test logs less spammy. + +* `main()`: A simple test runner (wrapper around unittest.main()) with support + for the tornado.autoreload module to rerun the tests when code changes. +""" + +import asyncio +from collections.abc import Generator +import functools +import inspect +import logging +import os +import re +import signal +import socket +import sys +import unittest +import warnings + +from tornado import gen +from tornado.httpclient import AsyncHTTPClient, HTTPResponse +from tornado.httpserver import HTTPServer +from tornado.ioloop import IOLoop, TimeoutError +from tornado import netutil +from tornado.platform.asyncio import AsyncIOMainLoop +from tornado.process import Subprocess +from tornado.log import app_log +from tornado.util import raise_exc_info, basestring_type +from tornado.web import Application + +import typing +from typing import Tuple, Any, Callable, Type, Dict, Union, Optional, Coroutine +from types import TracebackType + +if typing.TYPE_CHECKING: + _ExcInfoTuple = Tuple[ + Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType] + ] + + +_NON_OWNED_IOLOOPS = AsyncIOMainLoop + + +def bind_unused_port( + reuse_port: bool = False, address: str = "127.0.0.1" +) -> Tuple[socket.socket, int]: + """Binds a server socket to an available port on localhost. + + Returns a tuple (socket, port). + + .. versionchanged:: 4.4 + Always binds to ``127.0.0.1`` without resolving the name + ``localhost``. + + .. versionchanged:: 6.2 + Added optional ``address`` argument to + override the default "127.0.0.1". + """ + sock = netutil.bind_sockets( + 0, address, family=socket.AF_INET, reuse_port=reuse_port + )[0] + port = sock.getsockname()[1] + return sock, port + + +def get_async_test_timeout() -> float: + """Get the global timeout setting for async tests. + + Returns a float, the timeout in seconds. + + .. versionadded:: 3.1 + """ + env = os.environ.get("ASYNC_TEST_TIMEOUT") + if env is not None: + try: + return float(env) + except ValueError: + pass + return 5 + + +class _TestMethodWrapper(object): + """Wraps a test method to raise an error if it returns a value. + + This is mainly used to detect undecorated generators (if a test + method yields it must use a decorator to consume the generator), + but will also detect other kinds of return values (these are not + necessarily errors, but we alert anyway since there is no good + reason to return a value from a test). + """ + + def __init__(self, orig_method: Callable) -> None: + self.orig_method = orig_method + self.__wrapped__ = orig_method + + def __call__(self, *args: Any, **kwargs: Any) -> None: + result = self.orig_method(*args, **kwargs) + if isinstance(result, Generator) or inspect.iscoroutine(result): + raise TypeError( + "Generator and coroutine test methods should be" + " decorated with tornado.testing.gen_test" + ) + elif result is not None: + raise ValueError("Return value from test method ignored: %r" % result) + + def __getattr__(self, name: str) -> Any: + """Proxy all unknown attributes to the original method. + + This is important for some of the decorators in the `unittest` + module, such as `unittest.skipIf`. + """ + return getattr(self.orig_method, name) + + +class AsyncTestCase(unittest.TestCase): + """`~unittest.TestCase` subclass for testing `.IOLoop`-based + asynchronous code. + + The unittest framework is synchronous, so the test must be + complete by the time the test method returns. This means that + asynchronous code cannot be used in quite the same way as usual + and must be adapted to fit. To write your tests with coroutines, + decorate your test methods with `tornado.testing.gen_test` instead + of `tornado.gen.coroutine`. + + This class also provides the (deprecated) `stop()` and `wait()` + methods for a more manual style of testing. The test method itself + must call ``self.wait()``, and asynchronous callbacks should call + ``self.stop()`` to signal completion. + + By default, a new `.IOLoop` is constructed for each test and is available + as ``self.io_loop``. If the code being tested requires a + reused global `.IOLoop`, subclasses should override `get_new_ioloop` to return it, + although this is deprecated as of Tornado 6.3. + + The `.IOLoop`'s ``start`` and ``stop`` methods should not be + called directly. Instead, use `self.stop <stop>` and `self.wait + <wait>`. Arguments passed to ``self.stop`` are returned from + ``self.wait``. It is possible to have multiple ``wait``/``stop`` + cycles in the same test. + + Example:: + + # This test uses coroutine style. + class MyTestCase(AsyncTestCase): + @tornado.testing.gen_test + def test_http_fetch(self): + client = AsyncHTTPClient() + response = yield client.fetch("http://www.tornadoweb.org") + # Test contents of response + self.assertIn("FriendFeed", response.body) + + # This test uses argument passing between self.stop and self.wait. + class MyTestCase2(AsyncTestCase): + def test_http_fetch(self): + client = AsyncHTTPClient() + client.fetch("http://www.tornadoweb.org/", self.stop) + response = self.wait() + # Test contents of response + self.assertIn("FriendFeed", response.body) + """ + + def __init__(self, methodName: str = "runTest") -> None: + super().__init__(methodName) + self.__stopped = False + self.__running = False + self.__failure = None # type: Optional[_ExcInfoTuple] + self.__stop_args = None # type: Any + self.__timeout = None # type: Optional[object] + + # It's easy to forget the @gen_test decorator, but if you do + # the test will silently be ignored because nothing will consume + # the generator. Replace the test method with a wrapper that will + # make sure it's not an undecorated generator. + setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName))) + + # Not used in this class itself, but used by @gen_test + self._test_generator = None # type: Optional[Union[Generator, Coroutine]] + + def setUp(self) -> None: + py_ver = sys.version_info + if ((3, 10, 0) <= py_ver < (3, 10, 9)) or ((3, 11, 0) <= py_ver <= (3, 11, 1)): + # Early releases in the Python 3.10 and 3.1 series had deprecation + # warnings that were later reverted; we must suppress them here. + setup_with_context_manager(self, warnings.catch_warnings()) + warnings.filterwarnings( + "ignore", + message="There is no current event loop", + category=DeprecationWarning, + module=r"tornado\..*", + ) + super().setUp() + if type(self).get_new_ioloop is not AsyncTestCase.get_new_ioloop: + warnings.warn("get_new_ioloop is deprecated", DeprecationWarning) + self.io_loop = self.get_new_ioloop() + asyncio.set_event_loop(self.io_loop.asyncio_loop) # type: ignore[attr-defined] + + def tearDown(self) -> None: + # Native coroutines tend to produce warnings if they're not + # allowed to run to completion. It's difficult to ensure that + # this always happens in tests, so cancel any tasks that are + # still pending by the time we get here. + asyncio_loop = self.io_loop.asyncio_loop # type: ignore + if hasattr(asyncio, "all_tasks"): # py37 + tasks = asyncio.all_tasks(asyncio_loop) # type: ignore + else: + tasks = asyncio.Task.all_tasks(asyncio_loop) + # Tasks that are done may still appear here and may contain + # non-cancellation exceptions, so filter them out. + tasks = [t for t in tasks if not t.done()] # type: ignore + for t in tasks: + t.cancel() + # Allow the tasks to run and finalize themselves (which means + # raising a CancelledError inside the coroutine). This may + # just transform the "task was destroyed but it is pending" + # warning into a "uncaught CancelledError" warning, but + # catching CancelledErrors in coroutines that may leak is + # simpler than ensuring that no coroutines leak. + if tasks: + done, pending = self.io_loop.run_sync(lambda: asyncio.wait(tasks)) + assert not pending + # If any task failed with anything but a CancelledError, raise it. + for f in done: + try: + f.result() + except asyncio.CancelledError: + pass + + # Clean up Subprocess, so it can be used again with a new ioloop. + Subprocess.uninitialize() + asyncio.set_event_loop(None) + if not isinstance(self.io_loop, _NON_OWNED_IOLOOPS): + # Try to clean up any file descriptors left open in the ioloop. + # This avoids leaks, especially when tests are run repeatedly + # in the same process with autoreload (because curl does not + # set FD_CLOEXEC on its file descriptors) + self.io_loop.close(all_fds=True) + super().tearDown() + # In case an exception escaped or the StackContext caught an exception + # when there wasn't a wait() to re-raise it, do so here. + # This is our last chance to raise an exception in a way that the + # unittest machinery understands. + self.__rethrow() + + def get_new_ioloop(self) -> IOLoop: + """Returns the `.IOLoop` to use for this test. + + By default, a new `.IOLoop` is created for each test. + Subclasses may override this method to return + `.IOLoop.current()` if it is not appropriate to use a new + `.IOLoop` in each tests (for example, if there are global + singletons using the default `.IOLoop`) or if a per-test event + loop is being provided by another system (such as + ``pytest-asyncio``). + + .. deprecated:: 6.3 + This method will be removed in Tornado 7.0. + """ + return IOLoop(make_current=False) + + def _handle_exception( + self, typ: Type[Exception], value: Exception, tb: TracebackType + ) -> bool: + if self.__failure is None: + self.__failure = (typ, value, tb) + else: + app_log.error( + "multiple unhandled exceptions in test", exc_info=(typ, value, tb) + ) + self.stop() + return True + + def __rethrow(self) -> None: + if self.__failure is not None: + failure = self.__failure + self.__failure = None + raise_exc_info(failure) + + def run( + self, result: Optional[unittest.TestResult] = None + ) -> Optional[unittest.TestResult]: + ret = super().run(result) + # As a last resort, if an exception escaped super.run() and wasn't + # re-raised in tearDown, raise it here. This will cause the + # unittest run to fail messily, but that's better than silently + # ignoring an error. + self.__rethrow() + return ret + + def stop(self, _arg: Any = None, **kwargs: Any) -> None: + """Stops the `.IOLoop`, causing one pending (or future) call to `wait()` + to return. + + Keyword arguments or a single positional argument passed to `stop()` are + saved and will be returned by `wait()`. + + .. deprecated:: 5.1 + + `stop` and `wait` are deprecated; use ``@gen_test`` instead. + """ + assert _arg is None or not kwargs + self.__stop_args = kwargs or _arg + if self.__running: + self.io_loop.stop() + self.__running = False + self.__stopped = True + + def wait( + self, + condition: Optional[Callable[..., bool]] = None, + timeout: Optional[float] = None, + ) -> Any: + """Runs the `.IOLoop` until stop is called or timeout has passed. + + In the event of a timeout, an exception will be thrown. The + default timeout is 5 seconds; it may be overridden with a + ``timeout`` keyword argument or globally with the + ``ASYNC_TEST_TIMEOUT`` environment variable. + + If ``condition`` is not ``None``, the `.IOLoop` will be restarted + after `stop()` until ``condition()`` returns ``True``. + + .. versionchanged:: 3.1 + Added the ``ASYNC_TEST_TIMEOUT`` environment variable. + + .. deprecated:: 5.1 + + `stop` and `wait` are deprecated; use ``@gen_test`` instead. + """ + if timeout is None: + timeout = get_async_test_timeout() + + if not self.__stopped: + if timeout: + + def timeout_func() -> None: + try: + raise self.failureException( + "Async operation timed out after %s seconds" % timeout + ) + except Exception: + self.__failure = sys.exc_info() + self.stop() + + self.__timeout = self.io_loop.add_timeout( + self.io_loop.time() + timeout, timeout_func + ) + while True: + self.__running = True + self.io_loop.start() + if self.__failure is not None or condition is None or condition(): + break + if self.__timeout is not None: + self.io_loop.remove_timeout(self.__timeout) + self.__timeout = None + assert self.__stopped + self.__stopped = False + self.__rethrow() + result = self.__stop_args + self.__stop_args = None + return result + + +class AsyncHTTPTestCase(AsyncTestCase): + """A test case that starts up an HTTP server. + + Subclasses must override `get_app()`, which returns the + `tornado.web.Application` (or other `.HTTPServer` callback) to be tested. + Tests will typically use the provided ``self.http_client`` to fetch + URLs from this server. + + Example, assuming the "Hello, world" example from the user guide is in + ``hello.py``:: + + import hello + + class TestHelloApp(AsyncHTTPTestCase): + def get_app(self): + return hello.make_app() + + def test_homepage(self): + response = self.fetch('/') + self.assertEqual(response.code, 200) + self.assertEqual(response.body, 'Hello, world') + + That call to ``self.fetch()`` is equivalent to :: + + self.http_client.fetch(self.get_url('/'), self.stop) + response = self.wait() + + which illustrates how AsyncTestCase can turn an asynchronous operation, + like ``http_client.fetch()``, into a synchronous operation. If you need + to do other asynchronous operations in tests, you'll probably need to use + ``stop()`` and ``wait()`` yourself. + """ + + def setUp(self) -> None: + super().setUp() + sock, port = bind_unused_port() + self.__port = port + + self.http_client = self.get_http_client() + self._app = self.get_app() + self.http_server = self.get_http_server() + self.http_server.add_sockets([sock]) + + def get_http_client(self) -> AsyncHTTPClient: + return AsyncHTTPClient() + + def get_http_server(self) -> HTTPServer: + return HTTPServer(self._app, **self.get_httpserver_options()) + + def get_app(self) -> Application: + """Should be overridden by subclasses to return a + `tornado.web.Application` or other `.HTTPServer` callback. + """ + raise NotImplementedError() + + def fetch( + self, path: str, raise_error: bool = False, **kwargs: Any + ) -> HTTPResponse: + """Convenience method to synchronously fetch a URL. + + The given path will be appended to the local server's host and + port. Any additional keyword arguments will be passed directly to + `.AsyncHTTPClient.fetch` (and so could be used to pass + ``method="POST"``, ``body="..."``, etc). + + If the path begins with http:// or https://, it will be treated as a + full URL and will be fetched as-is. + + If ``raise_error`` is ``True``, a `tornado.httpclient.HTTPError` will + be raised if the response code is not 200. This is the same behavior + as the ``raise_error`` argument to `.AsyncHTTPClient.fetch`, but + the default is ``False`` here (it's ``True`` in `.AsyncHTTPClient`) + because tests often need to deal with non-200 response codes. + + .. versionchanged:: 5.0 + Added support for absolute URLs. + + .. versionchanged:: 5.1 + + Added the ``raise_error`` argument. + + .. deprecated:: 5.1 + + This method currently turns any exception into an + `.HTTPResponse` with status code 599. In Tornado 6.0, + errors other than `tornado.httpclient.HTTPError` will be + passed through, and ``raise_error=False`` will only + suppress errors that would be raised due to non-200 + response codes. + + """ + if path.lower().startswith(("http://", "https://")): + url = path + else: + url = self.get_url(path) + return self.io_loop.run_sync( + lambda: self.http_client.fetch(url, raise_error=raise_error, **kwargs), + timeout=get_async_test_timeout(), + ) + + def get_httpserver_options(self) -> Dict[str, Any]: + """May be overridden by subclasses to return additional + keyword arguments for the server. + """ + return {} + + def get_http_port(self) -> int: + """Returns the port used by the server. + + A new port is chosen for each test. + """ + return self.__port + + def get_protocol(self) -> str: + return "http" + + def get_url(self, path: str) -> str: + """Returns an absolute url for the given path on the test server.""" + return "%s://127.0.0.1:%s%s" % (self.get_protocol(), self.get_http_port(), path) + + def tearDown(self) -> None: + self.http_server.stop() + self.io_loop.run_sync( + self.http_server.close_all_connections, timeout=get_async_test_timeout() + ) + self.http_client.close() + del self.http_server + del self._app + super().tearDown() + + +class AsyncHTTPSTestCase(AsyncHTTPTestCase): + """A test case that starts an HTTPS server. + + Interface is generally the same as `AsyncHTTPTestCase`. + """ + + def get_http_client(self) -> AsyncHTTPClient: + return AsyncHTTPClient(force_instance=True, defaults=dict(validate_cert=False)) + + def get_httpserver_options(self) -> Dict[str, Any]: + return dict(ssl_options=self.get_ssl_options()) + + def get_ssl_options(self) -> Dict[str, Any]: + """May be overridden by subclasses to select SSL options. + + By default includes a self-signed testing certificate. + """ + return AsyncHTTPSTestCase.default_ssl_options() + + @staticmethod + def default_ssl_options() -> Dict[str, Any]: + # Testing keys were generated with: + # openssl req -new -keyout tornado/test/test.key \ + # -out tornado/test/test.crt -nodes -days 3650 -x509 + module_dir = os.path.dirname(__file__) + return dict( + certfile=os.path.join(module_dir, "test", "test.crt"), + keyfile=os.path.join(module_dir, "test", "test.key"), + ) + + def get_protocol(self) -> str: + return "https" + + +@typing.overload +def gen_test( + *, timeout: Optional[float] = None +) -> Callable[[Callable[..., Union[Generator, "Coroutine"]]], Callable[..., None]]: + pass + + +@typing.overload # noqa: F811 +def gen_test(func: Callable[..., Union[Generator, "Coroutine"]]) -> Callable[..., None]: + pass + + +def gen_test( # noqa: F811 + func: Optional[Callable[..., Union[Generator, "Coroutine"]]] = None, + timeout: Optional[float] = None, +) -> Union[ + Callable[..., None], + Callable[[Callable[..., Union[Generator, "Coroutine"]]], Callable[..., None]], +]: + """Testing equivalent of ``@gen.coroutine``, to be applied to test methods. + + ``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not + already running. ``@gen_test`` should be applied to test methods + on subclasses of `AsyncTestCase`. + + Example:: + + class MyTest(AsyncHTTPTestCase): + @gen_test + def test_something(self): + response = yield self.http_client.fetch(self.get_url('/')) + + By default, ``@gen_test`` times out after 5 seconds. The timeout may be + overridden globally with the ``ASYNC_TEST_TIMEOUT`` environment variable, + or for each test with the ``timeout`` keyword argument:: + + class MyTest(AsyncHTTPTestCase): + @gen_test(timeout=10) + def test_something_slow(self): + response = yield self.http_client.fetch(self.get_url('/')) + + Note that ``@gen_test`` is incompatible with `AsyncTestCase.stop`, + `AsyncTestCase.wait`, and `AsyncHTTPTestCase.fetch`. Use ``yield + self.http_client.fetch(self.get_url())`` as shown above instead. + + .. versionadded:: 3.1 + The ``timeout`` argument and ``ASYNC_TEST_TIMEOUT`` environment + variable. + + .. versionchanged:: 4.0 + The wrapper now passes along ``*args, **kwargs`` so it can be used + on functions with arguments. + + """ + if timeout is None: + timeout = get_async_test_timeout() + + def wrap(f: Callable[..., Union[Generator, "Coroutine"]]) -> Callable[..., None]: + # Stack up several decorators to allow us to access the generator + # object itself. In the innermost wrapper, we capture the generator + # and save it in an attribute of self. Next, we run the wrapped + # function through @gen.coroutine. Finally, the coroutine is + # wrapped again to make it synchronous with run_sync. + # + # This is a good case study arguing for either some sort of + # extensibility in the gen decorators or cancellation support. + @functools.wraps(f) + def pre_coroutine(self, *args, **kwargs): + # type: (AsyncTestCase, *Any, **Any) -> Union[Generator, Coroutine] + # Type comments used to avoid pypy3 bug. + result = f(self, *args, **kwargs) + if isinstance(result, Generator) or inspect.iscoroutine(result): + self._test_generator = result + else: + self._test_generator = None + return result + + if inspect.iscoroutinefunction(f): + coro = pre_coroutine + else: + coro = gen.coroutine(pre_coroutine) # type: ignore[assignment] + + @functools.wraps(coro) + def post_coroutine(self, *args, **kwargs): + # type: (AsyncTestCase, *Any, **Any) -> None + try: + return self.io_loop.run_sync( + functools.partial(coro, self, *args, **kwargs), timeout=timeout + ) + except TimeoutError as e: + # run_sync raises an error with an unhelpful traceback. + # If the underlying generator is still running, we can throw the + # exception back into it so the stack trace is replaced by the + # point where the test is stopped. The only reason the generator + # would not be running would be if it were cancelled, which means + # a native coroutine, so we can rely on the cr_running attribute. + if self._test_generator is not None and getattr( + self._test_generator, "cr_running", True + ): + self._test_generator.throw(e) + # In case the test contains an overly broad except + # clause, we may get back here. + # Coroutine was stopped or didn't raise a useful stack trace, + # so re-raise the original exception which is better than nothing. + raise + + return post_coroutine + + if func is not None: + # Used like: + # @gen_test + # def f(self): + # pass + return wrap(func) + else: + # Used like @gen_test(timeout=10) + return wrap + + +# Without this attribute, nosetests will try to run gen_test as a test +# anywhere it is imported. +gen_test.__test__ = False # type: ignore + + +class ExpectLog(logging.Filter): + """Context manager to capture and suppress expected log output. + + Useful to make tests of error conditions less noisy, while still + leaving unexpected log entries visible. *Not thread safe.* + + The attribute ``logged_stack`` is set to ``True`` if any exception + stack trace was logged. + + Usage:: + + with ExpectLog('tornado.application', "Uncaught exception"): + error_response = self.fetch("/some_page") + + .. versionchanged:: 4.3 + Added the ``logged_stack`` attribute. + """ + + def __init__( + self, + logger: Union[logging.Logger, basestring_type], + regex: str, + required: bool = True, + level: Optional[int] = None, + ) -> None: + """Constructs an ExpectLog context manager. + + :param logger: Logger object (or name of logger) to watch. Pass an + empty string to watch the root logger. + :param regex: Regular expression to match. Any log entries on the + specified logger that match this regex will be suppressed. + :param required: If true, an exception will be raised if the end of the + ``with`` statement is reached without matching any log entries. + :param level: A constant from the ``logging`` module indicating the + expected log level. If this parameter is provided, only log messages + at this level will be considered to match. Additionally, the + supplied ``logger`` will have its level adjusted if necessary (for + the duration of the ``ExpectLog`` to enable the expected message. + + .. versionchanged:: 6.1 + Added the ``level`` parameter. + + .. deprecated:: 6.3 + In Tornado 7.0, only ``WARNING`` and higher logging levels will be + matched by default. To match ``INFO`` and lower levels, the ``level`` + argument must be used. This is changing to minimize differences + between ``tornado.testing.main`` (which enables ``INFO`` logs by + default) and most other test runners (including those in IDEs) + which have ``INFO`` logs disabled by default. + """ + if isinstance(logger, basestring_type): + logger = logging.getLogger(logger) + self.logger = logger + self.regex = re.compile(regex) + self.required = required + # matched and deprecated_level_matched are a counter for the respective event. + self.matched = 0 + self.deprecated_level_matched = 0 + self.logged_stack = False + self.level = level + self.orig_level = None # type: Optional[int] + + def filter(self, record: logging.LogRecord) -> bool: + if record.exc_info: + self.logged_stack = True + message = record.getMessage() + if self.regex.match(message): + if self.level is None and record.levelno < logging.WARNING: + # We're inside the logging machinery here so generating a DeprecationWarning + # here won't be reported cleanly (if warnings-as-errors is enabled, the error + # just gets swallowed by the logging module), and even if it were it would + # have the wrong stack trace. Just remember this fact and report it in + # __exit__ instead. + self.deprecated_level_matched += 1 + if self.level is not None and record.levelno != self.level: + app_log.warning( + "Got expected log message %r at unexpected level (%s vs %s)" + % (message, logging.getLevelName(self.level), record.levelname) + ) + return True + self.matched += 1 + return False + return True + + def __enter__(self) -> "ExpectLog": + if self.level is not None and self.level < self.logger.getEffectiveLevel(): + self.orig_level = self.logger.level + self.logger.setLevel(self.level) + self.logger.addFilter(self) + return self + + def __exit__( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + if self.orig_level is not None: + self.logger.setLevel(self.orig_level) + self.logger.removeFilter(self) + if not typ and self.required and not self.matched: + raise Exception("did not get expected log message") + if ( + not typ + and self.required + and (self.deprecated_level_matched >= self.matched) + ): + warnings.warn( + "ExpectLog matched at INFO or below without level argument", + DeprecationWarning, + ) + + +# From https://nedbatchelder.com/blog/201508/using_context_managers_in_test_setup.html +def setup_with_context_manager(testcase: unittest.TestCase, cm: Any) -> Any: + """Use a contextmanager to setUp a test case.""" + val = cm.__enter__() + testcase.addCleanup(cm.__exit__, None, None, None) + return val + + +def main(**kwargs: Any) -> None: + """A simple test runner. + + This test runner is essentially equivalent to `unittest.main` from + the standard library, but adds support for Tornado-style option + parsing and log formatting. It is *not* necessary to use this + `main` function to run tests using `AsyncTestCase`; these tests + are self-contained and can run with any test runner. + + The easiest way to run a test is via the command line:: + + python -m tornado.testing tornado.test.web_test + + See the standard library ``unittest`` module for ways in which + tests can be specified. + + Projects with many tests may wish to define a test script like + ``tornado/test/runtests.py``. This script should define a method + ``all()`` which returns a test suite and then call + `tornado.testing.main()`. Note that even when a test script is + used, the ``all()`` test suite may be overridden by naming a + single test on the command line:: + + # Runs all tests + python -m tornado.test.runtests + # Runs one test + python -m tornado.test.runtests tornado.test.web_test + + Additional keyword arguments passed through to ``unittest.main()``. + For example, use ``tornado.testing.main(verbosity=2)`` + to show many test details as they are run. + See http://docs.python.org/library/unittest.html#unittest.main + for full argument list. + + .. versionchanged:: 5.0 + + This function produces no output of its own; only that produced + by the `unittest` module (previously it would add a PASS or FAIL + log message). + """ + from tornado.options import define, options, parse_command_line + + define( + "exception_on_interrupt", + type=bool, + default=True, + help=( + "If true (default), ctrl-c raises a KeyboardInterrupt " + "exception. This prints a stack trace but cannot interrupt " + "certain operations. If false, the process is more reliably " + "killed, but does not print a stack trace." + ), + ) + + # support the same options as unittest's command-line interface + define("verbose", type=bool) + define("quiet", type=bool) + define("failfast", type=bool) + define("catch", type=bool) + define("buffer", type=bool) + + argv = [sys.argv[0]] + parse_command_line(sys.argv) + + if not options.exception_on_interrupt: + signal.signal(signal.SIGINT, signal.SIG_DFL) + + if options.verbose is not None: + kwargs["verbosity"] = 2 + if options.quiet is not None: + kwargs["verbosity"] = 0 + if options.failfast is not None: + kwargs["failfast"] = True + if options.catch is not None: + kwargs["catchbreak"] = True + if options.buffer is not None: + kwargs["buffer"] = True + + if __name__ == "__main__" and len(argv) == 1: + print("No tests specified", file=sys.stderr) + sys.exit(1) + # In order to be able to run tests by their fully-qualified name + # on the command line without importing all tests here, + # module must be set to None. Python 3.2's unittest.main ignores + # defaultTest if no module is given (it tries to do its own + # test discovery, which is incompatible with auto2to3), so don't + # set module if we're not asking for a specific test. + if len(argv) > 1: + unittest.main(module=None, argv=argv, **kwargs) # type: ignore + else: + unittest.main(defaultTest="all", argv=argv, **kwargs) + + +if __name__ == "__main__": + main() diff --git a/venv/lib/python3.9/site-packages/tornado/util.py b/venv/lib/python3.9/site-packages/tornado/util.py new file mode 100644 index 00000000..3a3a52f1 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/util.py @@ -0,0 +1,462 @@ +"""Miscellaneous utility functions and classes. + +This module is used internally by Tornado. It is not necessarily expected +that the functions and classes defined here will be useful to other +applications, but they are documented here in case they are. + +The one public-facing part of this module is the `Configurable` class +and its `~Configurable.configure` method, which becomes a part of the +interface of its subclasses, including `.AsyncHTTPClient`, `.IOLoop`, +and `.Resolver`. +""" + +import array +import asyncio +import atexit +from inspect import getfullargspec +import os +import re +import typing +import zlib + +from typing import ( + Any, + Optional, + Dict, + Mapping, + List, + Tuple, + Match, + Callable, + Type, + Sequence, +) + +if typing.TYPE_CHECKING: + # Additional imports only used in type comments. + # This lets us make these imports lazy. + import datetime # noqa: F401 + from types import TracebackType # noqa: F401 + from typing import Union # noqa: F401 + import unittest # noqa: F401 + +# Aliases for types that are spelled differently in different Python +# versions. bytes_type is deprecated and no longer used in Tornado +# itself but is left in case anyone outside Tornado is using it. +bytes_type = bytes +unicode_type = str +basestring_type = str + +try: + from sys import is_finalizing +except ImportError: + # Emulate it + def _get_emulated_is_finalizing() -> Callable[[], bool]: + L = [] # type: List[None] + atexit.register(lambda: L.append(None)) + + def is_finalizing() -> bool: + # Not referencing any globals here + return L != [] + + return is_finalizing + + is_finalizing = _get_emulated_is_finalizing() + + +# versionchanged:: 6.2 +# no longer our own TimeoutError, use standard asyncio class +TimeoutError = asyncio.TimeoutError + + +class ObjectDict(Dict[str, Any]): + """Makes a dictionary behave like an object, with attribute-style access.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + +class GzipDecompressor(object): + """Streaming gzip decompressor. + + The interface is like that of `zlib.decompressobj` (without some of the + optional arguments, but it understands gzip headers and checksums. + """ + + def __init__(self) -> None: + # Magic parameter makes zlib module understand gzip header + # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib + # This works on cpython and pypy, but not jython. + self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS) + + def decompress(self, value: bytes, max_length: int = 0) -> bytes: + """Decompress a chunk, returning newly-available data. + + Some data may be buffered for later processing; `flush` must + be called when there is no more input data to ensure that + all data was processed. + + If ``max_length`` is given, some input data may be left over + in ``unconsumed_tail``; you must retrieve this value and pass + it back to a future call to `decompress` if it is not empty. + """ + return self.decompressobj.decompress(value, max_length) + + @property + def unconsumed_tail(self) -> bytes: + """Returns the unconsumed portion left over""" + return self.decompressobj.unconsumed_tail + + def flush(self) -> bytes: + """Return any remaining buffered data not yet returned by decompress. + + Also checks for errors such as truncated input. + No other methods may be called on this object after `flush`. + """ + return self.decompressobj.flush() + + +def import_object(name: str) -> Any: + """Imports an object by name. + + ``import_object('x')`` is equivalent to ``import x``. + ``import_object('x.y.z')`` is equivalent to ``from x.y import z``. + + >>> import tornado.escape + >>> import_object('tornado.escape') is tornado.escape + True + >>> import_object('tornado.escape.utf8') is tornado.escape.utf8 + True + >>> import_object('tornado') is tornado + True + >>> import_object('tornado.missing_module') + Traceback (most recent call last): + ... + ImportError: No module named missing_module + """ + if name.count(".") == 0: + return __import__(name) + + parts = name.split(".") + obj = __import__(".".join(parts[:-1]), fromlist=[parts[-1]]) + try: + return getattr(obj, parts[-1]) + except AttributeError: + raise ImportError("No module named %s" % parts[-1]) + + +def exec_in( + code: Any, glob: Dict[str, Any], loc: Optional[Optional[Mapping[str, Any]]] = None +) -> None: + if isinstance(code, str): + # exec(string) inherits the caller's future imports; compile + # the string first to prevent that. + code = compile(code, "<string>", "exec", dont_inherit=True) + exec(code, glob, loc) + + +def raise_exc_info( + exc_info: Tuple[Optional[type], Optional[BaseException], Optional["TracebackType"]] +) -> typing.NoReturn: + try: + if exc_info[1] is not None: + raise exc_info[1].with_traceback(exc_info[2]) + else: + raise TypeError("raise_exc_info called with no exception") + finally: + # Clear the traceback reference from our stack frame to + # minimize circular references that slow down GC. + exc_info = (None, None, None) + + +def errno_from_exception(e: BaseException) -> Optional[int]: + """Provides the errno from an Exception object. + + There are cases that the errno attribute was not set so we pull + the errno out of the args but if someone instantiates an Exception + without any args you will get a tuple error. So this function + abstracts all that behavior to give you a safe way to get the + errno. + """ + + if hasattr(e, "errno"): + return e.errno # type: ignore + elif e.args: + return e.args[0] + else: + return None + + +_alphanum = frozenset("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + + +def _re_unescape_replacement(match: Match[str]) -> str: + group = match.group(1) + if group[0] in _alphanum: + raise ValueError("cannot unescape '\\\\%s'" % group[0]) + return group + + +_re_unescape_pattern = re.compile(r"\\(.)", re.DOTALL) + + +def re_unescape(s: str) -> str: + r"""Unescape a string escaped by `re.escape`. + + May raise ``ValueError`` for regular expressions which could not + have been produced by `re.escape` (for example, strings containing + ``\d`` cannot be unescaped). + + .. versionadded:: 4.4 + """ + return _re_unescape_pattern.sub(_re_unescape_replacement, s) + + +class Configurable(object): + """Base class for configurable interfaces. + + A configurable interface is an (abstract) class whose constructor + acts as a factory function for one of its implementation subclasses. + The implementation subclass as well as optional keyword arguments to + its initializer can be set globally at runtime with `configure`. + + By using the constructor as the factory method, the interface + looks like a normal class, `isinstance` works as usual, etc. This + pattern is most useful when the choice of implementation is likely + to be a global decision (e.g. when `~select.epoll` is available, + always use it instead of `~select.select`), or when a + previously-monolithic class has been split into specialized + subclasses. + + Configurable subclasses must define the class methods + `configurable_base` and `configurable_default`, and use the instance + method `initialize` instead of ``__init__``. + + .. versionchanged:: 5.0 + + It is now possible for configuration to be specified at + multiple levels of a class hierarchy. + + """ + + # Type annotations on this class are mostly done with comments + # because they need to refer to Configurable, which isn't defined + # until after the class definition block. These can use regular + # annotations when our minimum python version is 3.7. + # + # There may be a clever way to use generics here to get more + # precise types (i.e. for a particular Configurable subclass T, + # all the types are subclasses of T, not just Configurable). + __impl_class = None # type: Optional[Type[Configurable]] + __impl_kwargs = None # type: Dict[str, Any] + + def __new__(cls, *args: Any, **kwargs: Any) -> Any: + base = cls.configurable_base() + init_kwargs = {} # type: Dict[str, Any] + if cls is base: + impl = cls.configured_class() + if base.__impl_kwargs: + init_kwargs.update(base.__impl_kwargs) + else: + impl = cls + init_kwargs.update(kwargs) + if impl.configurable_base() is not base: + # The impl class is itself configurable, so recurse. + return impl(*args, **init_kwargs) + instance = super(Configurable, cls).__new__(impl) + # initialize vs __init__ chosen for compatibility with AsyncHTTPClient + # singleton magic. If we get rid of that we can switch to __init__ + # here too. + instance.initialize(*args, **init_kwargs) + return instance + + @classmethod + def configurable_base(cls): + # type: () -> Type[Configurable] + """Returns the base class of a configurable hierarchy. + + This will normally return the class in which it is defined. + (which is *not* necessarily the same as the ``cls`` classmethod + parameter). + + """ + raise NotImplementedError() + + @classmethod + def configurable_default(cls): + # type: () -> Type[Configurable] + """Returns the implementation class to be used if none is configured.""" + raise NotImplementedError() + + def _initialize(self) -> None: + pass + + initialize = _initialize # type: Callable[..., None] + """Initialize a `Configurable` subclass instance. + + Configurable classes should use `initialize` instead of ``__init__``. + + .. versionchanged:: 4.2 + Now accepts positional arguments in addition to keyword arguments. + """ + + @classmethod + def configure(cls, impl, **kwargs): + # type: (Union[None, str, Type[Configurable]], Any) -> None + """Sets the class to use when the base class is instantiated. + + Keyword arguments will be saved and added to the arguments passed + to the constructor. This can be used to set global defaults for + some parameters. + """ + base = cls.configurable_base() + if isinstance(impl, str): + impl = typing.cast(Type[Configurable], import_object(impl)) + if impl is not None and not issubclass(impl, cls): + raise ValueError("Invalid subclass of %s" % cls) + base.__impl_class = impl + base.__impl_kwargs = kwargs + + @classmethod + def configured_class(cls): + # type: () -> Type[Configurable] + """Returns the currently configured class.""" + base = cls.configurable_base() + # Manually mangle the private name to see whether this base + # has been configured (and not another base higher in the + # hierarchy). + if base.__dict__.get("_Configurable__impl_class") is None: + base.__impl_class = cls.configurable_default() + if base.__impl_class is not None: + return base.__impl_class + else: + # Should be impossible, but mypy wants an explicit check. + raise ValueError("configured class not found") + + @classmethod + def _save_configuration(cls): + # type: () -> Tuple[Optional[Type[Configurable]], Dict[str, Any]] + base = cls.configurable_base() + return (base.__impl_class, base.__impl_kwargs) + + @classmethod + def _restore_configuration(cls, saved): + # type: (Tuple[Optional[Type[Configurable]], Dict[str, Any]]) -> None + base = cls.configurable_base() + base.__impl_class = saved[0] + base.__impl_kwargs = saved[1] + + +class ArgReplacer(object): + """Replaces one value in an ``args, kwargs`` pair. + + Inspects the function signature to find an argument by name + whether it is passed by position or keyword. For use in decorators + and similar wrappers. + """ + + def __init__(self, func: Callable, name: str) -> None: + self.name = name + try: + self.arg_pos = self._getargnames(func).index(name) # type: Optional[int] + except ValueError: + # Not a positional parameter + self.arg_pos = None + + def _getargnames(self, func: Callable) -> List[str]: + try: + return getfullargspec(func).args + except TypeError: + if hasattr(func, "func_code"): + # Cython-generated code has all the attributes needed + # by inspect.getfullargspec, but the inspect module only + # works with ordinary functions. Inline the portion of + # getfullargspec that we need here. Note that for static + # functions the @cython.binding(True) decorator must + # be used (for methods it works out of the box). + code = func.func_code # type: ignore + return code.co_varnames[: code.co_argcount] + raise + + def get_old_value( + self, args: Sequence[Any], kwargs: Dict[str, Any], default: Any = None + ) -> Any: + """Returns the old value of the named argument without replacing it. + + Returns ``default`` if the argument is not present. + """ + if self.arg_pos is not None and len(args) > self.arg_pos: + return args[self.arg_pos] + else: + return kwargs.get(self.name, default) + + def replace( + self, new_value: Any, args: Sequence[Any], kwargs: Dict[str, Any] + ) -> Tuple[Any, Sequence[Any], Dict[str, Any]]: + """Replace the named argument in ``args, kwargs`` with ``new_value``. + + Returns ``(old_value, args, kwargs)``. The returned ``args`` and + ``kwargs`` objects may not be the same as the input objects, or + the input objects may be mutated. + + If the named argument was not found, ``new_value`` will be added + to ``kwargs`` and None will be returned as ``old_value``. + """ + if self.arg_pos is not None and len(args) > self.arg_pos: + # The arg to replace is passed positionally + old_value = args[self.arg_pos] + args = list(args) # *args is normally a tuple + args[self.arg_pos] = new_value + else: + # The arg to replace is either omitted or passed by keyword. + old_value = kwargs.get(self.name) + kwargs[self.name] = new_value + return old_value, args, kwargs + + +def timedelta_to_seconds(td): + # type: (datetime.timedelta) -> float + """Equivalent to ``td.total_seconds()`` (introduced in Python 2.7).""" + return td.total_seconds() + + +def _websocket_mask_python(mask: bytes, data: bytes) -> bytes: + """Websocket masking function. + + `mask` is a `bytes` object of length 4; `data` is a `bytes` object of any length. + Returns a `bytes` object of the same length as `data` with the mask applied + as specified in section 5.3 of RFC 6455. + + This pure-python implementation may be replaced by an optimized version when available. + """ + mask_arr = array.array("B", mask) + unmasked_arr = array.array("B", data) + for i in range(len(data)): + unmasked_arr[i] = unmasked_arr[i] ^ mask_arr[i % 4] + return unmasked_arr.tobytes() + + +if os.environ.get("TORNADO_NO_EXTENSION") or os.environ.get("TORNADO_EXTENSION") == "0": + # These environment variables exist to make it easier to do performance + # comparisons; they are not guaranteed to remain supported in the future. + _websocket_mask = _websocket_mask_python +else: + try: + from tornado.speedups import websocket_mask as _websocket_mask + except ImportError: + if os.environ.get("TORNADO_EXTENSION") == "1": + raise + _websocket_mask = _websocket_mask_python + + +def doctests(): + # type: () -> unittest.TestSuite + import doctest + + return doctest.DocTestSuite() diff --git a/venv/lib/python3.9/site-packages/tornado/web.py b/venv/lib/python3.9/site-packages/tornado/web.py new file mode 100644 index 00000000..3b676e3c --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/web.py @@ -0,0 +1,3696 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""``tornado.web`` provides a simple web framework with asynchronous +features that allow it to scale to large numbers of open connections, +making it ideal for `long polling +<http://en.wikipedia.org/wiki/Push_technology#Long_polling>`_. + +Here is a simple "Hello, world" example app: + +.. testcode:: + + import asyncio + import tornado + + class MainHandler(tornado.web.RequestHandler): + def get(self): + self.write("Hello, world") + + async def main(): + application = tornado.web.Application([ + (r"/", MainHandler), + ]) + application.listen(8888) + await asyncio.Event().wait() + + if __name__ == "__main__": + asyncio.run(main()) + +.. testoutput:: + :hide: + + +See the :doc:`guide` for additional information. + +Thread-safety notes +------------------- + +In general, methods on `RequestHandler` and elsewhere in Tornado are +not thread-safe. In particular, methods such as +`~RequestHandler.write()`, `~RequestHandler.finish()`, and +`~RequestHandler.flush()` must only be called from the main thread. If +you use multiple threads it is important to use `.IOLoop.add_callback` +to transfer control back to the main thread before finishing the +request, or to limit your use of other threads to +`.IOLoop.run_in_executor` and ensure that your callbacks running in +the executor do not refer to Tornado objects. + +""" + +import base64 +import binascii +import datetime +import email.utils +import functools +import gzip +import hashlib +import hmac +import http.cookies +from inspect import isclass +from io import BytesIO +import mimetypes +import numbers +import os.path +import re +import socket +import sys +import threading +import time +import warnings +import tornado +import traceback +import types +import urllib.parse +from urllib.parse import urlencode + +from tornado.concurrent import Future, future_set_result_unless_cancelled +from tornado import escape +from tornado import gen +from tornado.httpserver import HTTPServer +from tornado import httputil +from tornado import iostream +from tornado import locale +from tornado.log import access_log, app_log, gen_log +from tornado import template +from tornado.escape import utf8, _unicode +from tornado.routing import ( + AnyMatches, + DefaultHostMatches, + HostMatches, + ReversibleRouter, + Rule, + ReversibleRuleRouter, + URLSpec, + _RuleList, +) +from tornado.util import ObjectDict, unicode_type, _websocket_mask + +url = URLSpec + +from typing import ( + Dict, + Any, + Union, + Optional, + Awaitable, + Tuple, + List, + Callable, + Iterable, + Generator, + Type, + TypeVar, + cast, + overload, +) +from types import TracebackType +import typing + +if typing.TYPE_CHECKING: + from typing import Set # noqa: F401 + + +# The following types are accepted by RequestHandler.set_header +# and related methods. +_HeaderTypes = Union[bytes, unicode_type, int, numbers.Integral, datetime.datetime] + +_CookieSecretTypes = Union[str, bytes, Dict[int, str], Dict[int, bytes]] + + +MIN_SUPPORTED_SIGNED_VALUE_VERSION = 1 +"""The oldest signed value version supported by this version of Tornado. + +Signed values older than this version cannot be decoded. + +.. versionadded:: 3.2.1 +""" + +MAX_SUPPORTED_SIGNED_VALUE_VERSION = 2 +"""The newest signed value version supported by this version of Tornado. + +Signed values newer than this version cannot be decoded. + +.. versionadded:: 3.2.1 +""" + +DEFAULT_SIGNED_VALUE_VERSION = 2 +"""The signed value version produced by `.RequestHandler.create_signed_value`. + +May be overridden by passing a ``version`` keyword argument. + +.. versionadded:: 3.2.1 +""" + +DEFAULT_SIGNED_VALUE_MIN_VERSION = 1 +"""The oldest signed value accepted by `.RequestHandler.get_signed_cookie`. + +May be overridden by passing a ``min_version`` keyword argument. + +.. versionadded:: 3.2.1 +""" + + +class _ArgDefaultMarker: + pass + + +_ARG_DEFAULT = _ArgDefaultMarker() + + +class RequestHandler(object): + """Base class for HTTP request handlers. + + Subclasses must define at least one of the methods defined in the + "Entry points" section below. + + Applications should not construct `RequestHandler` objects + directly and subclasses should not override ``__init__`` (override + `~RequestHandler.initialize` instead). + + """ + + SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PATCH", "PUT", "OPTIONS") + + _template_loaders = {} # type: Dict[str, template.BaseLoader] + _template_loader_lock = threading.Lock() + _remove_control_chars_regex = re.compile(r"[\x00-\x08\x0e-\x1f]") + + _stream_request_body = False + + # Will be set in _execute. + _transforms = None # type: List[OutputTransform] + path_args = None # type: List[str] + path_kwargs = None # type: Dict[str, str] + + def __init__( + self, + application: "Application", + request: httputil.HTTPServerRequest, + **kwargs: Any, + ) -> None: + super().__init__() + + self.application = application + self.request = request + self._headers_written = False + self._finished = False + self._auto_finish = True + self._prepared_future = None + self.ui = ObjectDict( + (n, self._ui_method(m)) for n, m in application.ui_methods.items() + ) + # UIModules are available as both `modules` and `_tt_modules` in the + # template namespace. Historically only `modules` was available + # but could be clobbered by user additions to the namespace. + # The template {% module %} directive looks in `_tt_modules` to avoid + # possible conflicts. + self.ui["_tt_modules"] = _UIModuleNamespace(self, application.ui_modules) + self.ui["modules"] = self.ui["_tt_modules"] + self.clear() + assert self.request.connection is not None + # TODO: need to add set_close_callback to HTTPConnection interface + self.request.connection.set_close_callback( # type: ignore + self.on_connection_close + ) + self.initialize(**kwargs) # type: ignore + + def _initialize(self) -> None: + pass + + initialize = _initialize # type: Callable[..., None] + """Hook for subclass initialization. Called for each request. + + A dictionary passed as the third argument of a ``URLSpec`` will be + supplied as keyword arguments to ``initialize()``. + + Example:: + + class ProfileHandler(RequestHandler): + def initialize(self, database): + self.database = database + + def get(self, username): + ... + + app = Application([ + (r'/user/(.*)', ProfileHandler, dict(database=database)), + ]) + """ + + @property + def settings(self) -> Dict[str, Any]: + """An alias for `self.application.settings <Application.settings>`.""" + return self.application.settings + + def _unimplemented_method(self, *args: str, **kwargs: str) -> None: + raise HTTPError(405) + + head = _unimplemented_method # type: Callable[..., Optional[Awaitable[None]]] + get = _unimplemented_method # type: Callable[..., Optional[Awaitable[None]]] + post = _unimplemented_method # type: Callable[..., Optional[Awaitable[None]]] + delete = _unimplemented_method # type: Callable[..., Optional[Awaitable[None]]] + patch = _unimplemented_method # type: Callable[..., Optional[Awaitable[None]]] + put = _unimplemented_method # type: Callable[..., Optional[Awaitable[None]]] + options = _unimplemented_method # type: Callable[..., Optional[Awaitable[None]]] + + def prepare(self) -> Optional[Awaitable[None]]: + """Called at the beginning of a request before `get`/`post`/etc. + + Override this method to perform common initialization regardless + of the request method. + + Asynchronous support: Use ``async def`` or decorate this method with + `.gen.coroutine` to make it asynchronous. + If this method returns an ``Awaitable`` execution will not proceed + until the ``Awaitable`` is done. + + .. versionadded:: 3.1 + Asynchronous support. + """ + pass + + def on_finish(self) -> None: + """Called after the end of a request. + + Override this method to perform cleanup, logging, etc. + This method is a counterpart to `prepare`. ``on_finish`` may + not produce any output, as it is called after the response + has been sent to the client. + """ + pass + + def on_connection_close(self) -> None: + """Called in async handlers if the client closed the connection. + + Override this to clean up resources associated with + long-lived connections. Note that this method is called only if + the connection was closed during asynchronous processing; if you + need to do cleanup after every request override `on_finish` + instead. + + Proxies may keep a connection open for a time (perhaps + indefinitely) after the client has gone away, so this method + may not be called promptly after the end user closes their + connection. + """ + if _has_stream_request_body(self.__class__): + if not self.request._body_future.done(): + self.request._body_future.set_exception(iostream.StreamClosedError()) + self.request._body_future.exception() + + def clear(self) -> None: + """Resets all headers and content for this response.""" + self._headers = httputil.HTTPHeaders( + { + "Server": "TornadoServer/%s" % tornado.version, + "Content-Type": "text/html; charset=UTF-8", + "Date": httputil.format_timestamp(time.time()), + } + ) + self.set_default_headers() + self._write_buffer = [] # type: List[bytes] + self._status_code = 200 + self._reason = httputil.responses[200] + + def set_default_headers(self) -> None: + """Override this to set HTTP headers at the beginning of the request. + + For example, this is the place to set a custom ``Server`` header. + Note that setting such headers in the normal flow of request + processing may not do what you want, since headers may be reset + during error handling. + """ + pass + + def set_status(self, status_code: int, reason: Optional[str] = None) -> None: + """Sets the status code for our response. + + :arg int status_code: Response status code. + :arg str reason: Human-readable reason phrase describing the status + code. If ``None``, it will be filled in from + `http.client.responses` or "Unknown". + + .. versionchanged:: 5.0 + + No longer validates that the response code is in + `http.client.responses`. + """ + self._status_code = status_code + if reason is not None: + self._reason = escape.native_str(reason) + else: + self._reason = httputil.responses.get(status_code, "Unknown") + + def get_status(self) -> int: + """Returns the status code for our response.""" + return self._status_code + + def set_header(self, name: str, value: _HeaderTypes) -> None: + """Sets the given response header name and value. + + All header values are converted to strings (`datetime` objects + are formatted according to the HTTP specification for the + ``Date`` header). + + """ + self._headers[name] = self._convert_header_value(value) + + def add_header(self, name: str, value: _HeaderTypes) -> None: + """Adds the given response header and value. + + Unlike `set_header`, `add_header` may be called multiple times + to return multiple values for the same header. + """ + self._headers.add(name, self._convert_header_value(value)) + + def clear_header(self, name: str) -> None: + """Clears an outgoing header, undoing a previous `set_header` call. + + Note that this method does not apply to multi-valued headers + set by `add_header`. + """ + if name in self._headers: + del self._headers[name] + + _INVALID_HEADER_CHAR_RE = re.compile(r"[\x00-\x1f]") + + def _convert_header_value(self, value: _HeaderTypes) -> str: + # Convert the input value to a str. This type check is a bit + # subtle: The bytes case only executes on python 3, and the + # unicode case only executes on python 2, because the other + # cases are covered by the first match for str. + if isinstance(value, str): + retval = value + elif isinstance(value, bytes): + # Non-ascii characters in headers are not well supported, + # but if you pass bytes, use latin1 so they pass through as-is. + retval = value.decode("latin1") + elif isinstance(value, numbers.Integral): + # return immediately since we know the converted value will be safe + return str(value) + elif isinstance(value, datetime.datetime): + return httputil.format_timestamp(value) + else: + raise TypeError("Unsupported header value %r" % value) + # If \n is allowed into the header, it is possible to inject + # additional headers or split the request. + if RequestHandler._INVALID_HEADER_CHAR_RE.search(retval): + raise ValueError("Unsafe header value %r", retval) + return retval + + @overload + def get_argument(self, name: str, default: str, strip: bool = True) -> str: + pass + + @overload + def get_argument( # noqa: F811 + self, name: str, default: _ArgDefaultMarker = _ARG_DEFAULT, strip: bool = True + ) -> str: + pass + + @overload + def get_argument( # noqa: F811 + self, name: str, default: None, strip: bool = True + ) -> Optional[str]: + pass + + def get_argument( # noqa: F811 + self, + name: str, + default: Union[None, str, _ArgDefaultMarker] = _ARG_DEFAULT, + strip: bool = True, + ) -> Optional[str]: + """Returns the value of the argument with the given name. + + If default is not provided, the argument is considered to be + required, and we raise a `MissingArgumentError` if it is missing. + + If the argument appears in the request more than once, we return the + last value. + + This method searches both the query and body arguments. + """ + return self._get_argument(name, default, self.request.arguments, strip) + + def get_arguments(self, name: str, strip: bool = True) -> List[str]: + """Returns a list of the arguments with the given name. + + If the argument is not present, returns an empty list. + + This method searches both the query and body arguments. + """ + + # Make sure `get_arguments` isn't accidentally being called with a + # positional argument that's assumed to be a default (like in + # `get_argument`.) + assert isinstance(strip, bool) + + return self._get_arguments(name, self.request.arguments, strip) + + def get_body_argument( + self, + name: str, + default: Union[None, str, _ArgDefaultMarker] = _ARG_DEFAULT, + strip: bool = True, + ) -> Optional[str]: + """Returns the value of the argument with the given name + from the request body. + + If default is not provided, the argument is considered to be + required, and we raise a `MissingArgumentError` if it is missing. + + If the argument appears in the url more than once, we return the + last value. + + .. versionadded:: 3.2 + """ + return self._get_argument(name, default, self.request.body_arguments, strip) + + def get_body_arguments(self, name: str, strip: bool = True) -> List[str]: + """Returns a list of the body arguments with the given name. + + If the argument is not present, returns an empty list. + + .. versionadded:: 3.2 + """ + return self._get_arguments(name, self.request.body_arguments, strip) + + def get_query_argument( + self, + name: str, + default: Union[None, str, _ArgDefaultMarker] = _ARG_DEFAULT, + strip: bool = True, + ) -> Optional[str]: + """Returns the value of the argument with the given name + from the request query string. + + If default is not provided, the argument is considered to be + required, and we raise a `MissingArgumentError` if it is missing. + + If the argument appears in the url more than once, we return the + last value. + + .. versionadded:: 3.2 + """ + return self._get_argument(name, default, self.request.query_arguments, strip) + + def get_query_arguments(self, name: str, strip: bool = True) -> List[str]: + """Returns a list of the query arguments with the given name. + + If the argument is not present, returns an empty list. + + .. versionadded:: 3.2 + """ + return self._get_arguments(name, self.request.query_arguments, strip) + + def _get_argument( + self, + name: str, + default: Union[None, str, _ArgDefaultMarker], + source: Dict[str, List[bytes]], + strip: bool = True, + ) -> Optional[str]: + args = self._get_arguments(name, source, strip=strip) + if not args: + if isinstance(default, _ArgDefaultMarker): + raise MissingArgumentError(name) + return default + return args[-1] + + def _get_arguments( + self, name: str, source: Dict[str, List[bytes]], strip: bool = True + ) -> List[str]: + values = [] + for v in source.get(name, []): + s = self.decode_argument(v, name=name) + if isinstance(s, unicode_type): + # Get rid of any weird control chars (unless decoding gave + # us bytes, in which case leave it alone) + s = RequestHandler._remove_control_chars_regex.sub(" ", s) + if strip: + s = s.strip() + values.append(s) + return values + + def decode_argument(self, value: bytes, name: Optional[str] = None) -> str: + """Decodes an argument from the request. + + The argument has been percent-decoded and is now a byte string. + By default, this method decodes the argument as utf-8 and returns + a unicode string, but this may be overridden in subclasses. + + This method is used as a filter for both `get_argument()` and for + values extracted from the url and passed to `get()`/`post()`/etc. + + The name of the argument is provided if known, but may be None + (e.g. for unnamed groups in the url regex). + """ + try: + return _unicode(value) + except UnicodeDecodeError: + raise HTTPError( + 400, "Invalid unicode in %s: %r" % (name or "url", value[:40]) + ) + + @property + def cookies(self) -> Dict[str, http.cookies.Morsel]: + """An alias for + `self.request.cookies <.httputil.HTTPServerRequest.cookies>`.""" + return self.request.cookies + + def get_cookie(self, name: str, default: Optional[str] = None) -> Optional[str]: + """Returns the value of the request cookie with the given name. + + If the named cookie is not present, returns ``default``. + + This method only returns cookies that were present in the request. + It does not see the outgoing cookies set by `set_cookie` in this + handler. + """ + if self.request.cookies is not None and name in self.request.cookies: + return self.request.cookies[name].value + return default + + def set_cookie( + self, + name: str, + value: Union[str, bytes], + domain: Optional[str] = None, + expires: Optional[Union[float, Tuple, datetime.datetime]] = None, + path: str = "/", + expires_days: Optional[float] = None, + # Keyword-only args start here for historical reasons. + *, + max_age: Optional[int] = None, + httponly: bool = False, + secure: bool = False, + samesite: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Sets an outgoing cookie name/value with the given options. + + Newly-set cookies are not immediately visible via `get_cookie`; + they are not present until the next request. + + Most arguments are passed directly to `http.cookies.Morsel` directly. + See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie + for more information. + + ``expires`` may be a numeric timestamp as returned by `time.time`, + a time tuple as returned by `time.gmtime`, or a + `datetime.datetime` object. ``expires_days`` is provided as a convenience + to set an expiration time in days from today (if both are set, ``expires`` + is used). + + .. deprecated:: 6.3 + Keyword arguments are currently accepted case-insensitively. + In Tornado 7.0 this will be changed to only accept lowercase + arguments. + """ + # The cookie library only accepts type str, in both python 2 and 3 + name = escape.native_str(name) + value = escape.native_str(value) + if re.search(r"[\x00-\x20]", name + value): + # Don't let us accidentally inject bad stuff + raise ValueError("Invalid cookie %r: %r" % (name, value)) + if not hasattr(self, "_new_cookie"): + self._new_cookie = ( + http.cookies.SimpleCookie() + ) # type: http.cookies.SimpleCookie + if name in self._new_cookie: + del self._new_cookie[name] + self._new_cookie[name] = value + morsel = self._new_cookie[name] + if domain: + morsel["domain"] = domain + if expires_days is not None and not expires: + expires = datetime.datetime.utcnow() + datetime.timedelta(days=expires_days) + if expires: + morsel["expires"] = httputil.format_timestamp(expires) + if path: + morsel["path"] = path + if max_age: + # Note change from _ to -. + morsel["max-age"] = str(max_age) + if httponly: + # Note that SimpleCookie ignores the value here. The presense of an + # httponly (or secure) key is treated as true. + morsel["httponly"] = True + if secure: + morsel["secure"] = True + if samesite: + morsel["samesite"] = samesite + if kwargs: + # The setitem interface is case-insensitive, so continue to support + # kwargs for backwards compatibility until we can remove deprecated + # features. + for k, v in kwargs.items(): + morsel[k] = v + warnings.warn( + f"Deprecated arguments to set_cookie: {set(kwargs.keys())} " + "(should be lowercase)", + DeprecationWarning, + ) + + def clear_cookie(self, name: str, **kwargs: Any) -> None: + """Deletes the cookie with the given name. + + This method accepts the same arguments as `set_cookie`, except for + ``expires`` and ``max_age``. Clearing a cookie requires the same + ``domain`` and ``path`` arguments as when it was set. In some cases the + ``samesite`` and ``secure`` arguments are also required to match. Other + arguments are ignored. + + Similar to `set_cookie`, the effect of this method will not be + seen until the following request. + + .. versionchanged:: 6.3 + + Now accepts all keyword arguments that ``set_cookie`` does. + The ``samesite`` and ``secure`` flags have recently become + required for clearing ``samesite="none"`` cookies. + """ + for excluded_arg in ["expires", "max_age"]: + if excluded_arg in kwargs: + raise TypeError( + f"clear_cookie() got an unexpected keyword argument '{excluded_arg}'" + ) + expires = datetime.datetime.utcnow() - datetime.timedelta(days=365) + self.set_cookie(name, value="", expires=expires, **kwargs) + + def clear_all_cookies(self, **kwargs: Any) -> None: + """Attempt to delete all the cookies the user sent with this request. + + See `clear_cookie` for more information on keyword arguments. Due to + limitations of the cookie protocol, it is impossible to determine on the + server side which values are necessary for the ``domain``, ``path``, + ``samesite``, or ``secure`` arguments, this method can only be + successful if you consistently use the same values for these arguments + when setting cookies. + + Similar to `set_cookie`, the effect of this method will not be seen + until the following request. + + .. versionchanged:: 3.2 + + Added the ``path`` and ``domain`` parameters. + + .. versionchanged:: 6.3 + + Now accepts all keyword arguments that ``set_cookie`` does. + + .. deprecated:: 6.3 + + The increasingly complex rules governing cookies have made it + impossible for a ``clear_all_cookies`` method to work reliably + since all we know about cookies are their names. Applications + should generally use ``clear_cookie`` one at a time instead. + """ + for name in self.request.cookies: + self.clear_cookie(name, **kwargs) + + def set_signed_cookie( + self, + name: str, + value: Union[str, bytes], + expires_days: Optional[float] = 30, + version: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Signs and timestamps a cookie so it cannot be forged. + + You must specify the ``cookie_secret`` setting in your Application + to use this method. It should be a long, random sequence of bytes + to be used as the HMAC secret for the signature. + + To read a cookie set with this method, use `get_signed_cookie()`. + + Note that the ``expires_days`` parameter sets the lifetime of the + cookie in the browser, but is independent of the ``max_age_days`` + parameter to `get_signed_cookie`. + A value of None limits the lifetime to the current browser session. + + Secure cookies may contain arbitrary byte values, not just unicode + strings (unlike regular cookies) + + Similar to `set_cookie`, the effect of this method will not be + seen until the following request. + + .. versionchanged:: 3.2.1 + + Added the ``version`` argument. Introduced cookie version 2 + and made it the default. + + .. versionchanged:: 6.3 + + Renamed from ``set_secure_cookie`` to ``set_signed_cookie`` to + avoid confusion with other uses of "secure" in cookie attributes + and prefixes. The old name remains as an alias. + """ + self.set_cookie( + name, + self.create_signed_value(name, value, version=version), + expires_days=expires_days, + **kwargs, + ) + + set_secure_cookie = set_signed_cookie + + def create_signed_value( + self, name: str, value: Union[str, bytes], version: Optional[int] = None + ) -> bytes: + """Signs and timestamps a string so it cannot be forged. + + Normally used via set_signed_cookie, but provided as a separate + method for non-cookie uses. To decode a value not stored + as a cookie use the optional value argument to get_signed_cookie. + + .. versionchanged:: 3.2.1 + + Added the ``version`` argument. Introduced cookie version 2 + and made it the default. + """ + self.require_setting("cookie_secret", "secure cookies") + secret = self.application.settings["cookie_secret"] + key_version = None + if isinstance(secret, dict): + if self.application.settings.get("key_version") is None: + raise Exception("key_version setting must be used for secret_key dicts") + key_version = self.application.settings["key_version"] + + return create_signed_value( + secret, name, value, version=version, key_version=key_version + ) + + def get_signed_cookie( + self, + name: str, + value: Optional[str] = None, + max_age_days: float = 31, + min_version: Optional[int] = None, + ) -> Optional[bytes]: + """Returns the given signed cookie if it validates, or None. + + The decoded cookie value is returned as a byte string (unlike + `get_cookie`). + + Similar to `get_cookie`, this method only returns cookies that + were present in the request. It does not see outgoing cookies set by + `set_signed_cookie` in this handler. + + .. versionchanged:: 3.2.1 + + Added the ``min_version`` argument. Introduced cookie version 2; + both versions 1 and 2 are accepted by default. + + .. versionchanged:: 6.3 + + Renamed from ``get_secure_cookie`` to ``get_signed_cookie`` to + avoid confusion with other uses of "secure" in cookie attributes + and prefixes. The old name remains as an alias. + + """ + self.require_setting("cookie_secret", "secure cookies") + if value is None: + value = self.get_cookie(name) + return decode_signed_value( + self.application.settings["cookie_secret"], + name, + value, + max_age_days=max_age_days, + min_version=min_version, + ) + + get_secure_cookie = get_signed_cookie + + def get_signed_cookie_key_version( + self, name: str, value: Optional[str] = None + ) -> Optional[int]: + """Returns the signing key version of the secure cookie. + + The version is returned as int. + + .. versionchanged:: 6.3 + + Renamed from ``get_secure_cookie_key_version`` to + ``set_signed_cookie_key_version`` to avoid confusion with other + uses of "secure" in cookie attributes and prefixes. The old name + remains as an alias. + + """ + self.require_setting("cookie_secret", "secure cookies") + if value is None: + value = self.get_cookie(name) + if value is None: + return None + return get_signature_key_version(value) + + get_secure_cookie_key_version = get_signed_cookie_key_version + + def redirect( + self, url: str, permanent: bool = False, status: Optional[int] = None + ) -> None: + """Sends a redirect to the given (optionally relative) URL. + + If the ``status`` argument is specified, that value is used as the + HTTP status code; otherwise either 301 (permanent) or 302 + (temporary) is chosen based on the ``permanent`` argument. + The default is 302 (temporary). + """ + if self._headers_written: + raise Exception("Cannot redirect after headers have been written") + if status is None: + status = 301 if permanent else 302 + else: + assert isinstance(status, int) and 300 <= status <= 399 + self.set_status(status) + self.set_header("Location", utf8(url)) + self.finish() + + def write(self, chunk: Union[str, bytes, dict]) -> None: + """Writes the given chunk to the output buffer. + + To write the output to the network, use the `flush()` method below. + + If the given chunk is a dictionary, we write it as JSON and set + the Content-Type of the response to be ``application/json``. + (if you want to send JSON as a different ``Content-Type``, call + ``set_header`` *after* calling ``write()``). + + Note that lists are not converted to JSON because of a potential + cross-site security vulnerability. All JSON output should be + wrapped in a dictionary. More details at + http://haacked.com/archive/2009/06/25/json-hijacking.aspx/ and + https://github.com/facebook/tornado/issues/1009 + """ + if self._finished: + raise RuntimeError("Cannot write() after finish()") + if not isinstance(chunk, (bytes, unicode_type, dict)): + message = "write() only accepts bytes, unicode, and dict objects" + if isinstance(chunk, list): + message += ( + ". Lists not accepted for security reasons; see " + + "http://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write" # noqa: E501 + ) + raise TypeError(message) + if isinstance(chunk, dict): + chunk = escape.json_encode(chunk) + self.set_header("Content-Type", "application/json; charset=UTF-8") + chunk = utf8(chunk) + self._write_buffer.append(chunk) + + def render(self, template_name: str, **kwargs: Any) -> "Future[None]": + """Renders the template with the given arguments as the response. + + ``render()`` calls ``finish()``, so no other output methods can be called + after it. + + Returns a `.Future` with the same semantics as the one returned by `finish`. + Awaiting this `.Future` is optional. + + .. versionchanged:: 5.1 + + Now returns a `.Future` instead of ``None``. + """ + if self._finished: + raise RuntimeError("Cannot render() after finish()") + html = self.render_string(template_name, **kwargs) + + # Insert the additional JS and CSS added by the modules on the page + js_embed = [] + js_files = [] + css_embed = [] + css_files = [] + html_heads = [] + html_bodies = [] + for module in getattr(self, "_active_modules", {}).values(): + embed_part = module.embedded_javascript() + if embed_part: + js_embed.append(utf8(embed_part)) + file_part = module.javascript_files() + if file_part: + if isinstance(file_part, (unicode_type, bytes)): + js_files.append(_unicode(file_part)) + else: + js_files.extend(file_part) + embed_part = module.embedded_css() + if embed_part: + css_embed.append(utf8(embed_part)) + file_part = module.css_files() + if file_part: + if isinstance(file_part, (unicode_type, bytes)): + css_files.append(_unicode(file_part)) + else: + css_files.extend(file_part) + head_part = module.html_head() + if head_part: + html_heads.append(utf8(head_part)) + body_part = module.html_body() + if body_part: + html_bodies.append(utf8(body_part)) + + if js_files: + # Maintain order of JavaScript files given by modules + js = self.render_linked_js(js_files) + sloc = html.rindex(b"</body>") + html = html[:sloc] + utf8(js) + b"\n" + html[sloc:] + if js_embed: + js_bytes = self.render_embed_js(js_embed) + sloc = html.rindex(b"</body>") + html = html[:sloc] + js_bytes + b"\n" + html[sloc:] + if css_files: + css = self.render_linked_css(css_files) + hloc = html.index(b"</head>") + html = html[:hloc] + utf8(css) + b"\n" + html[hloc:] + if css_embed: + css_bytes = self.render_embed_css(css_embed) + hloc = html.index(b"</head>") + html = html[:hloc] + css_bytes + b"\n" + html[hloc:] + if html_heads: + hloc = html.index(b"</head>") + html = html[:hloc] + b"".join(html_heads) + b"\n" + html[hloc:] + if html_bodies: + hloc = html.index(b"</body>") + html = html[:hloc] + b"".join(html_bodies) + b"\n" + html[hloc:] + return self.finish(html) + + def render_linked_js(self, js_files: Iterable[str]) -> str: + """Default method used to render the final js links for the + rendered webpage. + + Override this method in a sub-classed controller to change the output. + """ + paths = [] + unique_paths = set() # type: Set[str] + + for path in js_files: + if not is_absolute(path): + path = self.static_url(path) + if path not in unique_paths: + paths.append(path) + unique_paths.add(path) + + return "".join( + '<script src="' + + escape.xhtml_escape(p) + + '" type="text/javascript"></script>' + for p in paths + ) + + def render_embed_js(self, js_embed: Iterable[bytes]) -> bytes: + """Default method used to render the final embedded js for the + rendered webpage. + + Override this method in a sub-classed controller to change the output. + """ + return ( + b'<script type="text/javascript">\n//<![CDATA[\n' + + b"\n".join(js_embed) + + b"\n//]]>\n</script>" + ) + + def render_linked_css(self, css_files: Iterable[str]) -> str: + """Default method used to render the final css links for the + rendered webpage. + + Override this method in a sub-classed controller to change the output. + """ + paths = [] + unique_paths = set() # type: Set[str] + + for path in css_files: + if not is_absolute(path): + path = self.static_url(path) + if path not in unique_paths: + paths.append(path) + unique_paths.add(path) + + return "".join( + '<link href="' + escape.xhtml_escape(p) + '" ' + 'type="text/css" rel="stylesheet"/>' + for p in paths + ) + + def render_embed_css(self, css_embed: Iterable[bytes]) -> bytes: + """Default method used to render the final embedded css for the + rendered webpage. + + Override this method in a sub-classed controller to change the output. + """ + return b'<style type="text/css">\n' + b"\n".join(css_embed) + b"\n</style>" + + def render_string(self, template_name: str, **kwargs: Any) -> bytes: + """Generate the given template with the given arguments. + + We return the generated byte string (in utf8). To generate and + write a template as a response, use render() above. + """ + # If no template_path is specified, use the path of the calling file + template_path = self.get_template_path() + if not template_path: + frame = sys._getframe(0) + web_file = frame.f_code.co_filename + while frame.f_code.co_filename == web_file and frame.f_back is not None: + frame = frame.f_back + assert frame.f_code.co_filename is not None + template_path = os.path.dirname(frame.f_code.co_filename) + with RequestHandler._template_loader_lock: + if template_path not in RequestHandler._template_loaders: + loader = self.create_template_loader(template_path) + RequestHandler._template_loaders[template_path] = loader + else: + loader = RequestHandler._template_loaders[template_path] + t = loader.load(template_name) + namespace = self.get_template_namespace() + namespace.update(kwargs) + return t.generate(**namespace) + + def get_template_namespace(self) -> Dict[str, Any]: + """Returns a dictionary to be used as the default template namespace. + + May be overridden by subclasses to add or modify values. + + The results of this method will be combined with additional + defaults in the `tornado.template` module and keyword arguments + to `render` or `render_string`. + """ + namespace = dict( + handler=self, + request=self.request, + current_user=self.current_user, + locale=self.locale, + _=self.locale.translate, + pgettext=self.locale.pgettext, + static_url=self.static_url, + xsrf_form_html=self.xsrf_form_html, + reverse_url=self.reverse_url, + ) + namespace.update(self.ui) + return namespace + + def create_template_loader(self, template_path: str) -> template.BaseLoader: + """Returns a new template loader for the given path. + + May be overridden by subclasses. By default returns a + directory-based loader on the given path, using the + ``autoescape`` and ``template_whitespace`` application + settings. If a ``template_loader`` application setting is + supplied, uses that instead. + """ + settings = self.application.settings + if "template_loader" in settings: + return settings["template_loader"] + kwargs = {} + if "autoescape" in settings: + # autoescape=None means "no escaping", so we have to be sure + # to only pass this kwarg if the user asked for it. + kwargs["autoescape"] = settings["autoescape"] + if "template_whitespace" in settings: + kwargs["whitespace"] = settings["template_whitespace"] + return template.Loader(template_path, **kwargs) + + def flush(self, include_footers: bool = False) -> "Future[None]": + """Flushes the current output buffer to the network. + + .. versionchanged:: 4.0 + Now returns a `.Future` if no callback is given. + + .. versionchanged:: 6.0 + + The ``callback`` argument was removed. + """ + assert self.request.connection is not None + chunk = b"".join(self._write_buffer) + self._write_buffer = [] + if not self._headers_written: + self._headers_written = True + for transform in self._transforms: + assert chunk is not None + ( + self._status_code, + self._headers, + chunk, + ) = transform.transform_first_chunk( + self._status_code, self._headers, chunk, include_footers + ) + # Ignore the chunk and only write the headers for HEAD requests + if self.request.method == "HEAD": + chunk = b"" + + # Finalize the cookie headers (which have been stored in a side + # object so an outgoing cookie could be overwritten before it + # is sent). + if hasattr(self, "_new_cookie"): + for cookie in self._new_cookie.values(): + self.add_header("Set-Cookie", cookie.OutputString(None)) + + start_line = httputil.ResponseStartLine("", self._status_code, self._reason) + return self.request.connection.write_headers( + start_line, self._headers, chunk + ) + else: + for transform in self._transforms: + chunk = transform.transform_chunk(chunk, include_footers) + # Ignore the chunk and only write the headers for HEAD requests + if self.request.method != "HEAD": + return self.request.connection.write(chunk) + else: + future = Future() # type: Future[None] + future.set_result(None) + return future + + def finish(self, chunk: Optional[Union[str, bytes, dict]] = None) -> "Future[None]": + """Finishes this response, ending the HTTP request. + + Passing a ``chunk`` to ``finish()`` is equivalent to passing that + chunk to ``write()`` and then calling ``finish()`` with no arguments. + + Returns a `.Future` which may optionally be awaited to track the sending + of the response to the client. This `.Future` resolves when all the response + data has been sent, and raises an error if the connection is closed before all + data can be sent. + + .. versionchanged:: 5.1 + + Now returns a `.Future` instead of ``None``. + """ + if self._finished: + raise RuntimeError("finish() called twice") + + if chunk is not None: + self.write(chunk) + + # Automatically support ETags and add the Content-Length header if + # we have not flushed any content yet. + if not self._headers_written: + if ( + self._status_code == 200 + and self.request.method in ("GET", "HEAD") + and "Etag" not in self._headers + ): + self.set_etag_header() + if self.check_etag_header(): + self._write_buffer = [] + self.set_status(304) + if self._status_code in (204, 304) or (100 <= self._status_code < 200): + assert not self._write_buffer, ( + "Cannot send body with %s" % self._status_code + ) + self._clear_representation_headers() + elif "Content-Length" not in self._headers: + content_length = sum(len(part) for part in self._write_buffer) + self.set_header("Content-Length", content_length) + + assert self.request.connection is not None + # Now that the request is finished, clear the callback we + # set on the HTTPConnection (which would otherwise prevent the + # garbage collection of the RequestHandler when there + # are keepalive connections) + self.request.connection.set_close_callback(None) # type: ignore + + future = self.flush(include_footers=True) + self.request.connection.finish() + self._log() + self._finished = True + self.on_finish() + self._break_cycles() + return future + + def detach(self) -> iostream.IOStream: + """Take control of the underlying stream. + + Returns the underlying `.IOStream` object and stops all + further HTTP processing. Intended for implementing protocols + like websockets that tunnel over an HTTP handshake. + + This method is only supported when HTTP/1.1 is used. + + .. versionadded:: 5.1 + """ + self._finished = True + # TODO: add detach to HTTPConnection? + return self.request.connection.detach() # type: ignore + + def _break_cycles(self) -> None: + # Break up a reference cycle between this handler and the + # _ui_module closures to allow for faster GC on CPython. + self.ui = None # type: ignore + + def send_error(self, status_code: int = 500, **kwargs: Any) -> None: + """Sends the given HTTP error code to the browser. + + If `flush()` has already been called, it is not possible to send + an error, so this method will simply terminate the response. + If output has been written but not yet flushed, it will be discarded + and replaced with the error page. + + Override `write_error()` to customize the error page that is returned. + Additional keyword arguments are passed through to `write_error`. + """ + if self._headers_written: + gen_log.error("Cannot send error response after headers written") + if not self._finished: + # If we get an error between writing headers and finishing, + # we are unlikely to be able to finish due to a + # Content-Length mismatch. Try anyway to release the + # socket. + try: + self.finish() + except Exception: + gen_log.error("Failed to flush partial response", exc_info=True) + return + self.clear() + + reason = kwargs.get("reason") + if "exc_info" in kwargs: + exception = kwargs["exc_info"][1] + if isinstance(exception, HTTPError) and exception.reason: + reason = exception.reason + self.set_status(status_code, reason=reason) + try: + self.write_error(status_code, **kwargs) + except Exception: + app_log.error("Uncaught exception in write_error", exc_info=True) + if not self._finished: + self.finish() + + def write_error(self, status_code: int, **kwargs: Any) -> None: + """Override to implement custom error pages. + + ``write_error`` may call `write`, `render`, `set_header`, etc + to produce output as usual. + + If this error was caused by an uncaught exception (including + HTTPError), an ``exc_info`` triple will be available as + ``kwargs["exc_info"]``. Note that this exception may not be + the "current" exception for purposes of methods like + ``sys.exc_info()`` or ``traceback.format_exc``. + """ + if self.settings.get("serve_traceback") and "exc_info" in kwargs: + # in debug mode, try to send a traceback + self.set_header("Content-Type", "text/plain") + for line in traceback.format_exception(*kwargs["exc_info"]): + self.write(line) + self.finish() + else: + self.finish( + "<html><title>%(code)d: %(message)s</title>" + "<body>%(code)d: %(message)s</body></html>" + % {"code": status_code, "message": self._reason} + ) + + @property + def locale(self) -> tornado.locale.Locale: + """The locale for the current session. + + Determined by either `get_user_locale`, which you can override to + set the locale based on, e.g., a user preference stored in a + database, or `get_browser_locale`, which uses the ``Accept-Language`` + header. + + .. versionchanged: 4.1 + Added a property setter. + """ + if not hasattr(self, "_locale"): + loc = self.get_user_locale() + if loc is not None: + self._locale = loc + else: + self._locale = self.get_browser_locale() + assert self._locale + return self._locale + + @locale.setter + def locale(self, value: tornado.locale.Locale) -> None: + self._locale = value + + def get_user_locale(self) -> Optional[tornado.locale.Locale]: + """Override to determine the locale from the authenticated user. + + If None is returned, we fall back to `get_browser_locale()`. + + This method should return a `tornado.locale.Locale` object, + most likely obtained via a call like ``tornado.locale.get("en")`` + """ + return None + + def get_browser_locale(self, default: str = "en_US") -> tornado.locale.Locale: + """Determines the user's locale from ``Accept-Language`` header. + + See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.4 + """ + if "Accept-Language" in self.request.headers: + languages = self.request.headers["Accept-Language"].split(",") + locales = [] + for language in languages: + parts = language.strip().split(";") + if len(parts) > 1 and parts[1].strip().startswith("q="): + try: + score = float(parts[1].strip()[2:]) + if score < 0: + raise ValueError() + except (ValueError, TypeError): + score = 0.0 + else: + score = 1.0 + if score > 0: + locales.append((parts[0], score)) + if locales: + locales.sort(key=lambda pair: pair[1], reverse=True) + codes = [loc[0] for loc in locales] + return locale.get(*codes) + return locale.get(default) + + @property + def current_user(self) -> Any: + """The authenticated user for this request. + + This is set in one of two ways: + + * A subclass may override `get_current_user()`, which will be called + automatically the first time ``self.current_user`` is accessed. + `get_current_user()` will only be called once per request, + and is cached for future access:: + + def get_current_user(self): + user_cookie = self.get_signed_cookie("user") + if user_cookie: + return json.loads(user_cookie) + return None + + * It may be set as a normal variable, typically from an overridden + `prepare()`:: + + @gen.coroutine + def prepare(self): + user_id_cookie = self.get_signed_cookie("user_id") + if user_id_cookie: + self.current_user = yield load_user(user_id_cookie) + + Note that `prepare()` may be a coroutine while `get_current_user()` + may not, so the latter form is necessary if loading the user requires + asynchronous operations. + + The user object may be any type of the application's choosing. + """ + if not hasattr(self, "_current_user"): + self._current_user = self.get_current_user() + return self._current_user + + @current_user.setter + def current_user(self, value: Any) -> None: + self._current_user = value + + def get_current_user(self) -> Any: + """Override to determine the current user from, e.g., a cookie. + + This method may not be a coroutine. + """ + return None + + def get_login_url(self) -> str: + """Override to customize the login URL based on the request. + + By default, we use the ``login_url`` application setting. + """ + self.require_setting("login_url", "@tornado.web.authenticated") + return self.application.settings["login_url"] + + def get_template_path(self) -> Optional[str]: + """Override to customize template path for each handler. + + By default, we use the ``template_path`` application setting. + Return None to load templates relative to the calling file. + """ + return self.application.settings.get("template_path") + + @property + def xsrf_token(self) -> bytes: + """The XSRF-prevention token for the current user/session. + + To prevent cross-site request forgery, we set an '_xsrf' cookie + and include the same '_xsrf' value as an argument with all POST + requests. If the two do not match, we reject the form submission + as a potential forgery. + + See http://en.wikipedia.org/wiki/Cross-site_request_forgery + + This property is of type `bytes`, but it contains only ASCII + characters. If a character string is required, there is no + need to base64-encode it; just decode the byte string as + UTF-8. + + .. versionchanged:: 3.2.2 + The xsrf token will now be have a random mask applied in every + request, which makes it safe to include the token in pages + that are compressed. See http://breachattack.com for more + information on the issue fixed by this change. Old (version 1) + cookies will be converted to version 2 when this method is called + unless the ``xsrf_cookie_version`` `Application` setting is + set to 1. + + .. versionchanged:: 4.3 + The ``xsrf_cookie_kwargs`` `Application` setting may be + used to supply additional cookie options (which will be + passed directly to `set_cookie`). For example, + ``xsrf_cookie_kwargs=dict(httponly=True, secure=True)`` + will set the ``secure`` and ``httponly`` flags on the + ``_xsrf`` cookie. + """ + if not hasattr(self, "_xsrf_token"): + version, token, timestamp = self._get_raw_xsrf_token() + output_version = self.settings.get("xsrf_cookie_version", 2) + cookie_kwargs = self.settings.get("xsrf_cookie_kwargs", {}) + if output_version == 1: + self._xsrf_token = binascii.b2a_hex(token) + elif output_version == 2: + mask = os.urandom(4) + self._xsrf_token = b"|".join( + [ + b"2", + binascii.b2a_hex(mask), + binascii.b2a_hex(_websocket_mask(mask, token)), + utf8(str(int(timestamp))), + ] + ) + else: + raise ValueError("unknown xsrf cookie version %d", output_version) + if version is None: + if self.current_user and "expires_days" not in cookie_kwargs: + cookie_kwargs["expires_days"] = 30 + cookie_name = self.settings.get("xsrf_cookie_name", "_xsrf") + self.set_cookie(cookie_name, self._xsrf_token, **cookie_kwargs) + return self._xsrf_token + + def _get_raw_xsrf_token(self) -> Tuple[Optional[int], bytes, float]: + """Read or generate the xsrf token in its raw form. + + The raw_xsrf_token is a tuple containing: + + * version: the version of the cookie from which this token was read, + or None if we generated a new token in this request. + * token: the raw token data; random (non-ascii) bytes. + * timestamp: the time this token was generated (will not be accurate + for version 1 cookies) + """ + if not hasattr(self, "_raw_xsrf_token"): + cookie_name = self.settings.get("xsrf_cookie_name", "_xsrf") + cookie = self.get_cookie(cookie_name) + if cookie: + version, token, timestamp = self._decode_xsrf_token(cookie) + else: + version, token, timestamp = None, None, None + if token is None: + version = None + token = os.urandom(16) + timestamp = time.time() + assert token is not None + assert timestamp is not None + self._raw_xsrf_token = (version, token, timestamp) + return self._raw_xsrf_token + + def _decode_xsrf_token( + self, cookie: str + ) -> Tuple[Optional[int], Optional[bytes], Optional[float]]: + """Convert a cookie string into a the tuple form returned by + _get_raw_xsrf_token. + """ + + try: + m = _signed_value_version_re.match(utf8(cookie)) + + if m: + version = int(m.group(1)) + if version == 2: + _, mask_str, masked_token, timestamp_str = cookie.split("|") + + mask = binascii.a2b_hex(utf8(mask_str)) + token = _websocket_mask(mask, binascii.a2b_hex(utf8(masked_token))) + timestamp = int(timestamp_str) + return version, token, timestamp + else: + # Treat unknown versions as not present instead of failing. + raise Exception("Unknown xsrf cookie version") + else: + version = 1 + try: + token = binascii.a2b_hex(utf8(cookie)) + except (binascii.Error, TypeError): + token = utf8(cookie) + # We don't have a usable timestamp in older versions. + timestamp = int(time.time()) + return (version, token, timestamp) + except Exception: + # Catch exceptions and return nothing instead of failing. + gen_log.debug("Uncaught exception in _decode_xsrf_token", exc_info=True) + return None, None, None + + def check_xsrf_cookie(self) -> None: + """Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument. + + To prevent cross-site request forgery, we set an ``_xsrf`` + cookie and include the same value as a non-cookie + field with all ``POST`` requests. If the two do not match, we + reject the form submission as a potential forgery. + + The ``_xsrf`` value may be set as either a form field named ``_xsrf`` + or in a custom HTTP header named ``X-XSRFToken`` or ``X-CSRFToken`` + (the latter is accepted for compatibility with Django). + + See http://en.wikipedia.org/wiki/Cross-site_request_forgery + + .. versionchanged:: 3.2.2 + Added support for cookie version 2. Both versions 1 and 2 are + supported. + """ + # Prior to release 1.1.1, this check was ignored if the HTTP header + # ``X-Requested-With: XMLHTTPRequest`` was present. This exception + # has been shown to be insecure and has been removed. For more + # information please see + # http://www.djangoproject.com/weblog/2011/feb/08/security/ + # http://weblog.rubyonrails.org/2011/2/8/csrf-protection-bypass-in-ruby-on-rails + token = ( + self.get_argument("_xsrf", None) + or self.request.headers.get("X-Xsrftoken") + or self.request.headers.get("X-Csrftoken") + ) + if not token: + raise HTTPError(403, "'_xsrf' argument missing from POST") + _, token, _ = self._decode_xsrf_token(token) + _, expected_token, _ = self._get_raw_xsrf_token() + if not token: + raise HTTPError(403, "'_xsrf' argument has invalid format") + if not hmac.compare_digest(utf8(token), utf8(expected_token)): + raise HTTPError(403, "XSRF cookie does not match POST argument") + + def xsrf_form_html(self) -> str: + """An HTML ``<input/>`` element to be included with all POST forms. + + It defines the ``_xsrf`` input value, which we check on all POST + requests to prevent cross-site request forgery. If you have set + the ``xsrf_cookies`` application setting, you must include this + HTML within all of your HTML forms. + + In a template, this method should be called with ``{% module + xsrf_form_html() %}`` + + See `check_xsrf_cookie()` above for more information. + """ + return ( + '<input type="hidden" name="_xsrf" value="' + + escape.xhtml_escape(self.xsrf_token) + + '"/>' + ) + + def static_url( + self, path: str, include_host: Optional[bool] = None, **kwargs: Any + ) -> str: + """Returns a static URL for the given relative static file path. + + This method requires you set the ``static_path`` setting in your + application (which specifies the root directory of your static + files). + + This method returns a versioned url (by default appending + ``?v=<signature>``), which allows the static files to be + cached indefinitely. This can be disabled by passing + ``include_version=False`` (in the default implementation; + other static file implementations are not required to support + this, but they may support other options). + + By default this method returns URLs relative to the current + host, but if ``include_host`` is true the URL returned will be + absolute. If this handler has an ``include_host`` attribute, + that value will be used as the default for all `static_url` + calls that do not pass ``include_host`` as a keyword argument. + + """ + self.require_setting("static_path", "static_url") + get_url = self.settings.get( + "static_handler_class", StaticFileHandler + ).make_static_url + + if include_host is None: + include_host = getattr(self, "include_host", False) + + if include_host: + base = self.request.protocol + "://" + self.request.host + else: + base = "" + + return base + get_url(self.settings, path, **kwargs) + + def require_setting(self, name: str, feature: str = "this feature") -> None: + """Raises an exception if the given app setting is not defined.""" + if not self.application.settings.get(name): + raise Exception( + "You must define the '%s' setting in your " + "application to use %s" % (name, feature) + ) + + def reverse_url(self, name: str, *args: Any) -> str: + """Alias for `Application.reverse_url`.""" + return self.application.reverse_url(name, *args) + + def compute_etag(self) -> Optional[str]: + """Computes the etag header to be used for this request. + + By default uses a hash of the content written so far. + + May be overridden to provide custom etag implementations, + or may return None to disable tornado's default etag support. + """ + hasher = hashlib.sha1() + for part in self._write_buffer: + hasher.update(part) + return '"%s"' % hasher.hexdigest() + + def set_etag_header(self) -> None: + """Sets the response's Etag header using ``self.compute_etag()``. + + Note: no header will be set if ``compute_etag()`` returns ``None``. + + This method is called automatically when the request is finished. + """ + etag = self.compute_etag() + if etag is not None: + self.set_header("Etag", etag) + + def check_etag_header(self) -> bool: + """Checks the ``Etag`` header against requests's ``If-None-Match``. + + Returns ``True`` if the request's Etag matches and a 304 should be + returned. For example:: + + self.set_etag_header() + if self.check_etag_header(): + self.set_status(304) + return + + This method is called automatically when the request is finished, + but may be called earlier for applications that override + `compute_etag` and want to do an early check for ``If-None-Match`` + before completing the request. The ``Etag`` header should be set + (perhaps with `set_etag_header`) before calling this method. + """ + computed_etag = utf8(self._headers.get("Etag", "")) + # Find all weak and strong etag values from If-None-Match header + # because RFC 7232 allows multiple etag values in a single header. + etags = re.findall( + rb'\*|(?:W/)?"[^"]*"', utf8(self.request.headers.get("If-None-Match", "")) + ) + if not computed_etag or not etags: + return False + + match = False + if etags[0] == b"*": + match = True + else: + # Use a weak comparison when comparing entity-tags. + def val(x: bytes) -> bytes: + return x[2:] if x.startswith(b"W/") else x + + for etag in etags: + if val(etag) == val(computed_etag): + match = True + break + return match + + async def _execute( + self, transforms: List["OutputTransform"], *args: bytes, **kwargs: bytes + ) -> None: + """Executes this request with the given output transforms.""" + self._transforms = transforms + try: + if self.request.method not in self.SUPPORTED_METHODS: + raise HTTPError(405) + self.path_args = [self.decode_argument(arg) for arg in args] + self.path_kwargs = dict( + (k, self.decode_argument(v, name=k)) for (k, v) in kwargs.items() + ) + # If XSRF cookies are turned on, reject form submissions without + # the proper cookie + if self.request.method not in ( + "GET", + "HEAD", + "OPTIONS", + ) and self.application.settings.get("xsrf_cookies"): + self.check_xsrf_cookie() + + result = self.prepare() + if result is not None: + result = await result # type: ignore + if self._prepared_future is not None: + # Tell the Application we've finished with prepare() + # and are ready for the body to arrive. + future_set_result_unless_cancelled(self._prepared_future, None) + if self._finished: + return + + if _has_stream_request_body(self.__class__): + # In streaming mode request.body is a Future that signals + # the body has been completely received. The Future has no + # result; the data has been passed to self.data_received + # instead. + try: + await self.request._body_future + except iostream.StreamClosedError: + return + + method = getattr(self, self.request.method.lower()) + result = method(*self.path_args, **self.path_kwargs) + if result is not None: + result = await result + if self._auto_finish and not self._finished: + self.finish() + except Exception as e: + try: + self._handle_request_exception(e) + except Exception: + app_log.error("Exception in exception handler", exc_info=True) + finally: + # Unset result to avoid circular references + result = None + if self._prepared_future is not None and not self._prepared_future.done(): + # In case we failed before setting _prepared_future, do it + # now (to unblock the HTTP server). Note that this is not + # in a finally block to avoid GC issues prior to Python 3.4. + self._prepared_future.set_result(None) + + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: + """Implement this method to handle streamed request data. + + Requires the `.stream_request_body` decorator. + + May be a coroutine for flow control. + """ + raise NotImplementedError() + + def _log(self) -> None: + """Logs the current request. + + Sort of deprecated since this functionality was moved to the + Application, but left in place for the benefit of existing apps + that have overridden this method. + """ + self.application.log_request(self) + + def _request_summary(self) -> str: + return "%s %s (%s)" % ( + self.request.method, + self.request.uri, + self.request.remote_ip, + ) + + def _handle_request_exception(self, e: BaseException) -> None: + if isinstance(e, Finish): + # Not an error; just finish the request without logging. + if not self._finished: + self.finish(*e.args) + return + try: + self.log_exception(*sys.exc_info()) + except Exception: + # An error here should still get a best-effort send_error() + # to avoid leaking the connection. + app_log.error("Error in exception logger", exc_info=True) + if self._finished: + # Extra errors after the request has been finished should + # be logged, but there is no reason to continue to try and + # send a response. + return + if isinstance(e, HTTPError): + self.send_error(e.status_code, exc_info=sys.exc_info()) + else: + self.send_error(500, exc_info=sys.exc_info()) + + def log_exception( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + """Override to customize logging of uncaught exceptions. + + By default logs instances of `HTTPError` as warnings without + stack traces (on the ``tornado.general`` logger), and all + other exceptions as errors with stack traces (on the + ``tornado.application`` logger). + + .. versionadded:: 3.1 + """ + if isinstance(value, HTTPError): + if value.log_message: + format = "%d %s: " + value.log_message + args = [value.status_code, self._request_summary()] + list(value.args) + gen_log.warning(format, *args) + else: + app_log.error( + "Uncaught exception %s\n%r", + self._request_summary(), + self.request, + exc_info=(typ, value, tb), # type: ignore + ) + + def _ui_module(self, name: str, module: Type["UIModule"]) -> Callable[..., str]: + def render(*args, **kwargs) -> str: # type: ignore + if not hasattr(self, "_active_modules"): + self._active_modules = {} # type: Dict[str, UIModule] + if name not in self._active_modules: + self._active_modules[name] = module(self) + rendered = self._active_modules[name].render(*args, **kwargs) + return rendered + + return render + + def _ui_method(self, method: Callable[..., str]) -> Callable[..., str]: + return lambda *args, **kwargs: method(self, *args, **kwargs) + + def _clear_representation_headers(self) -> None: + # 304 responses should not contain representation metadata + # headers (defined in + # https://tools.ietf.org/html/rfc7231#section-3.1) + # not explicitly allowed by + # https://tools.ietf.org/html/rfc7232#section-4.1 + headers = ["Content-Encoding", "Content-Language", "Content-Type"] + for h in headers: + self.clear_header(h) + + +_RequestHandlerType = TypeVar("_RequestHandlerType", bound=RequestHandler) + + +def stream_request_body(cls: Type[_RequestHandlerType]) -> Type[_RequestHandlerType]: + """Apply to `RequestHandler` subclasses to enable streaming body support. + + This decorator implies the following changes: + + * `.HTTPServerRequest.body` is undefined, and body arguments will not + be included in `RequestHandler.get_argument`. + * `RequestHandler.prepare` is called when the request headers have been + read instead of after the entire body has been read. + * The subclass must define a method ``data_received(self, data):``, which + will be called zero or more times as data is available. Note that + if the request has an empty body, ``data_received`` may not be called. + * ``prepare`` and ``data_received`` may return Futures (such as via + ``@gen.coroutine``, in which case the next method will not be called + until those futures have completed. + * The regular HTTP method (``post``, ``put``, etc) will be called after + the entire body has been read. + + See the `file receiver demo <https://github.com/tornadoweb/tornado/tree/stable/demos/file_upload/>`_ + for example usage. + """ # noqa: E501 + if not issubclass(cls, RequestHandler): + raise TypeError("expected subclass of RequestHandler, got %r", cls) + cls._stream_request_body = True + return cls + + +def _has_stream_request_body(cls: Type[RequestHandler]) -> bool: + if not issubclass(cls, RequestHandler): + raise TypeError("expected subclass of RequestHandler, got %r", cls) + return cls._stream_request_body + + +def removeslash( + method: Callable[..., Optional[Awaitable[None]]] +) -> Callable[..., Optional[Awaitable[None]]]: + """Use this decorator to remove trailing slashes from the request path. + + For example, a request to ``/foo/`` would redirect to ``/foo`` with this + decorator. Your request handler mapping should use a regular expression + like ``r'/foo/*'`` in conjunction with using the decorator. + """ + + @functools.wraps(method) + def wrapper( # type: ignore + self: RequestHandler, *args, **kwargs + ) -> Optional[Awaitable[None]]: + if self.request.path.endswith("/"): + if self.request.method in ("GET", "HEAD"): + uri = self.request.path.rstrip("/") + if uri: # don't try to redirect '/' to '' + if self.request.query: + uri += "?" + self.request.query + self.redirect(uri, permanent=True) + return None + else: + raise HTTPError(404) + return method(self, *args, **kwargs) + + return wrapper + + +def addslash( + method: Callable[..., Optional[Awaitable[None]]] +) -> Callable[..., Optional[Awaitable[None]]]: + """Use this decorator to add a missing trailing slash to the request path. + + For example, a request to ``/foo`` would redirect to ``/foo/`` with this + decorator. Your request handler mapping should use a regular expression + like ``r'/foo/?'`` in conjunction with using the decorator. + """ + + @functools.wraps(method) + def wrapper( # type: ignore + self: RequestHandler, *args, **kwargs + ) -> Optional[Awaitable[None]]: + if not self.request.path.endswith("/"): + if self.request.method in ("GET", "HEAD"): + uri = self.request.path + "/" + if self.request.query: + uri += "?" + self.request.query + self.redirect(uri, permanent=True) + return None + raise HTTPError(404) + return method(self, *args, **kwargs) + + return wrapper + + +class _ApplicationRouter(ReversibleRuleRouter): + """Routing implementation used internally by `Application`. + + Provides a binding between `Application` and `RequestHandler`. + This implementation extends `~.routing.ReversibleRuleRouter` in a couple of ways: + * it allows to use `RequestHandler` subclasses as `~.routing.Rule` target and + * it allows to use a list/tuple of rules as `~.routing.Rule` target. + ``process_rule`` implementation will substitute this list with an appropriate + `_ApplicationRouter` instance. + """ + + def __init__( + self, application: "Application", rules: Optional[_RuleList] = None + ) -> None: + assert isinstance(application, Application) + self.application = application + super().__init__(rules) + + def process_rule(self, rule: Rule) -> Rule: + rule = super().process_rule(rule) + + if isinstance(rule.target, (list, tuple)): + rule.target = _ApplicationRouter( + self.application, rule.target # type: ignore + ) + + return rule + + def get_target_delegate( + self, target: Any, request: httputil.HTTPServerRequest, **target_params: Any + ) -> Optional[httputil.HTTPMessageDelegate]: + if isclass(target) and issubclass(target, RequestHandler): + return self.application.get_handler_delegate( + request, target, **target_params + ) + + return super().get_target_delegate(target, request, **target_params) + + +class Application(ReversibleRouter): + r"""A collection of request handlers that make up a web application. + + Instances of this class are callable and can be passed directly to + HTTPServer to serve the application:: + + application = web.Application([ + (r"/", MainPageHandler), + ]) + http_server = httpserver.HTTPServer(application) + http_server.listen(8080) + + The constructor for this class takes in a list of `~.routing.Rule` + objects or tuples of values corresponding to the arguments of + `~.routing.Rule` constructor: ``(matcher, target, [target_kwargs], [name])``, + the values in square brackets being optional. The default matcher is + `~.routing.PathMatches`, so ``(regexp, target)`` tuples can also be used + instead of ``(PathMatches(regexp), target)``. + + A common routing target is a `RequestHandler` subclass, but you can also + use lists of rules as a target, which create a nested routing configuration:: + + application = web.Application([ + (HostMatches("example.com"), [ + (r"/", MainPageHandler), + (r"/feed", FeedHandler), + ]), + ]) + + In addition to this you can use nested `~.routing.Router` instances, + `~.httputil.HTTPMessageDelegate` subclasses and callables as routing targets + (see `~.routing` module docs for more information). + + When we receive requests, we iterate over the list in order and + instantiate an instance of the first request class whose regexp + matches the request path. The request class can be specified as + either a class object or a (fully-qualified) name. + + A dictionary may be passed as the third element (``target_kwargs``) + of the tuple, which will be used as keyword arguments to the handler's + constructor and `~RequestHandler.initialize` method. This pattern + is used for the `StaticFileHandler` in this example (note that a + `StaticFileHandler` can be installed automatically with the + static_path setting described below):: + + application = web.Application([ + (r"/static/(.*)", web.StaticFileHandler, {"path": "/var/www"}), + ]) + + We support virtual hosts with the `add_handlers` method, which takes in + a host regular expression as the first argument:: + + application.add_handlers(r"www\.myhost\.com", [ + (r"/article/([0-9]+)", ArticleHandler), + ]) + + If there's no match for the current request's host, then ``default_host`` + parameter value is matched against host regular expressions. + + + .. warning:: + + Applications that do not use TLS may be vulnerable to :ref:`DNS + rebinding <dnsrebinding>` attacks. This attack is especially + relevant to applications that only listen on ``127.0.0.1`` or + other private networks. Appropriate host patterns must be used + (instead of the default of ``r'.*'``) to prevent this risk. The + ``default_host`` argument must not be used in applications that + may be vulnerable to DNS rebinding. + + You can serve static files by sending the ``static_path`` setting + as a keyword argument. We will serve those files from the + ``/static/`` URI (this is configurable with the + ``static_url_prefix`` setting), and we will serve ``/favicon.ico`` + and ``/robots.txt`` from the same directory. A custom subclass of + `StaticFileHandler` can be specified with the + ``static_handler_class`` setting. + + .. versionchanged:: 4.5 + Integration with the new `tornado.routing` module. + + """ + + def __init__( + self, + handlers: Optional[_RuleList] = None, + default_host: Optional[str] = None, + transforms: Optional[List[Type["OutputTransform"]]] = None, + **settings: Any, + ) -> None: + if transforms is None: + self.transforms = [] # type: List[Type[OutputTransform]] + if settings.get("compress_response") or settings.get("gzip"): + self.transforms.append(GZipContentEncoding) + else: + self.transforms = transforms + self.default_host = default_host + self.settings = settings + self.ui_modules = { + "linkify": _linkify, + "xsrf_form_html": _xsrf_form_html, + "Template": TemplateModule, + } + self.ui_methods = {} # type: Dict[str, Callable[..., str]] + self._load_ui_modules(settings.get("ui_modules", {})) + self._load_ui_methods(settings.get("ui_methods", {})) + if self.settings.get("static_path"): + path = self.settings["static_path"] + handlers = list(handlers or []) + static_url_prefix = settings.get("static_url_prefix", "/static/") + static_handler_class = settings.get( + "static_handler_class", StaticFileHandler + ) + static_handler_args = settings.get("static_handler_args", {}) + static_handler_args["path"] = path + for pattern in [ + re.escape(static_url_prefix) + r"(.*)", + r"/(favicon\.ico)", + r"/(robots\.txt)", + ]: + handlers.insert(0, (pattern, static_handler_class, static_handler_args)) + + if self.settings.get("debug"): + self.settings.setdefault("autoreload", True) + self.settings.setdefault("compiled_template_cache", False) + self.settings.setdefault("static_hash_cache", False) + self.settings.setdefault("serve_traceback", True) + + self.wildcard_router = _ApplicationRouter(self, handlers) + self.default_router = _ApplicationRouter( + self, [Rule(AnyMatches(), self.wildcard_router)] + ) + + # Automatically reload modified modules + if self.settings.get("autoreload"): + from tornado import autoreload + + autoreload.start() + + def listen( + self, + port: int, + address: Optional[str] = None, + *, + family: socket.AddressFamily = socket.AF_UNSPEC, + backlog: int = tornado.netutil._DEFAULT_BACKLOG, + flags: Optional[int] = None, + reuse_port: bool = False, + **kwargs: Any, + ) -> HTTPServer: + """Starts an HTTP server for this application on the given port. + + This is a convenience alias for creating an `.HTTPServer` object and + calling its listen method. Keyword arguments not supported by + `HTTPServer.listen <.TCPServer.listen>` are passed to the `.HTTPServer` + constructor. For advanced uses (e.g. multi-process mode), do not use + this method; create an `.HTTPServer` and call its + `.TCPServer.bind`/`.TCPServer.start` methods directly. + + Note that after calling this method you still need to call + ``IOLoop.current().start()`` (or run within ``asyncio.run``) to start + the server. + + Returns the `.HTTPServer` object. + + .. versionchanged:: 4.3 + Now returns the `.HTTPServer` object. + + .. versionchanged:: 6.2 + Added support for new keyword arguments in `.TCPServer.listen`, + including ``reuse_port``. + """ + server = HTTPServer(self, **kwargs) + server.listen( + port, + address=address, + family=family, + backlog=backlog, + flags=flags, + reuse_port=reuse_port, + ) + return server + + def add_handlers(self, host_pattern: str, host_handlers: _RuleList) -> None: + """Appends the given handlers to our handler list. + + Host patterns are processed sequentially in the order they were + added. All matching patterns will be considered. + """ + host_matcher = HostMatches(host_pattern) + rule = Rule(host_matcher, _ApplicationRouter(self, host_handlers)) + + self.default_router.rules.insert(-1, rule) + + if self.default_host is not None: + self.wildcard_router.add_rules( + [(DefaultHostMatches(self, host_matcher.host_pattern), host_handlers)] + ) + + def add_transform(self, transform_class: Type["OutputTransform"]) -> None: + self.transforms.append(transform_class) + + def _load_ui_methods(self, methods: Any) -> None: + if isinstance(methods, types.ModuleType): + self._load_ui_methods(dict((n, getattr(methods, n)) for n in dir(methods))) + elif isinstance(methods, list): + for m in methods: + self._load_ui_methods(m) + else: + for name, fn in methods.items(): + if ( + not name.startswith("_") + and hasattr(fn, "__call__") + and name[0].lower() == name[0] + ): + self.ui_methods[name] = fn + + def _load_ui_modules(self, modules: Any) -> None: + if isinstance(modules, types.ModuleType): + self._load_ui_modules(dict((n, getattr(modules, n)) for n in dir(modules))) + elif isinstance(modules, list): + for m in modules: + self._load_ui_modules(m) + else: + assert isinstance(modules, dict) + for name, cls in modules.items(): + try: + if issubclass(cls, UIModule): + self.ui_modules[name] = cls + except TypeError: + pass + + def __call__( + self, request: httputil.HTTPServerRequest + ) -> Optional[Awaitable[None]]: + # Legacy HTTPServer interface + dispatcher = self.find_handler(request) + return dispatcher.execute() + + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> "_HandlerDelegate": + route = self.default_router.find_handler(request) + if route is not None: + return cast("_HandlerDelegate", route) + + if self.settings.get("default_handler_class"): + return self.get_handler_delegate( + request, + self.settings["default_handler_class"], + self.settings.get("default_handler_args", {}), + ) + + return self.get_handler_delegate(request, ErrorHandler, {"status_code": 404}) + + def get_handler_delegate( + self, + request: httputil.HTTPServerRequest, + target_class: Type[RequestHandler], + target_kwargs: Optional[Dict[str, Any]] = None, + path_args: Optional[List[bytes]] = None, + path_kwargs: Optional[Dict[str, bytes]] = None, + ) -> "_HandlerDelegate": + """Returns `~.httputil.HTTPMessageDelegate` that can serve a request + for application and `RequestHandler` subclass. + + :arg httputil.HTTPServerRequest request: current HTTP request. + :arg RequestHandler target_class: a `RequestHandler` class. + :arg dict target_kwargs: keyword arguments for ``target_class`` constructor. + :arg list path_args: positional arguments for ``target_class`` HTTP method that + will be executed while handling a request (``get``, ``post`` or any other). + :arg dict path_kwargs: keyword arguments for ``target_class`` HTTP method. + """ + return _HandlerDelegate( + self, request, target_class, target_kwargs, path_args, path_kwargs + ) + + def reverse_url(self, name: str, *args: Any) -> str: + """Returns a URL path for handler named ``name`` + + The handler must be added to the application as a named `URLSpec`. + + Args will be substituted for capturing groups in the `URLSpec` regex. + They will be converted to strings if necessary, encoded as utf8, + and url-escaped. + """ + reversed_url = self.default_router.reverse_url(name, *args) + if reversed_url is not None: + return reversed_url + + raise KeyError("%s not found in named urls" % name) + + def log_request(self, handler: RequestHandler) -> None: + """Writes a completed HTTP request to the logs. + + By default writes to the python root logger. To change + this behavior either subclass Application and override this method, + or pass a function in the application settings dictionary as + ``log_function``. + """ + if "log_function" in self.settings: + self.settings["log_function"](handler) + return + if handler.get_status() < 400: + log_method = access_log.info + elif handler.get_status() < 500: + log_method = access_log.warning + else: + log_method = access_log.error + request_time = 1000.0 * handler.request.request_time() + log_method( + "%d %s %.2fms", + handler.get_status(), + handler._request_summary(), + request_time, + ) + + +class _HandlerDelegate(httputil.HTTPMessageDelegate): + def __init__( + self, + application: Application, + request: httputil.HTTPServerRequest, + handler_class: Type[RequestHandler], + handler_kwargs: Optional[Dict[str, Any]], + path_args: Optional[List[bytes]], + path_kwargs: Optional[Dict[str, bytes]], + ) -> None: + self.application = application + self.connection = request.connection + self.request = request + self.handler_class = handler_class + self.handler_kwargs = handler_kwargs or {} + self.path_args = path_args or [] + self.path_kwargs = path_kwargs or {} + self.chunks = [] # type: List[bytes] + self.stream_request_body = _has_stream_request_body(self.handler_class) + + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: + if self.stream_request_body: + self.request._body_future = Future() + return self.execute() + return None + + def data_received(self, data: bytes) -> Optional[Awaitable[None]]: + if self.stream_request_body: + return self.handler.data_received(data) + else: + self.chunks.append(data) + return None + + def finish(self) -> None: + if self.stream_request_body: + future_set_result_unless_cancelled(self.request._body_future, None) + else: + self.request.body = b"".join(self.chunks) + self.request._parse_body() + self.execute() + + def on_connection_close(self) -> None: + if self.stream_request_body: + self.handler.on_connection_close() + else: + self.chunks = None # type: ignore + + def execute(self) -> Optional[Awaitable[None]]: + # If template cache is disabled (usually in the debug mode), + # re-compile templates and reload static files on every + # request so you don't need to restart to see changes + if not self.application.settings.get("compiled_template_cache", True): + with RequestHandler._template_loader_lock: + for loader in RequestHandler._template_loaders.values(): + loader.reset() + if not self.application.settings.get("static_hash_cache", True): + static_handler_class = self.application.settings.get( + "static_handler_class", StaticFileHandler + ) + static_handler_class.reset() + + self.handler = self.handler_class( + self.application, self.request, **self.handler_kwargs + ) + transforms = [t(self.request) for t in self.application.transforms] + + if self.stream_request_body: + self.handler._prepared_future = Future() + # Note that if an exception escapes handler._execute it will be + # trapped in the Future it returns (which we are ignoring here, + # leaving it to be logged when the Future is GC'd). + # However, that shouldn't happen because _execute has a blanket + # except handler, and we cannot easily access the IOLoop here to + # call add_future (because of the requirement to remain compatible + # with WSGI) + fut = gen.convert_yielded( + self.handler._execute(transforms, *self.path_args, **self.path_kwargs) + ) + fut.add_done_callback(lambda f: f.result()) + # If we are streaming the request body, then execute() is finished + # when the handler has prepared to receive the body. If not, + # it doesn't matter when execute() finishes (so we return None) + return self.handler._prepared_future + + +class HTTPError(Exception): + """An exception that will turn into an HTTP error response. + + Raising an `HTTPError` is a convenient alternative to calling + `RequestHandler.send_error` since it automatically ends the + current function. + + To customize the response sent with an `HTTPError`, override + `RequestHandler.write_error`. + + :arg int status_code: HTTP status code. Must be listed in + `httplib.responses <http.client.responses>` unless the ``reason`` + keyword argument is given. + :arg str log_message: Message to be written to the log for this error + (will not be shown to the user unless the `Application` is in debug + mode). May contain ``%s``-style placeholders, which will be filled + in with remaining positional parameters. + :arg str reason: Keyword-only argument. The HTTP "reason" phrase + to pass in the status line along with ``status_code``. Normally + determined automatically from ``status_code``, but can be used + to use a non-standard numeric code. + """ + + def __init__( + self, + status_code: int = 500, + log_message: Optional[str] = None, + *args: Any, + **kwargs: Any, + ) -> None: + self.status_code = status_code + self.log_message = log_message + self.args = args + self.reason = kwargs.get("reason", None) + if log_message and not args: + self.log_message = log_message.replace("%", "%%") + + def __str__(self) -> str: + message = "HTTP %d: %s" % ( + self.status_code, + self.reason or httputil.responses.get(self.status_code, "Unknown"), + ) + if self.log_message: + return message + " (" + (self.log_message % self.args) + ")" + else: + return message + + +class Finish(Exception): + """An exception that ends the request without producing an error response. + + When `Finish` is raised in a `RequestHandler`, the request will + end (calling `RequestHandler.finish` if it hasn't already been + called), but the error-handling methods (including + `RequestHandler.write_error`) will not be called. + + If `Finish()` was created with no arguments, the pending response + will be sent as-is. If `Finish()` was given an argument, that + argument will be passed to `RequestHandler.finish()`. + + This can be a more convenient way to implement custom error pages + than overriding ``write_error`` (especially in library code):: + + if self.current_user is None: + self.set_status(401) + self.set_header('WWW-Authenticate', 'Basic realm="something"') + raise Finish() + + .. versionchanged:: 4.3 + Arguments passed to ``Finish()`` will be passed on to + `RequestHandler.finish`. + """ + + pass + + +class MissingArgumentError(HTTPError): + """Exception raised by `RequestHandler.get_argument`. + + This is a subclass of `HTTPError`, so if it is uncaught a 400 response + code will be used instead of 500 (and a stack trace will not be logged). + + .. versionadded:: 3.1 + """ + + def __init__(self, arg_name: str) -> None: + super().__init__(400, "Missing argument %s" % arg_name) + self.arg_name = arg_name + + +class ErrorHandler(RequestHandler): + """Generates an error response with ``status_code`` for all requests.""" + + def initialize(self, status_code: int) -> None: + self.set_status(status_code) + + def prepare(self) -> None: + raise HTTPError(self._status_code) + + def check_xsrf_cookie(self) -> None: + # POSTs to an ErrorHandler don't actually have side effects, + # so we don't need to check the xsrf token. This allows POSTs + # to the wrong url to return a 404 instead of 403. + pass + + +class RedirectHandler(RequestHandler): + """Redirects the client to the given URL for all GET requests. + + You should provide the keyword argument ``url`` to the handler, e.g.:: + + application = web.Application([ + (r"/oldpath", web.RedirectHandler, {"url": "/newpath"}), + ]) + + `RedirectHandler` supports regular expression substitutions. E.g., to + swap the first and second parts of a path while preserving the remainder:: + + application = web.Application([ + (r"/(.*?)/(.*?)/(.*)", web.RedirectHandler, {"url": "/{1}/{0}/{2}"}), + ]) + + The final URL is formatted with `str.format` and the substrings that match + the capturing groups. In the above example, a request to "/a/b/c" would be + formatted like:: + + str.format("/{1}/{0}/{2}", "a", "b", "c") # -> "/b/a/c" + + Use Python's :ref:`format string syntax <formatstrings>` to customize how + values are substituted. + + .. versionchanged:: 4.5 + Added support for substitutions into the destination URL. + + .. versionchanged:: 5.0 + If any query arguments are present, they will be copied to the + destination URL. + """ + + def initialize(self, url: str, permanent: bool = True) -> None: + self._url = url + self._permanent = permanent + + def get(self, *args: Any, **kwargs: Any) -> None: + to_url = self._url.format(*args, **kwargs) + if self.request.query_arguments: + # TODO: figure out typing for the next line. + to_url = httputil.url_concat( + to_url, + list(httputil.qs_to_qsl(self.request.query_arguments)), # type: ignore + ) + self.redirect(to_url, permanent=self._permanent) + + +class StaticFileHandler(RequestHandler): + """A simple handler that can serve static content from a directory. + + A `StaticFileHandler` is configured automatically if you pass the + ``static_path`` keyword argument to `Application`. This handler + can be customized with the ``static_url_prefix``, ``static_handler_class``, + and ``static_handler_args`` settings. + + To map an additional path to this handler for a static data directory + you would add a line to your application like:: + + application = web.Application([ + (r"/content/(.*)", web.StaticFileHandler, {"path": "/var/www"}), + ]) + + The handler constructor requires a ``path`` argument, which specifies the + local root directory of the content to be served. + + Note that a capture group in the regex is required to parse the value for + the ``path`` argument to the get() method (different than the constructor + argument above); see `URLSpec` for details. + + To serve a file like ``index.html`` automatically when a directory is + requested, set ``static_handler_args=dict(default_filename="index.html")`` + in your application settings, or add ``default_filename`` as an initializer + argument for your ``StaticFileHandler``. + + To maximize the effectiveness of browser caching, this class supports + versioned urls (by default using the argument ``?v=``). If a version + is given, we instruct the browser to cache this file indefinitely. + `make_static_url` (also available as `RequestHandler.static_url`) can + be used to construct a versioned url. + + This handler is intended primarily for use in development and light-duty + file serving; for heavy traffic it will be more efficient to use + a dedicated static file server (such as nginx or Apache). We support + the HTTP ``Accept-Ranges`` mechanism to return partial content (because + some browsers require this functionality to be present to seek in + HTML5 audio or video). + + **Subclassing notes** + + This class is designed to be extensible by subclassing, but because + of the way static urls are generated with class methods rather than + instance methods, the inheritance patterns are somewhat unusual. + Be sure to use the ``@classmethod`` decorator when overriding a + class method. Instance methods may use the attributes ``self.path`` + ``self.absolute_path``, and ``self.modified``. + + Subclasses should only override methods discussed in this section; + overriding other methods is error-prone. Overriding + ``StaticFileHandler.get`` is particularly problematic due to the + tight coupling with ``compute_etag`` and other methods. + + To change the way static urls are generated (e.g. to match the behavior + of another server or CDN), override `make_static_url`, `parse_url_path`, + `get_cache_time`, and/or `get_version`. + + To replace all interaction with the filesystem (e.g. to serve + static content from a database), override `get_content`, + `get_content_size`, `get_modified_time`, `get_absolute_path`, and + `validate_absolute_path`. + + .. versionchanged:: 3.1 + Many of the methods for subclasses were added in Tornado 3.1. + """ + + CACHE_MAX_AGE = 86400 * 365 * 10 # 10 years + + _static_hashes = {} # type: Dict[str, Optional[str]] + _lock = threading.Lock() # protects _static_hashes + + def initialize(self, path: str, default_filename: Optional[str] = None) -> None: + self.root = path + self.default_filename = default_filename + + @classmethod + def reset(cls) -> None: + with cls._lock: + cls._static_hashes = {} + + def head(self, path: str) -> Awaitable[None]: + return self.get(path, include_body=False) + + async def get(self, path: str, include_body: bool = True) -> None: + # Set up our path instance variables. + self.path = self.parse_url_path(path) + del path # make sure we don't refer to path instead of self.path again + absolute_path = self.get_absolute_path(self.root, self.path) + self.absolute_path = self.validate_absolute_path(self.root, absolute_path) + if self.absolute_path is None: + return + + self.modified = self.get_modified_time() + self.set_headers() + + if self.should_return_304(): + self.set_status(304) + return + + request_range = None + range_header = self.request.headers.get("Range") + if range_header: + # As per RFC 2616 14.16, if an invalid Range header is specified, + # the request will be treated as if the header didn't exist. + request_range = httputil._parse_request_range(range_header) + + size = self.get_content_size() + if request_range: + start, end = request_range + if start is not None and start < 0: + start += size + if start < 0: + start = 0 + if ( + start is not None + and (start >= size or (end is not None and start >= end)) + ) or end == 0: + # As per RFC 2616 14.35.1, a range is not satisfiable only: if + # the first requested byte is equal to or greater than the + # content, or when a suffix with length 0 is specified. + # https://tools.ietf.org/html/rfc7233#section-2.1 + # A byte-range-spec is invalid if the last-byte-pos value is present + # and less than the first-byte-pos. + self.set_status(416) # Range Not Satisfiable + self.set_header("Content-Type", "text/plain") + self.set_header("Content-Range", "bytes */%s" % (size,)) + return + if end is not None and end > size: + # Clients sometimes blindly use a large range to limit their + # download size; cap the endpoint at the actual file size. + end = size + # Note: only return HTTP 206 if less than the entire range has been + # requested. Not only is this semantically correct, but Chrome + # refuses to play audio if it gets an HTTP 206 in response to + # ``Range: bytes=0-``. + if size != (end or size) - (start or 0): + self.set_status(206) # Partial Content + self.set_header( + "Content-Range", httputil._get_content_range(start, end, size) + ) + else: + start = end = None + + if start is not None and end is not None: + content_length = end - start + elif end is not None: + content_length = end + elif start is not None: + content_length = size - start + else: + content_length = size + self.set_header("Content-Length", content_length) + + if include_body: + content = self.get_content(self.absolute_path, start, end) + if isinstance(content, bytes): + content = [content] + for chunk in content: + try: + self.write(chunk) + await self.flush() + except iostream.StreamClosedError: + return + else: + assert self.request.method == "HEAD" + + def compute_etag(self) -> Optional[str]: + """Sets the ``Etag`` header based on static url version. + + This allows efficient ``If-None-Match`` checks against cached + versions, and sends the correct ``Etag`` for a partial response + (i.e. the same ``Etag`` as the full file). + + .. versionadded:: 3.1 + """ + assert self.absolute_path is not None + version_hash = self._get_cached_version(self.absolute_path) + if not version_hash: + return None + return '"%s"' % (version_hash,) + + def set_headers(self) -> None: + """Sets the content and caching headers on the response. + + .. versionadded:: 3.1 + """ + self.set_header("Accept-Ranges", "bytes") + self.set_etag_header() + + if self.modified is not None: + self.set_header("Last-Modified", self.modified) + + content_type = self.get_content_type() + if content_type: + self.set_header("Content-Type", content_type) + + cache_time = self.get_cache_time(self.path, self.modified, content_type) + if cache_time > 0: + self.set_header( + "Expires", + datetime.datetime.utcnow() + datetime.timedelta(seconds=cache_time), + ) + self.set_header("Cache-Control", "max-age=" + str(cache_time)) + + self.set_extra_headers(self.path) + + def should_return_304(self) -> bool: + """Returns True if the headers indicate that we should return 304. + + .. versionadded:: 3.1 + """ + # If client sent If-None-Match, use it, ignore If-Modified-Since + if self.request.headers.get("If-None-Match"): + return self.check_etag_header() + + # Check the If-Modified-Since, and don't send the result if the + # content has not been modified + ims_value = self.request.headers.get("If-Modified-Since") + if ims_value is not None: + date_tuple = email.utils.parsedate(ims_value) + if date_tuple is not None: + if_since = datetime.datetime(*date_tuple[:6]) + assert self.modified is not None + if if_since >= self.modified: + return True + + return False + + @classmethod + def get_absolute_path(cls, root: str, path: str) -> str: + """Returns the absolute location of ``path`` relative to ``root``. + + ``root`` is the path configured for this `StaticFileHandler` + (in most cases the ``static_path`` `Application` setting). + + This class method may be overridden in subclasses. By default + it returns a filesystem path, but other strings may be used + as long as they are unique and understood by the subclass's + overridden `get_content`. + + .. versionadded:: 3.1 + """ + abspath = os.path.abspath(os.path.join(root, path)) + return abspath + + def validate_absolute_path(self, root: str, absolute_path: str) -> Optional[str]: + """Validate and return the absolute path. + + ``root`` is the configured path for the `StaticFileHandler`, + and ``path`` is the result of `get_absolute_path` + + This is an instance method called during request processing, + so it may raise `HTTPError` or use methods like + `RequestHandler.redirect` (return None after redirecting to + halt further processing). This is where 404 errors for missing files + are generated. + + This method may modify the path before returning it, but note that + any such modifications will not be understood by `make_static_url`. + + In instance methods, this method's result is available as + ``self.absolute_path``. + + .. versionadded:: 3.1 + """ + # os.path.abspath strips a trailing /. + # We must add it back to `root` so that we only match files + # in a directory named `root` instead of files starting with + # that prefix. + root = os.path.abspath(root) + if not root.endswith(os.path.sep): + # abspath always removes a trailing slash, except when + # root is '/'. This is an unusual case, but several projects + # have independently discovered this technique to disable + # Tornado's path validation and (hopefully) do their own, + # so we need to support it. + root += os.path.sep + # The trailing slash also needs to be temporarily added back + # the requested path so a request to root/ will match. + if not (absolute_path + os.path.sep).startswith(root): + raise HTTPError(403, "%s is not in root static directory", self.path) + if os.path.isdir(absolute_path) and self.default_filename is not None: + # need to look at the request.path here for when path is empty + # but there is some prefix to the path that was already + # trimmed by the routing + if not self.request.path.endswith("/"): + self.redirect(self.request.path + "/", permanent=True) + return None + absolute_path = os.path.join(absolute_path, self.default_filename) + if not os.path.exists(absolute_path): + raise HTTPError(404) + if not os.path.isfile(absolute_path): + raise HTTPError(403, "%s is not a file", self.path) + return absolute_path + + @classmethod + def get_content( + cls, abspath: str, start: Optional[int] = None, end: Optional[int] = None + ) -> Generator[bytes, None, None]: + """Retrieve the content of the requested resource which is located + at the given absolute path. + + This class method may be overridden by subclasses. Note that its + signature is different from other overridable class methods + (no ``settings`` argument); this is deliberate to ensure that + ``abspath`` is able to stand on its own as a cache key. + + This method should either return a byte string or an iterator + of byte strings. The latter is preferred for large files + as it helps reduce memory fragmentation. + + .. versionadded:: 3.1 + """ + with open(abspath, "rb") as file: + if start is not None: + file.seek(start) + if end is not None: + remaining = end - (start or 0) # type: Optional[int] + else: + remaining = None + while True: + chunk_size = 64 * 1024 + if remaining is not None and remaining < chunk_size: + chunk_size = remaining + chunk = file.read(chunk_size) + if chunk: + if remaining is not None: + remaining -= len(chunk) + yield chunk + else: + if remaining is not None: + assert remaining == 0 + return + + @classmethod + def get_content_version(cls, abspath: str) -> str: + """Returns a version string for the resource at the given path. + + This class method may be overridden by subclasses. The + default implementation is a SHA-512 hash of the file's contents. + + .. versionadded:: 3.1 + """ + data = cls.get_content(abspath) + hasher = hashlib.sha512() + if isinstance(data, bytes): + hasher.update(data) + else: + for chunk in data: + hasher.update(chunk) + return hasher.hexdigest() + + def _stat(self) -> os.stat_result: + assert self.absolute_path is not None + if not hasattr(self, "_stat_result"): + self._stat_result = os.stat(self.absolute_path) + return self._stat_result + + def get_content_size(self) -> int: + """Retrieve the total size of the resource at the given path. + + This method may be overridden by subclasses. + + .. versionadded:: 3.1 + + .. versionchanged:: 4.0 + This method is now always called, instead of only when + partial results are requested. + """ + stat_result = self._stat() + return stat_result.st_size + + def get_modified_time(self) -> Optional[datetime.datetime]: + """Returns the time that ``self.absolute_path`` was last modified. + + May be overridden in subclasses. Should return a `~datetime.datetime` + object or None. + + .. versionadded:: 3.1 + """ + stat_result = self._stat() + # NOTE: Historically, this used stat_result[stat.ST_MTIME], + # which truncates the fractional portion of the timestamp. It + # was changed from that form to stat_result.st_mtime to + # satisfy mypy (which disallows the bracket operator), but the + # latter form returns a float instead of an int. For + # consistency with the past (and because we have a unit test + # that relies on this), we truncate the float here, although + # I'm not sure that's the right thing to do. + modified = datetime.datetime.utcfromtimestamp(int(stat_result.st_mtime)) + return modified + + def get_content_type(self) -> str: + """Returns the ``Content-Type`` header to be used for this request. + + .. versionadded:: 3.1 + """ + assert self.absolute_path is not None + mime_type, encoding = mimetypes.guess_type(self.absolute_path) + # per RFC 6713, use the appropriate type for a gzip compressed file + if encoding == "gzip": + return "application/gzip" + # As of 2015-07-21 there is no bzip2 encoding defined at + # http://www.iana.org/assignments/media-types/media-types.xhtml + # So for that (and any other encoding), use octet-stream. + elif encoding is not None: + return "application/octet-stream" + elif mime_type is not None: + return mime_type + # if mime_type not detected, use application/octet-stream + else: + return "application/octet-stream" + + def set_extra_headers(self, path: str) -> None: + """For subclass to add extra headers to the response""" + pass + + def get_cache_time( + self, path: str, modified: Optional[datetime.datetime], mime_type: str + ) -> int: + """Override to customize cache control behavior. + + Return a positive number of seconds to make the result + cacheable for that amount of time or 0 to mark resource as + cacheable for an unspecified amount of time (subject to + browser heuristics). + + By default returns cache expiry of 10 years for resources requested + with ``v`` argument. + """ + return self.CACHE_MAX_AGE if "v" in self.request.arguments else 0 + + @classmethod + def make_static_url( + cls, settings: Dict[str, Any], path: str, include_version: bool = True + ) -> str: + """Constructs a versioned url for the given path. + + This method may be overridden in subclasses (but note that it + is a class method rather than an instance method). Subclasses + are only required to implement the signature + ``make_static_url(cls, settings, path)``; other keyword + arguments may be passed through `~RequestHandler.static_url` + but are not standard. + + ``settings`` is the `Application.settings` dictionary. ``path`` + is the static path being requested. The url returned should be + relative to the current host. + + ``include_version`` determines whether the generated URL should + include the query string containing the version hash of the + file corresponding to the given ``path``. + + """ + url = settings.get("static_url_prefix", "/static/") + path + if not include_version: + return url + + version_hash = cls.get_version(settings, path) + if not version_hash: + return url + + return "%s?v=%s" % (url, version_hash) + + def parse_url_path(self, url_path: str) -> str: + """Converts a static URL path into a filesystem path. + + ``url_path`` is the path component of the URL with + ``static_url_prefix`` removed. The return value should be + filesystem path relative to ``static_path``. + + This is the inverse of `make_static_url`. + """ + if os.path.sep != "/": + url_path = url_path.replace("/", os.path.sep) + return url_path + + @classmethod + def get_version(cls, settings: Dict[str, Any], path: str) -> Optional[str]: + """Generate the version string to be used in static URLs. + + ``settings`` is the `Application.settings` dictionary and ``path`` + is the relative location of the requested asset on the filesystem. + The returned value should be a string, or ``None`` if no version + could be determined. + + .. versionchanged:: 3.1 + This method was previously recommended for subclasses to override; + `get_content_version` is now preferred as it allows the base + class to handle caching of the result. + """ + abs_path = cls.get_absolute_path(settings["static_path"], path) + return cls._get_cached_version(abs_path) + + @classmethod + def _get_cached_version(cls, abs_path: str) -> Optional[str]: + with cls._lock: + hashes = cls._static_hashes + if abs_path not in hashes: + try: + hashes[abs_path] = cls.get_content_version(abs_path) + except Exception: + gen_log.error("Could not open static file %r", abs_path) + hashes[abs_path] = None + hsh = hashes.get(abs_path) + if hsh: + return hsh + return None + + +class FallbackHandler(RequestHandler): + """A `RequestHandler` that wraps another HTTP server callback. + + The fallback is a callable object that accepts an + `~.httputil.HTTPServerRequest`, such as an `Application` or + `tornado.wsgi.WSGIContainer`. This is most useful to use both + Tornado ``RequestHandlers`` and WSGI in the same server. Typical + usage:: + + wsgi_app = tornado.wsgi.WSGIContainer( + django.core.handlers.wsgi.WSGIHandler()) + application = tornado.web.Application([ + (r"/foo", FooHandler), + (r".*", FallbackHandler, dict(fallback=wsgi_app), + ]) + """ + + def initialize( + self, fallback: Callable[[httputil.HTTPServerRequest], None] + ) -> None: + self.fallback = fallback + + def prepare(self) -> None: + self.fallback(self.request) + self._finished = True + self.on_finish() + + +class OutputTransform(object): + """A transform modifies the result of an HTTP request (e.g., GZip encoding) + + Applications are not expected to create their own OutputTransforms + or interact with them directly; the framework chooses which transforms + (if any) to apply. + """ + + def __init__(self, request: httputil.HTTPServerRequest) -> None: + pass + + def transform_first_chunk( + self, + status_code: int, + headers: httputil.HTTPHeaders, + chunk: bytes, + finishing: bool, + ) -> Tuple[int, httputil.HTTPHeaders, bytes]: + return status_code, headers, chunk + + def transform_chunk(self, chunk: bytes, finishing: bool) -> bytes: + return chunk + + +class GZipContentEncoding(OutputTransform): + """Applies the gzip content encoding to the response. + + See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.11 + + .. versionchanged:: 4.0 + Now compresses all mime types beginning with ``text/``, instead + of just a whitelist. (the whitelist is still used for certain + non-text mime types). + """ + + # Whitelist of compressible mime types (in addition to any types + # beginning with "text/"). + CONTENT_TYPES = set( + [ + "application/javascript", + "application/x-javascript", + "application/xml", + "application/atom+xml", + "application/json", + "application/xhtml+xml", + "image/svg+xml", + ] + ) + # Python's GzipFile defaults to level 9, while most other gzip + # tools (including gzip itself) default to 6, which is probably a + # better CPU/size tradeoff. + GZIP_LEVEL = 6 + # Responses that are too short are unlikely to benefit from gzipping + # after considering the "Content-Encoding: gzip" header and the header + # inside the gzip encoding. + # Note that responses written in multiple chunks will be compressed + # regardless of size. + MIN_LENGTH = 1024 + + def __init__(self, request: httputil.HTTPServerRequest) -> None: + self._gzipping = "gzip" in request.headers.get("Accept-Encoding", "") + + def _compressible_type(self, ctype: str) -> bool: + return ctype.startswith("text/") or ctype in self.CONTENT_TYPES + + def transform_first_chunk( + self, + status_code: int, + headers: httputil.HTTPHeaders, + chunk: bytes, + finishing: bool, + ) -> Tuple[int, httputil.HTTPHeaders, bytes]: + # TODO: can/should this type be inherited from the superclass? + if "Vary" in headers: + headers["Vary"] += ", Accept-Encoding" + else: + headers["Vary"] = "Accept-Encoding" + if self._gzipping: + ctype = _unicode(headers.get("Content-Type", "")).split(";")[0] + self._gzipping = ( + self._compressible_type(ctype) + and (not finishing or len(chunk) >= self.MIN_LENGTH) + and ("Content-Encoding" not in headers) + ) + if self._gzipping: + headers["Content-Encoding"] = "gzip" + self._gzip_value = BytesIO() + self._gzip_file = gzip.GzipFile( + mode="w", fileobj=self._gzip_value, compresslevel=self.GZIP_LEVEL + ) + chunk = self.transform_chunk(chunk, finishing) + if "Content-Length" in headers: + # The original content length is no longer correct. + # If this is the last (and only) chunk, we can set the new + # content-length; otherwise we remove it and fall back to + # chunked encoding. + if finishing: + headers["Content-Length"] = str(len(chunk)) + else: + del headers["Content-Length"] + return status_code, headers, chunk + + def transform_chunk(self, chunk: bytes, finishing: bool) -> bytes: + if self._gzipping: + self._gzip_file.write(chunk) + if finishing: + self._gzip_file.close() + else: + self._gzip_file.flush() + chunk = self._gzip_value.getvalue() + self._gzip_value.truncate(0) + self._gzip_value.seek(0) + return chunk + + +def authenticated( + method: Callable[..., Optional[Awaitable[None]]] +) -> Callable[..., Optional[Awaitable[None]]]: + """Decorate methods with this to require that the user be logged in. + + If the user is not logged in, they will be redirected to the configured + `login url <RequestHandler.get_login_url>`. + + If you configure a login url with a query parameter, Tornado will + assume you know what you're doing and use it as-is. If not, it + will add a `next` parameter so the login page knows where to send + you once you're logged in. + """ + + @functools.wraps(method) + def wrapper( # type: ignore + self: RequestHandler, *args, **kwargs + ) -> Optional[Awaitable[None]]: + if not self.current_user: + if self.request.method in ("GET", "HEAD"): + url = self.get_login_url() + if "?" not in url: + if urllib.parse.urlsplit(url).scheme: + # if login url is absolute, make next absolute too + next_url = self.request.full_url() + else: + assert self.request.uri is not None + next_url = self.request.uri + url += "?" + urlencode(dict(next=next_url)) + self.redirect(url) + return None + raise HTTPError(403) + return method(self, *args, **kwargs) + + return wrapper + + +class UIModule(object): + """A re-usable, modular UI unit on a page. + + UI modules often execute additional queries, and they can include + additional CSS and JavaScript that will be included in the output + page, which is automatically inserted on page render. + + Subclasses of UIModule must override the `render` method. + """ + + def __init__(self, handler: RequestHandler) -> None: + self.handler = handler + self.request = handler.request + self.ui = handler.ui + self.locale = handler.locale + + @property + def current_user(self) -> Any: + return self.handler.current_user + + def render(self, *args: Any, **kwargs: Any) -> str: + """Override in subclasses to return this module's output.""" + raise NotImplementedError() + + def embedded_javascript(self) -> Optional[str]: + """Override to return a JavaScript string + to be embedded in the page.""" + return None + + def javascript_files(self) -> Optional[Iterable[str]]: + """Override to return a list of JavaScript files needed by this module. + + If the return values are relative paths, they will be passed to + `RequestHandler.static_url`; otherwise they will be used as-is. + """ + return None + + def embedded_css(self) -> Optional[str]: + """Override to return a CSS string + that will be embedded in the page.""" + return None + + def css_files(self) -> Optional[Iterable[str]]: + """Override to returns a list of CSS files required by this module. + + If the return values are relative paths, they will be passed to + `RequestHandler.static_url`; otherwise they will be used as-is. + """ + return None + + def html_head(self) -> Optional[str]: + """Override to return an HTML string that will be put in the <head/> + element. + """ + return None + + def html_body(self) -> Optional[str]: + """Override to return an HTML string that will be put at the end of + the <body/> element. + """ + return None + + def render_string(self, path: str, **kwargs: Any) -> bytes: + """Renders a template and returns it as a string.""" + return self.handler.render_string(path, **kwargs) + + +class _linkify(UIModule): + def render(self, text: str, **kwargs: Any) -> str: # type: ignore + return escape.linkify(text, **kwargs) + + +class _xsrf_form_html(UIModule): + def render(self) -> str: # type: ignore + return self.handler.xsrf_form_html() + + +class TemplateModule(UIModule): + """UIModule that simply renders the given template. + + {% module Template("foo.html") %} is similar to {% include "foo.html" %}, + but the module version gets its own namespace (with kwargs passed to + Template()) instead of inheriting the outer template's namespace. + + Templates rendered through this module also get access to UIModule's + automatic JavaScript/CSS features. Simply call set_resources + inside the template and give it keyword arguments corresponding to + the methods on UIModule: {{ set_resources(js_files=static_url("my.js")) }} + Note that these resources are output once per template file, not once + per instantiation of the template, so they must not depend on + any arguments to the template. + """ + + def __init__(self, handler: RequestHandler) -> None: + super().__init__(handler) + # keep resources in both a list and a dict to preserve order + self._resource_list = [] # type: List[Dict[str, Any]] + self._resource_dict = {} # type: Dict[str, Dict[str, Any]] + + def render(self, path: str, **kwargs: Any) -> bytes: # type: ignore + def set_resources(**kwargs) -> str: # type: ignore + if path not in self._resource_dict: + self._resource_list.append(kwargs) + self._resource_dict[path] = kwargs + else: + if self._resource_dict[path] != kwargs: + raise ValueError( + "set_resources called with different " + "resources for the same template" + ) + return "" + + return self.render_string(path, set_resources=set_resources, **kwargs) + + def _get_resources(self, key: str) -> Iterable[str]: + return (r[key] for r in self._resource_list if key in r) + + def embedded_javascript(self) -> str: + return "\n".join(self._get_resources("embedded_javascript")) + + def javascript_files(self) -> Iterable[str]: + result = [] + for f in self._get_resources("javascript_files"): + if isinstance(f, (unicode_type, bytes)): + result.append(f) + else: + result.extend(f) + return result + + def embedded_css(self) -> str: + return "\n".join(self._get_resources("embedded_css")) + + def css_files(self) -> Iterable[str]: + result = [] + for f in self._get_resources("css_files"): + if isinstance(f, (unicode_type, bytes)): + result.append(f) + else: + result.extend(f) + return result + + def html_head(self) -> str: + return "".join(self._get_resources("html_head")) + + def html_body(self) -> str: + return "".join(self._get_resources("html_body")) + + +class _UIModuleNamespace(object): + """Lazy namespace which creates UIModule proxies bound to a handler.""" + + def __init__( + self, handler: RequestHandler, ui_modules: Dict[str, Type[UIModule]] + ) -> None: + self.handler = handler + self.ui_modules = ui_modules + + def __getitem__(self, key: str) -> Callable[..., str]: + return self.handler._ui_module(key, self.ui_modules[key]) + + def __getattr__(self, key: str) -> Callable[..., str]: + try: + return self[key] + except KeyError as e: + raise AttributeError(str(e)) + + +def create_signed_value( + secret: _CookieSecretTypes, + name: str, + value: Union[str, bytes], + version: Optional[int] = None, + clock: Optional[Callable[[], float]] = None, + key_version: Optional[int] = None, +) -> bytes: + if version is None: + version = DEFAULT_SIGNED_VALUE_VERSION + if clock is None: + clock = time.time + + timestamp = utf8(str(int(clock()))) + value = base64.b64encode(utf8(value)) + if version == 1: + assert not isinstance(secret, dict) + signature = _create_signature_v1(secret, name, value, timestamp) + value = b"|".join([value, timestamp, signature]) + return value + elif version == 2: + # The v2 format consists of a version number and a series of + # length-prefixed fields "%d:%s", the last of which is a + # signature, all separated by pipes. All numbers are in + # decimal format with no leading zeros. The signature is an + # HMAC-SHA256 of the whole string up to that point, including + # the final pipe. + # + # The fields are: + # - format version (i.e. 2; no length prefix) + # - key version (integer, default is 0) + # - timestamp (integer seconds since epoch) + # - name (not encoded; assumed to be ~alphanumeric) + # - value (base64-encoded) + # - signature (hex-encoded; no length prefix) + def format_field(s: Union[str, bytes]) -> bytes: + return utf8("%d:" % len(s)) + utf8(s) + + to_sign = b"|".join( + [ + b"2", + format_field(str(key_version or 0)), + format_field(timestamp), + format_field(name), + format_field(value), + b"", + ] + ) + + if isinstance(secret, dict): + assert ( + key_version is not None + ), "Key version must be set when sign key dict is used" + assert version >= 2, "Version must be at least 2 for key version support" + secret = secret[key_version] + + signature = _create_signature_v2(secret, to_sign) + return to_sign + signature + else: + raise ValueError("Unsupported version %d" % version) + + +# A leading version number in decimal +# with no leading zeros, followed by a pipe. +_signed_value_version_re = re.compile(rb"^([1-9][0-9]*)\|(.*)$") + + +def _get_version(value: bytes) -> int: + # Figures out what version value is. Version 1 did not include an + # explicit version field and started with arbitrary base64 data, + # which makes this tricky. + m = _signed_value_version_re.match(value) + if m is None: + version = 1 + else: + try: + version = int(m.group(1)) + if version > 999: + # Certain payloads from the version-less v1 format may + # be parsed as valid integers. Due to base64 padding + # restrictions, this can only happen for numbers whose + # length is a multiple of 4, so we can treat all + # numbers up to 999 as versions, and for the rest we + # fall back to v1 format. + version = 1 + except ValueError: + version = 1 + return version + + +def decode_signed_value( + secret: _CookieSecretTypes, + name: str, + value: Union[None, str, bytes], + max_age_days: float = 31, + clock: Optional[Callable[[], float]] = None, + min_version: Optional[int] = None, +) -> Optional[bytes]: + if clock is None: + clock = time.time + if min_version is None: + min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION + if min_version > 2: + raise ValueError("Unsupported min_version %d" % min_version) + if not value: + return None + + value = utf8(value) + version = _get_version(value) + + if version < min_version: + return None + if version == 1: + assert not isinstance(secret, dict) + return _decode_signed_value_v1(secret, name, value, max_age_days, clock) + elif version == 2: + return _decode_signed_value_v2(secret, name, value, max_age_days, clock) + else: + return None + + +def _decode_signed_value_v1( + secret: Union[str, bytes], + name: str, + value: bytes, + max_age_days: float, + clock: Callable[[], float], +) -> Optional[bytes]: + parts = utf8(value).split(b"|") + if len(parts) != 3: + return None + signature = _create_signature_v1(secret, name, parts[0], parts[1]) + if not hmac.compare_digest(parts[2], signature): + gen_log.warning("Invalid cookie signature %r", value) + return None + timestamp = int(parts[1]) + if timestamp < clock() - max_age_days * 86400: + gen_log.warning("Expired cookie %r", value) + return None + if timestamp > clock() + 31 * 86400: + # _cookie_signature does not hash a delimiter between the + # parts of the cookie, so an attacker could transfer trailing + # digits from the payload to the timestamp without altering the + # signature. For backwards compatibility, sanity-check timestamp + # here instead of modifying _cookie_signature. + gen_log.warning("Cookie timestamp in future; possible tampering %r", value) + return None + if parts[1].startswith(b"0"): + gen_log.warning("Tampered cookie %r", value) + return None + try: + return base64.b64decode(parts[0]) + except Exception: + return None + + +def _decode_fields_v2(value: bytes) -> Tuple[int, bytes, bytes, bytes, bytes]: + def _consume_field(s: bytes) -> Tuple[bytes, bytes]: + length, _, rest = s.partition(b":") + n = int(length) + field_value = rest[:n] + # In python 3, indexing bytes returns small integers; we must + # use a slice to get a byte string as in python 2. + if rest[n : n + 1] != b"|": + raise ValueError("malformed v2 signed value field") + rest = rest[n + 1 :] + return field_value, rest + + rest = value[2:] # remove version number + key_version, rest = _consume_field(rest) + timestamp, rest = _consume_field(rest) + name_field, rest = _consume_field(rest) + value_field, passed_sig = _consume_field(rest) + return int(key_version), timestamp, name_field, value_field, passed_sig + + +def _decode_signed_value_v2( + secret: _CookieSecretTypes, + name: str, + value: bytes, + max_age_days: float, + clock: Callable[[], float], +) -> Optional[bytes]: + try: + ( + key_version, + timestamp_bytes, + name_field, + value_field, + passed_sig, + ) = _decode_fields_v2(value) + except ValueError: + return None + signed_string = value[: -len(passed_sig)] + + if isinstance(secret, dict): + try: + secret = secret[key_version] + except KeyError: + return None + + expected_sig = _create_signature_v2(secret, signed_string) + if not hmac.compare_digest(passed_sig, expected_sig): + return None + if name_field != utf8(name): + return None + timestamp = int(timestamp_bytes) + if timestamp < clock() - max_age_days * 86400: + # The signature has expired. + return None + try: + return base64.b64decode(value_field) + except Exception: + return None + + +def get_signature_key_version(value: Union[str, bytes]) -> Optional[int]: + value = utf8(value) + version = _get_version(value) + if version < 2: + return None + try: + key_version, _, _, _, _ = _decode_fields_v2(value) + except ValueError: + return None + + return key_version + + +def _create_signature_v1(secret: Union[str, bytes], *parts: Union[str, bytes]) -> bytes: + hash = hmac.new(utf8(secret), digestmod=hashlib.sha1) + for part in parts: + hash.update(utf8(part)) + return utf8(hash.hexdigest()) + + +def _create_signature_v2(secret: Union[str, bytes], s: bytes) -> bytes: + hash = hmac.new(utf8(secret), digestmod=hashlib.sha256) + hash.update(utf8(s)) + return utf8(hash.hexdigest()) + + +def is_absolute(path: str) -> bool: + return any(path.startswith(x) for x in ["/", "http:", "https:"]) diff --git a/venv/lib/python3.9/site-packages/tornado/websocket.py b/venv/lib/python3.9/site-packages/tornado/websocket.py new file mode 100644 index 00000000..d0abd425 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/websocket.py @@ -0,0 +1,1659 @@ +"""Implementation of the WebSocket protocol. + +`WebSockets <http://dev.w3.org/html5/websockets/>`_ allow for bidirectional +communication between the browser and server. WebSockets are supported in the +current versions of all major browsers. + +This module implements the final version of the WebSocket protocol as +defined in `RFC 6455 <http://tools.ietf.org/html/rfc6455>`_. + +.. versionchanged:: 4.0 + Removed support for the draft 76 protocol version. +""" + +import abc +import asyncio +import base64 +import hashlib +import os +import sys +import struct +import tornado +from urllib.parse import urlparse +import zlib + +from tornado.concurrent import Future, future_set_result_unless_cancelled +from tornado.escape import utf8, native_str, to_unicode +from tornado import gen, httpclient, httputil +from tornado.ioloop import IOLoop, PeriodicCallback +from tornado.iostream import StreamClosedError, IOStream +from tornado.log import gen_log, app_log +from tornado.netutil import Resolver +from tornado import simple_httpclient +from tornado.queues import Queue +from tornado.tcpclient import TCPClient +from tornado.util import _websocket_mask + +from typing import ( + TYPE_CHECKING, + cast, + Any, + Optional, + Dict, + Union, + List, + Awaitable, + Callable, + Tuple, + Type, +) +from types import TracebackType + +if TYPE_CHECKING: + from typing_extensions import Protocol + + # The zlib compressor types aren't actually exposed anywhere + # publicly, so declare protocols for the portions we use. + class _Compressor(Protocol): + def compress(self, data: bytes) -> bytes: + pass + + def flush(self, mode: int) -> bytes: + pass + + class _Decompressor(Protocol): + unconsumed_tail = b"" # type: bytes + + def decompress(self, data: bytes, max_length: int) -> bytes: + pass + + class _WebSocketDelegate(Protocol): + # The common base interface implemented by WebSocketHandler on + # the server side and WebSocketClientConnection on the client + # side. + def on_ws_connection_close( + self, close_code: Optional[int] = None, close_reason: Optional[str] = None + ) -> None: + pass + + def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]: + pass + + def on_ping(self, data: bytes) -> None: + pass + + def on_pong(self, data: bytes) -> None: + pass + + def log_exception( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + pass + + +_default_max_message_size = 10 * 1024 * 1024 + + +class WebSocketError(Exception): + pass + + +class WebSocketClosedError(WebSocketError): + """Raised by operations on a closed connection. + + .. versionadded:: 3.2 + """ + + pass + + +class _DecompressTooLargeError(Exception): + pass + + +class _WebSocketParams(object): + def __init__( + self, + ping_interval: Optional[float] = None, + ping_timeout: Optional[float] = None, + max_message_size: int = _default_max_message_size, + compression_options: Optional[Dict[str, Any]] = None, + ) -> None: + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.max_message_size = max_message_size + self.compression_options = compression_options + + +class WebSocketHandler(tornado.web.RequestHandler): + """Subclass this class to create a basic WebSocket handler. + + Override `on_message` to handle incoming messages, and use + `write_message` to send messages to the client. You can also + override `open` and `on_close` to handle opened and closed + connections. + + Custom upgrade response headers can be sent by overriding + `~tornado.web.RequestHandler.set_default_headers` or + `~tornado.web.RequestHandler.prepare`. + + See http://dev.w3.org/html5/websockets/ for details on the + JavaScript interface. The protocol is specified at + http://tools.ietf.org/html/rfc6455. + + Here is an example WebSocket handler that echos back all received messages + back to the client: + + .. testcode:: + + class EchoWebSocket(tornado.websocket.WebSocketHandler): + def open(self): + print("WebSocket opened") + + def on_message(self, message): + self.write_message(u"You said: " + message) + + def on_close(self): + print("WebSocket closed") + + .. testoutput:: + :hide: + + WebSockets are not standard HTTP connections. The "handshake" is + HTTP, but after the handshake, the protocol is + message-based. Consequently, most of the Tornado HTTP facilities + are not available in handlers of this type. The only communication + methods available to you are `write_message()`, `ping()`, and + `close()`. Likewise, your request handler class should implement + `open()` method rather than ``get()`` or ``post()``. + + If you map the handler above to ``/websocket`` in your application, you can + invoke it in JavaScript with:: + + var ws = new WebSocket("ws://localhost:8888/websocket"); + ws.onopen = function() { + ws.send("Hello, world"); + }; + ws.onmessage = function (evt) { + alert(evt.data); + }; + + This script pops up an alert box that says "You said: Hello, world". + + Web browsers allow any site to open a websocket connection to any other, + instead of using the same-origin policy that governs other network + access from JavaScript. This can be surprising and is a potential + security hole, so since Tornado 4.0 `WebSocketHandler` requires + applications that wish to receive cross-origin websockets to opt in + by overriding the `~WebSocketHandler.check_origin` method (see that + method's docs for details). Failure to do so is the most likely + cause of 403 errors when making a websocket connection. + + When using a secure websocket connection (``wss://``) with a self-signed + certificate, the connection from a browser may fail because it wants + to show the "accept this certificate" dialog but has nowhere to show it. + You must first visit a regular HTML page using the same certificate + to accept it before the websocket connection will succeed. + + If the application setting ``websocket_ping_interval`` has a non-zero + value, a ping will be sent periodically, and the connection will be + closed if a response is not received before the ``websocket_ping_timeout``. + + Messages larger than the ``websocket_max_message_size`` application setting + (default 10MiB) will not be accepted. + + .. versionchanged:: 4.5 + Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and + ``websocket_max_message_size``. + """ + + def __init__( + self, + application: tornado.web.Application, + request: httputil.HTTPServerRequest, + **kwargs: Any + ) -> None: + super().__init__(application, request, **kwargs) + self.ws_connection = None # type: Optional[WebSocketProtocol] + self.close_code = None # type: Optional[int] + self.close_reason = None # type: Optional[str] + self._on_close_called = False + + async def get(self, *args: Any, **kwargs: Any) -> None: + self.open_args = args + self.open_kwargs = kwargs + + # Upgrade header should be present and should be equal to WebSocket + if self.request.headers.get("Upgrade", "").lower() != "websocket": + self.set_status(400) + log_msg = 'Can "Upgrade" only to "WebSocket".' + self.finish(log_msg) + gen_log.debug(log_msg) + return + + # Connection header should be upgrade. + # Some proxy servers/load balancers + # might mess with it. + headers = self.request.headers + connection = map( + lambda s: s.strip().lower(), headers.get("Connection", "").split(",") + ) + if "upgrade" not in connection: + self.set_status(400) + log_msg = '"Connection" must be "Upgrade".' + self.finish(log_msg) + gen_log.debug(log_msg) + return + + # Handle WebSocket Origin naming convention differences + # The difference between version 8 and 13 is that in 8 the + # client sends a "Sec-Websocket-Origin" header and in 13 it's + # simply "Origin". + if "Origin" in self.request.headers: + origin = self.request.headers.get("Origin") + else: + origin = self.request.headers.get("Sec-Websocket-Origin", None) + + # If there was an origin header, check to make sure it matches + # according to check_origin. When the origin is None, we assume it + # did not come from a browser and that it can be passed on. + if origin is not None and not self.check_origin(origin): + self.set_status(403) + log_msg = "Cross origin websockets not allowed" + self.finish(log_msg) + gen_log.debug(log_msg) + return + + self.ws_connection = self.get_websocket_protocol() + if self.ws_connection: + await self.ws_connection.accept_connection(self) + else: + self.set_status(426, "Upgrade Required") + self.set_header("Sec-WebSocket-Version", "7, 8, 13") + + @property + def ping_interval(self) -> Optional[float]: + """The interval for websocket keep-alive pings. + + Set websocket_ping_interval = 0 to disable pings. + """ + return self.settings.get("websocket_ping_interval", None) + + @property + def ping_timeout(self) -> Optional[float]: + """If no ping is received in this many seconds, + close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). + Default is max of 3 pings or 30 seconds. + """ + return self.settings.get("websocket_ping_timeout", None) + + @property + def max_message_size(self) -> int: + """Maximum allowed message size. + + If the remote peer sends a message larger than this, the connection + will be closed. + + Default is 10MiB. + """ + return self.settings.get( + "websocket_max_message_size", _default_max_message_size + ) + + def write_message( + self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False + ) -> "Future[None]": + """Sends the given message to the client of this Web Socket. + + The message may be either a string or a dict (which will be + encoded as json). If the ``binary`` argument is false, the + message will be sent as utf8; in binary mode any byte string + is allowed. + + If the connection is already closed, raises `WebSocketClosedError`. + Returns a `.Future` which can be used for flow control. + + .. versionchanged:: 3.2 + `WebSocketClosedError` was added (previously a closed connection + would raise an `AttributeError`) + + .. versionchanged:: 4.3 + Returns a `.Future` which can be used for flow control. + + .. versionchanged:: 5.0 + Consistently raises `WebSocketClosedError`. Previously could + sometimes raise `.StreamClosedError`. + """ + if self.ws_connection is None or self.ws_connection.is_closing(): + raise WebSocketClosedError() + if isinstance(message, dict): + message = tornado.escape.json_encode(message) + return self.ws_connection.write_message(message, binary=binary) + + def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]: + """Override to implement subprotocol negotiation. + + ``subprotocols`` is a list of strings identifying the + subprotocols proposed by the client. This method may be + overridden to return one of those strings to select it, or + ``None`` to not select a subprotocol. + + Failure to select a subprotocol does not automatically abort + the connection, although clients may close the connection if + none of their proposed subprotocols was selected. + + The list may be empty, in which case this method must return + None. This method is always called exactly once even if no + subprotocols were proposed so that the handler can be advised + of this fact. + + .. versionchanged:: 5.1 + + Previously, this method was called with a list containing + an empty string instead of an empty list if no subprotocols + were proposed by the client. + """ + return None + + @property + def selected_subprotocol(self) -> Optional[str]: + """The subprotocol returned by `select_subprotocol`. + + .. versionadded:: 5.1 + """ + assert self.ws_connection is not None + return self.ws_connection.selected_subprotocol + + def get_compression_options(self) -> Optional[Dict[str, Any]]: + """Override to return compression options for the connection. + + If this method returns None (the default), compression will + be disabled. If it returns a dict (even an empty one), it + will be enabled. The contents of the dict may be used to + control the following compression options: + + ``compression_level`` specifies the compression level. + + ``mem_level`` specifies the amount of memory used for the internal compression state. + + These parameters are documented in details here: + https://docs.python.org/3.6/library/zlib.html#zlib.compressobj + + .. versionadded:: 4.1 + + .. versionchanged:: 4.5 + + Added ``compression_level`` and ``mem_level``. + """ + # TODO: Add wbits option. + return None + + def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: + """Invoked when a new WebSocket is opened. + + The arguments to `open` are extracted from the `tornado.web.URLSpec` + regular expression, just like the arguments to + `tornado.web.RequestHandler.get`. + + `open` may be a coroutine. `on_message` will not be called until + `open` has returned. + + .. versionchanged:: 5.1 + + ``open`` may be a coroutine. + """ + pass + + def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: + """Handle incoming messages on the WebSocket + + This method must be overridden. + + .. versionchanged:: 4.5 + + ``on_message`` can be a coroutine. + """ + raise NotImplementedError + + def ping(self, data: Union[str, bytes] = b"") -> None: + """Send ping frame to the remote end. + + The data argument allows a small amount of data (up to 125 + bytes) to be sent as a part of the ping message. Note that not + all websocket implementations expose this data to + applications. + + Consider using the ``websocket_ping_interval`` application + setting instead of sending pings manually. + + .. versionchanged:: 5.1 + + The data argument is now optional. + + """ + data = utf8(data) + if self.ws_connection is None or self.ws_connection.is_closing(): + raise WebSocketClosedError() + self.ws_connection.write_ping(data) + + def on_pong(self, data: bytes) -> None: + """Invoked when the response to a ping frame is received.""" + pass + + def on_ping(self, data: bytes) -> None: + """Invoked when the a ping frame is received.""" + pass + + def on_close(self) -> None: + """Invoked when the WebSocket is closed. + + If the connection was closed cleanly and a status code or reason + phrase was supplied, these values will be available as the attributes + ``self.close_code`` and ``self.close_reason``. + + .. versionchanged:: 4.0 + + Added ``close_code`` and ``close_reason`` attributes. + """ + pass + + def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: + """Closes this Web Socket. + + Once the close handshake is successful the socket will be closed. + + ``code`` may be a numeric status code, taken from the values + defined in `RFC 6455 section 7.4.1 + <https://tools.ietf.org/html/rfc6455#section-7.4.1>`_. + ``reason`` may be a textual message about why the connection is + closing. These values are made available to the client, but are + not otherwise interpreted by the websocket protocol. + + .. versionchanged:: 4.0 + + Added the ``code`` and ``reason`` arguments. + """ + if self.ws_connection: + self.ws_connection.close(code, reason) + self.ws_connection = None + + def check_origin(self, origin: str) -> bool: + """Override to enable support for allowing alternate origins. + + The ``origin`` argument is the value of the ``Origin`` HTTP + header, the url responsible for initiating this request. This + method is not called for clients that do not send this header; + such requests are always allowed (because all browsers that + implement WebSockets support this header, and non-browser + clients do not have the same cross-site security concerns). + + Should return ``True`` to accept the request or ``False`` to + reject it. By default, rejects all requests with an origin on + a host other than this one. + + This is a security protection against cross site scripting attacks on + browsers, since WebSockets are allowed to bypass the usual same-origin + policies and don't use CORS headers. + + .. warning:: + + This is an important security measure; don't disable it + without understanding the security implications. In + particular, if your authentication is cookie-based, you + must either restrict the origins allowed by + ``check_origin()`` or implement your own XSRF-like + protection for websocket connections. See `these + <https://www.christian-schneider.net/CrossSiteWebSocketHijacking.html>`_ + `articles + <https://devcenter.heroku.com/articles/websocket-security>`_ + for more. + + To accept all cross-origin traffic (which was the default prior to + Tornado 4.0), simply override this method to always return ``True``:: + + def check_origin(self, origin): + return True + + To allow connections from any subdomain of your site, you might + do something like:: + + def check_origin(self, origin): + parsed_origin = urllib.parse.urlparse(origin) + return parsed_origin.netloc.endswith(".mydomain.com") + + .. versionadded:: 4.0 + + """ + parsed_origin = urlparse(origin) + origin = parsed_origin.netloc + origin = origin.lower() + + host = self.request.headers.get("Host") + + # Check to see that origin matches host directly, including ports + return origin == host + + def set_nodelay(self, value: bool) -> None: + """Set the no-delay flag for this stream. + + By default, small messages may be delayed and/or combined to minimize + the number of packets sent. This can sometimes cause 200-500ms delays + due to the interaction between Nagle's algorithm and TCP delayed + ACKs. To reduce this delay (at the expense of possibly increasing + bandwidth usage), call ``self.set_nodelay(True)`` once the websocket + connection is established. + + See `.BaseIOStream.set_nodelay` for additional details. + + .. versionadded:: 3.1 + """ + assert self.ws_connection is not None + self.ws_connection.set_nodelay(value) + + def on_connection_close(self) -> None: + if self.ws_connection: + self.ws_connection.on_connection_close() + self.ws_connection = None + if not self._on_close_called: + self._on_close_called = True + self.on_close() + self._break_cycles() + + def on_ws_connection_close( + self, close_code: Optional[int] = None, close_reason: Optional[str] = None + ) -> None: + self.close_code = close_code + self.close_reason = close_reason + self.on_connection_close() + + def _break_cycles(self) -> None: + # WebSocketHandlers call finish() early, but we don't want to + # break up reference cycles (which makes it impossible to call + # self.render_string) until after we've really closed the + # connection (if it was established in the first place, + # indicated by status code 101). + if self.get_status() != 101 or self._on_close_called: + super()._break_cycles() + + def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]: + websocket_version = self.request.headers.get("Sec-WebSocket-Version") + if websocket_version in ("7", "8", "13"): + params = _WebSocketParams( + ping_interval=self.ping_interval, + ping_timeout=self.ping_timeout, + max_message_size=self.max_message_size, + compression_options=self.get_compression_options(), + ) + return WebSocketProtocol13(self, False, params) + return None + + def _detach_stream(self) -> IOStream: + # disable non-WS methods + for method in [ + "write", + "redirect", + "set_header", + "set_cookie", + "set_status", + "flush", + "finish", + ]: + setattr(self, method, _raise_not_supported_for_websockets) + return self.detach() + + +def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None: + raise RuntimeError("Method not supported for Web Sockets") + + +class WebSocketProtocol(abc.ABC): + """Base class for WebSocket protocol versions.""" + + def __init__(self, handler: "_WebSocketDelegate") -> None: + self.handler = handler + self.stream = None # type: Optional[IOStream] + self.client_terminated = False + self.server_terminated = False + + def _run_callback( + self, callback: Callable, *args: Any, **kwargs: Any + ) -> "Optional[Future[Any]]": + """Runs the given callback with exception handling. + + If the callback is a coroutine, returns its Future. On error, aborts the + websocket connection and returns None. + """ + try: + result = callback(*args, **kwargs) + except Exception: + self.handler.log_exception(*sys.exc_info()) + self._abort() + return None + else: + if result is not None: + result = gen.convert_yielded(result) + assert self.stream is not None + self.stream.io_loop.add_future(result, lambda f: f.result()) + return result + + def on_connection_close(self) -> None: + self._abort() + + def _abort(self) -> None: + """Instantly aborts the WebSocket connection by closing the socket""" + self.client_terminated = True + self.server_terminated = True + if self.stream is not None: + self.stream.close() # forcibly tear down the connection + self.close() # let the subclass cleanup + + @abc.abstractmethod + def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def is_closing(self) -> bool: + raise NotImplementedError() + + @abc.abstractmethod + async def accept_connection(self, handler: WebSocketHandler) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def write_message( + self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False + ) -> "Future[None]": + raise NotImplementedError() + + @property + @abc.abstractmethod + def selected_subprotocol(self) -> Optional[str]: + raise NotImplementedError() + + @abc.abstractmethod + def write_ping(self, data: bytes) -> None: + raise NotImplementedError() + + # The entry points below are used by WebSocketClientConnection, + # which was introduced after we only supported a single version of + # WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13 + # boundary is currently pretty ad-hoc. + @abc.abstractmethod + def _process_server_headers( + self, key: Union[str, bytes], headers: httputil.HTTPHeaders + ) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def start_pinging(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + async def _receive_frame_loop(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def set_nodelay(self, x: bool) -> None: + raise NotImplementedError() + + +class _PerMessageDeflateCompressor(object): + def __init__( + self, + persistent: bool, + max_wbits: Optional[int], + compression_options: Optional[Dict[str, Any]] = None, + ) -> None: + if max_wbits is None: + max_wbits = zlib.MAX_WBITS + # There is no symbolic constant for the minimum wbits value. + if not (8 <= max_wbits <= zlib.MAX_WBITS): + raise ValueError( + "Invalid max_wbits value %r; allowed range 8-%d", + max_wbits, + zlib.MAX_WBITS, + ) + self._max_wbits = max_wbits + + if ( + compression_options is None + or "compression_level" not in compression_options + ): + self._compression_level = tornado.web.GZipContentEncoding.GZIP_LEVEL + else: + self._compression_level = compression_options["compression_level"] + + if compression_options is None or "mem_level" not in compression_options: + self._mem_level = 8 + else: + self._mem_level = compression_options["mem_level"] + + if persistent: + self._compressor = self._create_compressor() # type: Optional[_Compressor] + else: + self._compressor = None + + def _create_compressor(self) -> "_Compressor": + return zlib.compressobj( + self._compression_level, zlib.DEFLATED, -self._max_wbits, self._mem_level + ) + + def compress(self, data: bytes) -> bytes: + compressor = self._compressor or self._create_compressor() + data = compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH) + assert data.endswith(b"\x00\x00\xff\xff") + return data[:-4] + + +class _PerMessageDeflateDecompressor(object): + def __init__( + self, + persistent: bool, + max_wbits: Optional[int], + max_message_size: int, + compression_options: Optional[Dict[str, Any]] = None, + ) -> None: + self._max_message_size = max_message_size + if max_wbits is None: + max_wbits = zlib.MAX_WBITS + if not (8 <= max_wbits <= zlib.MAX_WBITS): + raise ValueError( + "Invalid max_wbits value %r; allowed range 8-%d", + max_wbits, + zlib.MAX_WBITS, + ) + self._max_wbits = max_wbits + if persistent: + self._decompressor = ( + self._create_decompressor() + ) # type: Optional[_Decompressor] + else: + self._decompressor = None + + def _create_decompressor(self) -> "_Decompressor": + return zlib.decompressobj(-self._max_wbits) + + def decompress(self, data: bytes) -> bytes: + decompressor = self._decompressor or self._create_decompressor() + result = decompressor.decompress( + data + b"\x00\x00\xff\xff", self._max_message_size + ) + if decompressor.unconsumed_tail: + raise _DecompressTooLargeError() + return result + + +class WebSocketProtocol13(WebSocketProtocol): + """Implementation of the WebSocket protocol from RFC 6455. + + This class supports versions 7 and 8 of the protocol in addition to the + final version 13. + """ + + # Bit masks for the first byte of a frame. + FIN = 0x80 + RSV1 = 0x40 + RSV2 = 0x20 + RSV3 = 0x10 + RSV_MASK = RSV1 | RSV2 | RSV3 + OPCODE_MASK = 0x0F + + stream = None # type: IOStream + + def __init__( + self, + handler: "_WebSocketDelegate", + mask_outgoing: bool, + params: _WebSocketParams, + ) -> None: + WebSocketProtocol.__init__(self, handler) + self.mask_outgoing = mask_outgoing + self.params = params + self._final_frame = False + self._frame_opcode = None + self._masked_frame = None + self._frame_mask = None # type: Optional[bytes] + self._frame_length = None + self._fragmented_message_buffer = None # type: Optional[bytearray] + self._fragmented_message_opcode = None + self._waiting = None # type: object + self._compression_options = params.compression_options + self._decompressor = None # type: Optional[_PerMessageDeflateDecompressor] + self._compressor = None # type: Optional[_PerMessageDeflateCompressor] + self._frame_compressed = None # type: Optional[bool] + # The total uncompressed size of all messages received or sent. + # Unicode messages are encoded to utf8. + # Only for testing; subject to change. + self._message_bytes_in = 0 + self._message_bytes_out = 0 + # The total size of all packets received or sent. Includes + # the effect of compression, frame overhead, and control frames. + self._wire_bytes_in = 0 + self._wire_bytes_out = 0 + self.ping_callback = None # type: Optional[PeriodicCallback] + self.last_ping = 0.0 + self.last_pong = 0.0 + self.close_code = None # type: Optional[int] + self.close_reason = None # type: Optional[str] + + # Use a property for this to satisfy the abc. + @property + def selected_subprotocol(self) -> Optional[str]: + return self._selected_subprotocol + + @selected_subprotocol.setter + def selected_subprotocol(self, value: Optional[str]) -> None: + self._selected_subprotocol = value + + async def accept_connection(self, handler: WebSocketHandler) -> None: + try: + self._handle_websocket_headers(handler) + except ValueError: + handler.set_status(400) + log_msg = "Missing/Invalid WebSocket headers" + handler.finish(log_msg) + gen_log.debug(log_msg) + return + + try: + await self._accept_connection(handler) + except asyncio.CancelledError: + self._abort() + return + except ValueError: + gen_log.debug("Malformed WebSocket request received", exc_info=True) + self._abort() + return + + def _handle_websocket_headers(self, handler: WebSocketHandler) -> None: + """Verifies all invariant- and required headers + + If a header is missing or have an incorrect value ValueError will be + raised + """ + fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version") + if not all(map(lambda f: handler.request.headers.get(f), fields)): + raise ValueError("Missing/Invalid WebSocket headers") + + @staticmethod + def compute_accept_value(key: Union[str, bytes]) -> str: + """Computes the value for the Sec-WebSocket-Accept header, + given the value for Sec-WebSocket-Key. + """ + sha1 = hashlib.sha1() + sha1.update(utf8(key)) + sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value + return native_str(base64.b64encode(sha1.digest())) + + def _challenge_response(self, handler: WebSocketHandler) -> str: + return WebSocketProtocol13.compute_accept_value( + cast(str, handler.request.headers.get("Sec-Websocket-Key")) + ) + + async def _accept_connection(self, handler: WebSocketHandler) -> None: + subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol") + if subprotocol_header: + subprotocols = [s.strip() for s in subprotocol_header.split(",")] + else: + subprotocols = [] + self.selected_subprotocol = handler.select_subprotocol(subprotocols) + if self.selected_subprotocol: + assert self.selected_subprotocol in subprotocols + handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol) + + extensions = self._parse_extensions_header(handler.request.headers) + for ext in extensions: + if ext[0] == "permessage-deflate" and self._compression_options is not None: + # TODO: negotiate parameters if compression_options + # specifies limits. + self._create_compressors("server", ext[1], self._compression_options) + if ( + "client_max_window_bits" in ext[1] + and ext[1]["client_max_window_bits"] is None + ): + # Don't echo an offered client_max_window_bits + # parameter with no value. + del ext[1]["client_max_window_bits"] + handler.set_header( + "Sec-WebSocket-Extensions", + httputil._encode_header("permessage-deflate", ext[1]), + ) + break + + handler.clear_header("Content-Type") + handler.set_status(101) + handler.set_header("Upgrade", "websocket") + handler.set_header("Connection", "Upgrade") + handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler)) + handler.finish() + + self.stream = handler._detach_stream() + + self.start_pinging() + try: + open_result = handler.open(*handler.open_args, **handler.open_kwargs) + if open_result is not None: + await open_result + except Exception: + handler.log_exception(*sys.exc_info()) + self._abort() + return + + await self._receive_frame_loop() + + def _parse_extensions_header( + self, headers: httputil.HTTPHeaders + ) -> List[Tuple[str, Dict[str, str]]]: + extensions = headers.get("Sec-WebSocket-Extensions", "") + if extensions: + return [httputil._parse_header(e.strip()) for e in extensions.split(",")] + return [] + + def _process_server_headers( + self, key: Union[str, bytes], headers: httputil.HTTPHeaders + ) -> None: + """Process the headers sent by the server to this client connection. + + 'key' is the websocket handshake challenge/response key. + """ + assert headers["Upgrade"].lower() == "websocket" + assert headers["Connection"].lower() == "upgrade" + accept = self.compute_accept_value(key) + assert headers["Sec-Websocket-Accept"] == accept + + extensions = self._parse_extensions_header(headers) + for ext in extensions: + if ext[0] == "permessage-deflate" and self._compression_options is not None: + self._create_compressors("client", ext[1]) + else: + raise ValueError("unsupported extension %r", ext) + + self.selected_subprotocol = headers.get("Sec-WebSocket-Protocol", None) + + def _get_compressor_options( + self, + side: str, + agreed_parameters: Dict[str, Any], + compression_options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Converts a websocket agreed_parameters set to keyword arguments + for our compressor objects. + """ + options = dict( + persistent=(side + "_no_context_takeover") not in agreed_parameters + ) # type: Dict[str, Any] + wbits_header = agreed_parameters.get(side + "_max_window_bits", None) + if wbits_header is None: + options["max_wbits"] = zlib.MAX_WBITS + else: + options["max_wbits"] = int(wbits_header) + options["compression_options"] = compression_options + return options + + def _create_compressors( + self, + side: str, + agreed_parameters: Dict[str, Any], + compression_options: Optional[Dict[str, Any]] = None, + ) -> None: + # TODO: handle invalid parameters gracefully + allowed_keys = set( + [ + "server_no_context_takeover", + "client_no_context_takeover", + "server_max_window_bits", + "client_max_window_bits", + ] + ) + for key in agreed_parameters: + if key not in allowed_keys: + raise ValueError("unsupported compression parameter %r" % key) + other_side = "client" if (side == "server") else "server" + self._compressor = _PerMessageDeflateCompressor( + **self._get_compressor_options(side, agreed_parameters, compression_options) + ) + self._decompressor = _PerMessageDeflateDecompressor( + max_message_size=self.params.max_message_size, + **self._get_compressor_options( + other_side, agreed_parameters, compression_options + ) + ) + + def _write_frame( + self, fin: bool, opcode: int, data: bytes, flags: int = 0 + ) -> "Future[None]": + data_len = len(data) + if opcode & 0x8: + # All control frames MUST have a payload length of 125 + # bytes or less and MUST NOT be fragmented. + if not fin: + raise ValueError("control frames may not be fragmented") + if data_len > 125: + raise ValueError("control frame payloads may not exceed 125 bytes") + if fin: + finbit = self.FIN + else: + finbit = 0 + frame = struct.pack("B", finbit | opcode | flags) + if self.mask_outgoing: + mask_bit = 0x80 + else: + mask_bit = 0 + if data_len < 126: + frame += struct.pack("B", data_len | mask_bit) + elif data_len <= 0xFFFF: + frame += struct.pack("!BH", 126 | mask_bit, data_len) + else: + frame += struct.pack("!BQ", 127 | mask_bit, data_len) + if self.mask_outgoing: + mask = os.urandom(4) + data = mask + _websocket_mask(mask, data) + frame += data + self._wire_bytes_out += len(frame) + return self.stream.write(frame) + + def write_message( + self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False + ) -> "Future[None]": + """Sends the given message to the client of this Web Socket.""" + if binary: + opcode = 0x2 + else: + opcode = 0x1 + if isinstance(message, dict): + message = tornado.escape.json_encode(message) + message = tornado.escape.utf8(message) + assert isinstance(message, bytes) + self._message_bytes_out += len(message) + flags = 0 + if self._compressor: + message = self._compressor.compress(message) + flags |= self.RSV1 + # For historical reasons, write methods in Tornado operate in a semi-synchronous + # mode in which awaiting the Future they return is optional (But errors can + # still be raised). This requires us to go through an awkward dance here + # to transform the errors that may be returned while presenting the same + # semi-synchronous interface. + try: + fut = self._write_frame(True, opcode, message, flags=flags) + except StreamClosedError: + raise WebSocketClosedError() + + async def wrapper() -> None: + try: + await fut + except StreamClosedError: + raise WebSocketClosedError() + + return asyncio.ensure_future(wrapper()) + + def write_ping(self, data: bytes) -> None: + """Send ping frame.""" + assert isinstance(data, bytes) + self._write_frame(True, 0x9, data) + + async def _receive_frame_loop(self) -> None: + try: + while not self.client_terminated: + await self._receive_frame() + except StreamClosedError: + self._abort() + self.handler.on_ws_connection_close(self.close_code, self.close_reason) + + async def _read_bytes(self, n: int) -> bytes: + data = await self.stream.read_bytes(n) + self._wire_bytes_in += n + return data + + async def _receive_frame(self) -> None: + # Read the frame header. + data = await self._read_bytes(2) + header, mask_payloadlen = struct.unpack("BB", data) + is_final_frame = header & self.FIN + reserved_bits = header & self.RSV_MASK + opcode = header & self.OPCODE_MASK + opcode_is_control = opcode & 0x8 + if self._decompressor is not None and opcode != 0: + # Compression flag is present in the first frame's header, + # but we can't decompress until we have all the frames of + # the message. + self._frame_compressed = bool(reserved_bits & self.RSV1) + reserved_bits &= ~self.RSV1 + if reserved_bits: + # client is using as-yet-undefined extensions; abort + self._abort() + return + is_masked = bool(mask_payloadlen & 0x80) + payloadlen = mask_payloadlen & 0x7F + + # Parse and validate the length. + if opcode_is_control and payloadlen >= 126: + # control frames must have payload < 126 + self._abort() + return + if payloadlen < 126: + self._frame_length = payloadlen + elif payloadlen == 126: + data = await self._read_bytes(2) + payloadlen = struct.unpack("!H", data)[0] + elif payloadlen == 127: + data = await self._read_bytes(8) + payloadlen = struct.unpack("!Q", data)[0] + new_len = payloadlen + if self._fragmented_message_buffer is not None: + new_len += len(self._fragmented_message_buffer) + if new_len > self.params.max_message_size: + self.close(1009, "message too big") + self._abort() + return + + # Read the payload, unmasking if necessary. + if is_masked: + self._frame_mask = await self._read_bytes(4) + data = await self._read_bytes(payloadlen) + if is_masked: + assert self._frame_mask is not None + data = _websocket_mask(self._frame_mask, data) + + # Decide what to do with this frame. + if opcode_is_control: + # control frames may be interleaved with a series of fragmented + # data frames, so control frames must not interact with + # self._fragmented_* + if not is_final_frame: + # control frames must not be fragmented + self._abort() + return + elif opcode == 0: # continuation frame + if self._fragmented_message_buffer is None: + # nothing to continue + self._abort() + return + self._fragmented_message_buffer.extend(data) + if is_final_frame: + opcode = self._fragmented_message_opcode + data = bytes(self._fragmented_message_buffer) + self._fragmented_message_buffer = None + else: # start of new data message + if self._fragmented_message_buffer is not None: + # can't start new message until the old one is finished + self._abort() + return + if not is_final_frame: + self._fragmented_message_opcode = opcode + self._fragmented_message_buffer = bytearray(data) + + if is_final_frame: + handled_future = self._handle_message(opcode, data) + if handled_future is not None: + await handled_future + + def _handle_message(self, opcode: int, data: bytes) -> "Optional[Future[None]]": + """Execute on_message, returning its Future if it is a coroutine.""" + if self.client_terminated: + return None + + if self._frame_compressed: + assert self._decompressor is not None + try: + data = self._decompressor.decompress(data) + except _DecompressTooLargeError: + self.close(1009, "message too big after decompression") + self._abort() + return None + + if opcode == 0x1: + # UTF-8 data + self._message_bytes_in += len(data) + try: + decoded = data.decode("utf-8") + except UnicodeDecodeError: + self._abort() + return None + return self._run_callback(self.handler.on_message, decoded) + elif opcode == 0x2: + # Binary data + self._message_bytes_in += len(data) + return self._run_callback(self.handler.on_message, data) + elif opcode == 0x8: + # Close + self.client_terminated = True + if len(data) >= 2: + self.close_code = struct.unpack(">H", data[:2])[0] + if len(data) > 2: + self.close_reason = to_unicode(data[2:]) + # Echo the received close code, if any (RFC 6455 section 5.5.1). + self.close(self.close_code) + elif opcode == 0x9: + # Ping + try: + self._write_frame(True, 0xA, data) + except StreamClosedError: + self._abort() + self._run_callback(self.handler.on_ping, data) + elif opcode == 0xA: + # Pong + self.last_pong = IOLoop.current().time() + return self._run_callback(self.handler.on_pong, data) + else: + self._abort() + return None + + def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: + """Closes the WebSocket connection.""" + if not self.server_terminated: + if not self.stream.closed(): + if code is None and reason is not None: + code = 1000 # "normal closure" status code + if code is None: + close_data = b"" + else: + close_data = struct.pack(">H", code) + if reason is not None: + close_data += utf8(reason) + try: + self._write_frame(True, 0x8, close_data) + except StreamClosedError: + self._abort() + self.server_terminated = True + if self.client_terminated: + if self._waiting is not None: + self.stream.io_loop.remove_timeout(self._waiting) + self._waiting = None + self.stream.close() + elif self._waiting is None: + # Give the client a few seconds to complete a clean shutdown, + # otherwise just close the connection. + self._waiting = self.stream.io_loop.add_timeout( + self.stream.io_loop.time() + 5, self._abort + ) + if self.ping_callback: + self.ping_callback.stop() + self.ping_callback = None + + def is_closing(self) -> bool: + """Return ``True`` if this connection is closing. + + The connection is considered closing if either side has + initiated its closing handshake or if the stream has been + shut down uncleanly. + """ + return self.stream.closed() or self.client_terminated or self.server_terminated + + @property + def ping_interval(self) -> Optional[float]: + interval = self.params.ping_interval + if interval is not None: + return interval + return 0 + + @property + def ping_timeout(self) -> Optional[float]: + timeout = self.params.ping_timeout + if timeout is not None: + return timeout + assert self.ping_interval is not None + return max(3 * self.ping_interval, 30) + + def start_pinging(self) -> None: + """Start sending periodic pings to keep the connection alive""" + assert self.ping_interval is not None + if self.ping_interval > 0: + self.last_ping = self.last_pong = IOLoop.current().time() + self.ping_callback = PeriodicCallback( + self.periodic_ping, self.ping_interval * 1000 + ) + self.ping_callback.start() + + def periodic_ping(self) -> None: + """Send a ping to keep the websocket alive + + Called periodically if the websocket_ping_interval is set and non-zero. + """ + if self.is_closing() and self.ping_callback is not None: + self.ping_callback.stop() + return + + # Check for timeout on pong. Make sure that we really have + # sent a recent ping in case the machine with both server and + # client has been suspended since the last ping. + now = IOLoop.current().time() + since_last_pong = now - self.last_pong + since_last_ping = now - self.last_ping + assert self.ping_interval is not None + assert self.ping_timeout is not None + if ( + since_last_ping < 2 * self.ping_interval + and since_last_pong > self.ping_timeout + ): + self.close() + return + + self.write_ping(b"") + self.last_ping = now + + def set_nodelay(self, x: bool) -> None: + self.stream.set_nodelay(x) + + +class WebSocketClientConnection(simple_httpclient._HTTPConnection): + """WebSocket client connection. + + This class should not be instantiated directly; use the + `websocket_connect` function instead. + """ + + protocol = None # type: WebSocketProtocol + + def __init__( + self, + request: httpclient.HTTPRequest, + on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None, + compression_options: Optional[Dict[str, Any]] = None, + ping_interval: Optional[float] = None, + ping_timeout: Optional[float] = None, + max_message_size: int = _default_max_message_size, + subprotocols: Optional[List[str]] = [], + resolver: Optional[Resolver] = None, + ) -> None: + self.connect_future = Future() # type: Future[WebSocketClientConnection] + self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]] + self.key = base64.b64encode(os.urandom(16)) + self._on_message_callback = on_message_callback + self.close_code = None # type: Optional[int] + self.close_reason = None # type: Optional[str] + self.params = _WebSocketParams( + ping_interval=ping_interval, + ping_timeout=ping_timeout, + max_message_size=max_message_size, + compression_options=compression_options, + ) + + scheme, sep, rest = request.url.partition(":") + scheme = {"ws": "http", "wss": "https"}[scheme] + request.url = scheme + sep + rest + request.headers.update( + { + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": self.key, + "Sec-WebSocket-Version": "13", + } + ) + if subprotocols is not None: + request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols) + if compression_options is not None: + # Always offer to let the server set our max_wbits (and even though + # we don't offer it, we will accept a client_no_context_takeover + # from the server). + # TODO: set server parameters for deflate extension + # if requested in self.compression_options. + request.headers[ + "Sec-WebSocket-Extensions" + ] = "permessage-deflate; client_max_window_bits" + + # Websocket connection is currently unable to follow redirects + request.follow_redirects = False + + self.tcp_client = TCPClient(resolver=resolver) + super().__init__( + None, + request, + lambda: None, + self._on_http_response, + 104857600, + self.tcp_client, + 65536, + 104857600, + ) + + def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: + """Closes the websocket connection. + + ``code`` and ``reason`` are documented under + `WebSocketHandler.close`. + + .. versionadded:: 3.2 + + .. versionchanged:: 4.0 + + Added the ``code`` and ``reason`` arguments. + """ + if self.protocol is not None: + self.protocol.close(code, reason) + self.protocol = None # type: ignore + + def on_connection_close(self) -> None: + if not self.connect_future.done(): + self.connect_future.set_exception(StreamClosedError()) + self._on_message(None) + self.tcp_client.close() + super().on_connection_close() + + def on_ws_connection_close( + self, close_code: Optional[int] = None, close_reason: Optional[str] = None + ) -> None: + self.close_code = close_code + self.close_reason = close_reason + self.on_connection_close() + + def _on_http_response(self, response: httpclient.HTTPResponse) -> None: + if not self.connect_future.done(): + if response.error: + self.connect_future.set_exception(response.error) + else: + self.connect_future.set_exception( + WebSocketError("Non-websocket response") + ) + + async def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> None: + assert isinstance(start_line, httputil.ResponseStartLine) + if start_line.code != 101: + await super().headers_received(start_line, headers) + return + + if self._timeout is not None: + self.io_loop.remove_timeout(self._timeout) + self._timeout = None + + self.headers = headers + self.protocol = self.get_websocket_protocol() + self.protocol._process_server_headers(self.key, self.headers) + self.protocol.stream = self.connection.detach() + + IOLoop.current().add_callback(self.protocol._receive_frame_loop) + self.protocol.start_pinging() + + # Once we've taken over the connection, clear the final callback + # we set on the http request. This deactivates the error handling + # in simple_httpclient that would otherwise interfere with our + # ability to see exceptions. + self.final_callback = None # type: ignore + + future_set_result_unless_cancelled(self.connect_future, self) + + def write_message( + self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False + ) -> "Future[None]": + """Sends a message to the WebSocket server. + + If the stream is closed, raises `WebSocketClosedError`. + Returns a `.Future` which can be used for flow control. + + .. versionchanged:: 5.0 + Exception raised on a closed stream changed from `.StreamClosedError` + to `WebSocketClosedError`. + """ + if self.protocol is None: + raise WebSocketClosedError("Client connection has been closed") + return self.protocol.write_message(message, binary=binary) + + def read_message( + self, + callback: Optional[Callable[["Future[Union[None, str, bytes]]"], None]] = None, + ) -> Awaitable[Union[None, str, bytes]]: + """Reads a message from the WebSocket server. + + If on_message_callback was specified at WebSocket + initialization, this function will never return messages + + Returns a future whose result is the message, or None + if the connection is closed. If a callback argument + is given it will be called with the future when it is + ready. + """ + + awaitable = self.read_queue.get() + if callback is not None: + self.io_loop.add_future(asyncio.ensure_future(awaitable), callback) + return awaitable + + def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: + return self._on_message(message) + + def _on_message( + self, message: Union[None, str, bytes] + ) -> Optional[Awaitable[None]]: + if self._on_message_callback: + self._on_message_callback(message) + return None + else: + return self.read_queue.put(message) + + def ping(self, data: bytes = b"") -> None: + """Send ping frame to the remote end. + + The data argument allows a small amount of data (up to 125 + bytes) to be sent as a part of the ping message. Note that not + all websocket implementations expose this data to + applications. + + Consider using the ``ping_interval`` argument to + `websocket_connect` instead of sending pings manually. + + .. versionadded:: 5.1 + + """ + data = utf8(data) + if self.protocol is None: + raise WebSocketClosedError() + self.protocol.write_ping(data) + + def on_pong(self, data: bytes) -> None: + pass + + def on_ping(self, data: bytes) -> None: + pass + + def get_websocket_protocol(self) -> WebSocketProtocol: + return WebSocketProtocol13(self, mask_outgoing=True, params=self.params) + + @property + def selected_subprotocol(self) -> Optional[str]: + """The subprotocol selected by the server. + + .. versionadded:: 5.1 + """ + return self.protocol.selected_subprotocol + + def log_exception( + self, + typ: "Optional[Type[BaseException]]", + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + assert typ is not None + assert value is not None + app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb)) + + +def websocket_connect( + url: Union[str, httpclient.HTTPRequest], + callback: Optional[Callable[["Future[WebSocketClientConnection]"], None]] = None, + connect_timeout: Optional[float] = None, + on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None, + compression_options: Optional[Dict[str, Any]] = None, + ping_interval: Optional[float] = None, + ping_timeout: Optional[float] = None, + max_message_size: int = _default_max_message_size, + subprotocols: Optional[List[str]] = None, + resolver: Optional[Resolver] = None, +) -> "Awaitable[WebSocketClientConnection]": + """Client-side websocket support. + + Takes a url and returns a Future whose result is a + `WebSocketClientConnection`. + + ``compression_options`` is interpreted in the same way as the + return value of `.WebSocketHandler.get_compression_options`. + + The connection supports two styles of operation. In the coroutine + style, the application typically calls + `~.WebSocketClientConnection.read_message` in a loop:: + + conn = yield websocket_connect(url) + while True: + msg = yield conn.read_message() + if msg is None: break + # Do something with msg + + In the callback style, pass an ``on_message_callback`` to + ``websocket_connect``. In both styles, a message of ``None`` + indicates that the connection has been closed. + + ``subprotocols`` may be a list of strings specifying proposed + subprotocols. The selected protocol may be found on the + ``selected_subprotocol`` attribute of the connection object + when the connection is complete. + + .. versionchanged:: 3.2 + Also accepts ``HTTPRequest`` objects in place of urls. + + .. versionchanged:: 4.1 + Added ``compression_options`` and ``on_message_callback``. + + .. versionchanged:: 4.5 + Added the ``ping_interval``, ``ping_timeout``, and ``max_message_size`` + arguments, which have the same meaning as in `WebSocketHandler`. + + .. versionchanged:: 5.0 + The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. versionchanged:: 5.1 + Added the ``subprotocols`` argument. + + .. versionchanged:: 6.3 + Added the ``resolver`` argument. + """ + if isinstance(url, httpclient.HTTPRequest): + assert connect_timeout is None + request = url + # Copy and convert the headers dict/object (see comments in + # AsyncHTTPClient.fetch) + request.headers = httputil.HTTPHeaders(request.headers) + else: + request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) + request = cast( + httpclient.HTTPRequest, + httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS), + ) + conn = WebSocketClientConnection( + request, + on_message_callback=on_message_callback, + compression_options=compression_options, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + max_message_size=max_message_size, + subprotocols=subprotocols, + resolver=resolver, + ) + if callback is not None: + IOLoop.current().add_future(conn.connect_future, callback) + return conn.connect_future diff --git a/venv/lib/python3.9/site-packages/tornado/wsgi.py b/venv/lib/python3.9/site-packages/tornado/wsgi.py new file mode 100644 index 00000000..32641be3 --- /dev/null +++ b/venv/lib/python3.9/site-packages/tornado/wsgi.py @@ -0,0 +1,268 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""WSGI support for the Tornado web framework. + +WSGI is the Python standard for web servers, and allows for interoperability +between Tornado and other Python web frameworks and servers. + +This module provides WSGI support via the `WSGIContainer` class, which +makes it possible to run applications using other WSGI frameworks on +the Tornado HTTP server. The reverse is not supported; the Tornado +`.Application` and `.RequestHandler` classes are designed for use with +the Tornado `.HTTPServer` and cannot be used in a generic WSGI +container. + +""" + +import concurrent.futures +from io import BytesIO +import tornado +import sys + +from tornado.concurrent import dummy_executor +from tornado import escape +from tornado import httputil +from tornado.ioloop import IOLoop +from tornado.log import access_log + +from typing import List, Tuple, Optional, Callable, Any, Dict, Text +from types import TracebackType +import typing + +if typing.TYPE_CHECKING: + from typing import Type # noqa: F401 + from _typeshed.wsgi import WSGIApplication as WSGIAppType # noqa: F401 + + +# PEP 3333 specifies that WSGI on python 3 generally deals with byte strings +# that are smuggled inside objects of type unicode (via the latin1 encoding). +# This function is like those in the tornado.escape module, but defined +# here to minimize the temptation to use it in non-wsgi contexts. +def to_wsgi_str(s: bytes) -> str: + assert isinstance(s, bytes) + return s.decode("latin1") + + +class WSGIContainer(object): + r"""Makes a WSGI-compatible application runnable on Tornado's HTTP server. + + .. warning:: + + WSGI is a *synchronous* interface, while Tornado's concurrency model + is based on single-threaded *asynchronous* execution. Many of Tornado's + distinguishing features are not available in WSGI mode, including efficient + long-polling and websockets. The primary purpose of `WSGIContainer` is + to support both WSGI applications and native Tornado ``RequestHandlers`` in + a single process. WSGI-only applications are likely to be better off + with a dedicated WSGI server such as ``gunicorn`` or ``uwsgi``. + + Wrap a WSGI application in a `WSGIContainer` to make it implement the Tornado + `.HTTPServer` ``request_callback`` interface. The `WSGIContainer` object can + then be passed to classes from the `tornado.routing` module, + `tornado.web.FallbackHandler`, or to `.HTTPServer` directly. + + This class is intended to let other frameworks (Django, Flask, etc) + run on the Tornado HTTP server and I/O loop. + + Realistic usage will be more complicated, but the simplest possible example uses a + hand-written WSGI application with `.HTTPServer`:: + + def simple_app(environ, start_response): + status = "200 OK" + response_headers = [("Content-type", "text/plain")] + start_response(status, response_headers) + return [b"Hello world!\n"] + + async def main(): + container = tornado.wsgi.WSGIContainer(simple_app) + http_server = tornado.httpserver.HTTPServer(container) + http_server.listen(8888) + await asyncio.Event().wait() + + asyncio.run(main()) + + The recommended pattern is to use the `tornado.routing` module to set up routing + rules between your WSGI application and, typically, a `tornado.web.Application`. + Alternatively, `tornado.web.Application` can be used as the top-level router + and `tornado.web.FallbackHandler` can embed a `WSGIContainer` within it. + + If the ``executor`` argument is provided, the WSGI application will be executed + on that executor. This must be an instance of `concurrent.futures.Executor`, + typically a ``ThreadPoolExecutor`` (``ProcessPoolExecutor`` is not supported). + If no ``executor`` is given, the application will run on the event loop thread in + Tornado 6.3; this will change to use an internal thread pool by default in + Tornado 7.0. + + .. warning:: + By default, the WSGI application is executed on the event loop's thread. This + limits the server to one request at a time (per process), making it less scalable + than most other WSGI servers. It is therefore highly recommended that you pass + a ``ThreadPoolExecutor`` when constructing the `WSGIContainer`, after verifying + that your application is thread-safe. The default will change to use a + ``ThreadPoolExecutor`` in Tornado 7.0. + + .. versionadded:: 6.3 + The ``executor`` parameter. + + .. deprecated:: 6.3 + The default behavior of running the WSGI application on the event loop thread + is deprecated and will change in Tornado 7.0 to use a thread pool by default. + """ + + def __init__( + self, + wsgi_application: "WSGIAppType", + executor: Optional[concurrent.futures.Executor] = None, + ) -> None: + self.wsgi_application = wsgi_application + self.executor = dummy_executor if executor is None else executor + + def __call__(self, request: httputil.HTTPServerRequest) -> None: + IOLoop.current().spawn_callback(self.handle_request, request) + + async def handle_request(self, request: httputil.HTTPServerRequest) -> None: + data = {} # type: Dict[str, Any] + response = [] # type: List[bytes] + + def start_response( + status: str, + headers: List[Tuple[str, str]], + exc_info: Optional[ + Tuple[ + "Optional[Type[BaseException]]", + Optional[BaseException], + Optional[TracebackType], + ] + ] = None, + ) -> Callable[[bytes], Any]: + data["status"] = status + data["headers"] = headers + return response.append + + loop = IOLoop.current() + app_response = await loop.run_in_executor( + self.executor, + self.wsgi_application, + self.environ(request), + start_response, + ) + try: + app_response_iter = iter(app_response) + + def next_chunk() -> Optional[bytes]: + try: + return next(app_response_iter) + except StopIteration: + # StopIteration is special and is not allowed to pass through + # coroutines normally. + return None + + while True: + chunk = await loop.run_in_executor(self.executor, next_chunk) + if chunk is None: + break + response.append(chunk) + finally: + if hasattr(app_response, "close"): + app_response.close() # type: ignore + body = b"".join(response) + if not data: + raise Exception("WSGI app did not call start_response") + + status_code_str, reason = data["status"].split(" ", 1) + status_code = int(status_code_str) + headers = data["headers"] # type: List[Tuple[str, str]] + header_set = set(k.lower() for (k, v) in headers) + body = escape.utf8(body) + if status_code != 304: + if "content-length" not in header_set: + headers.append(("Content-Length", str(len(body)))) + if "content-type" not in header_set: + headers.append(("Content-Type", "text/html; charset=UTF-8")) + if "server" not in header_set: + headers.append(("Server", "TornadoServer/%s" % tornado.version)) + + start_line = httputil.ResponseStartLine("HTTP/1.1", status_code, reason) + header_obj = httputil.HTTPHeaders() + for key, value in headers: + header_obj.add(key, value) + assert request.connection is not None + request.connection.write_headers(start_line, header_obj, chunk=body) + request.connection.finish() + self._log(status_code, request) + + def environ(self, request: httputil.HTTPServerRequest) -> Dict[Text, Any]: + """Converts a `tornado.httputil.HTTPServerRequest` to a WSGI environment. + + .. versionchanged:: 6.3 + No longer a static method. + """ + hostport = request.host.split(":") + if len(hostport) == 2: + host = hostport[0] + port = int(hostport[1]) + else: + host = request.host + port = 443 if request.protocol == "https" else 80 + environ = { + "REQUEST_METHOD": request.method, + "SCRIPT_NAME": "", + "PATH_INFO": to_wsgi_str( + escape.url_unescape(request.path, encoding=None, plus=False) + ), + "QUERY_STRING": request.query, + "REMOTE_ADDR": request.remote_ip, + "SERVER_NAME": host, + "SERVER_PORT": str(port), + "SERVER_PROTOCOL": request.version, + "wsgi.version": (1, 0), + "wsgi.url_scheme": request.protocol, + "wsgi.input": BytesIO(escape.utf8(request.body)), + "wsgi.errors": sys.stderr, + "wsgi.multithread": self.executor is not dummy_executor, + "wsgi.multiprocess": True, + "wsgi.run_once": False, + } + if "Content-Type" in request.headers: + environ["CONTENT_TYPE"] = request.headers.pop("Content-Type") + if "Content-Length" in request.headers: + environ["CONTENT_LENGTH"] = request.headers.pop("Content-Length") + for key, value in request.headers.items(): + environ["HTTP_" + key.replace("-", "_").upper()] = value + return environ + + def _log(self, status_code: int, request: httputil.HTTPServerRequest) -> None: + if status_code < 400: + log_method = access_log.info + elif status_code < 500: + log_method = access_log.warning + else: + log_method = access_log.error + request_time = 1000.0 * request.request_time() + assert request.method is not None + assert request.uri is not None + summary = ( + request.method # type: ignore[operator] + + " " + + request.uri + + " (" + + request.remote_ip + + ")" + ) + log_method("%d %s %.2fms", status_code, summary, request_time) + + +HTTPRequest = httputil.HTTPServerRequest |