diff options
-rw-r--r-- | gpt4free/quora/api.py | 35 | ||||
-rw-r--r-- | gpt4free/quora/tests/__init__.py | 0 | ||||
-rw-r--r-- | gpt4free/quora/tests/test_api.py | 38 |
3 files changed, 59 insertions, 14 deletions
diff --git a/gpt4free/quora/api.py b/gpt4free/quora/api.py index d388baee..9e3c0b91 100644 --- a/gpt4free/quora/api.py +++ b/gpt4free/quora/api.py @@ -56,18 +56,25 @@ def generate_payload(query_name, variables): return {"query": queries[query_name], "variables": variables} -def request_with_retries(method, *args, **kwargs): - attempts = kwargs.get("attempts") or 10 +def retry_request(method, *args, **kwargs): + """Retry a request with 10 attempts by default, delay increases exponentially""" + max_attempts: int = kwargs.pop("max_attempts", 10) + delay = kwargs.pop("delay", 1) url = args[0] - for i in range(attempts): - r = method(*args, **kwargs) - if r.status_code == 200: - return r - logger.warn( - f"Server returned a status code of {r.status_code} while downloading {url}. Retrying ({i + 1}/{attempts})..." - ) - raise RuntimeError(f"Failed to download {url} too many times.") + for attempt in range(1, max_attempts + 1): + try: + response = method(*args, **kwargs) + response.raise_for_status() + return response + except Exception as error: + logger.warning( + f"Attempt {attempt}/{max_attempts} failed with error: {error}. " + f"Retrying in {delay} seconds..." + ) + time.sleep(delay) + delay *= 2 + raise RuntimeError(f"Failed to download {url} after {max_attempts} attempts.") class Client: @@ -134,7 +141,7 @@ class Client: def get_next_data(self, overwrite_vars=False): logger.info("Downloading next_data...") - r = request_with_retries(self.session.get, self.home_url) + r = retry_request(self.session.get, self.home_url) json_regex = r'<script id="__NEXT_DATA__" type="application\/json">(.+?)</script>' json_text = re.search(json_regex, r.text).group(1) next_data = json.loads(json_text) @@ -149,7 +156,7 @@ class Client: def get_bot(self, display_name): url = f'https://poe.com/_next/data/{self.next_data["buildId"]}/{display_name}.json' - r = request_with_retries(self.session.get, url) + r = retry_request(self.session.get, url) chat_data = r.json()["pageProps"]["payload"]["chatOfBotDisplayName"] return chat_data @@ -198,7 +205,7 @@ class Client: def get_channel_data(self, channel=None): logger.info("Downloading channel data...") - r = request_with_retries(self.session.get, self.settings_url) + r = retry_request(self.session.get, self.settings_url) data = r.json() return data["tchannelData"] @@ -222,7 +229,7 @@ class Client: } headers = {**self.gql_headers, **headers} - r = request_with_retries(self.session.post, self.gql_url, data=payload, headers=headers) + r = retry_request(self.session.post, self.gql_url, data=payload, headers=headers) data = r.json() if data["data"] is None: diff --git a/gpt4free/quora/tests/__init__.py b/gpt4free/quora/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/gpt4free/quora/tests/__init__.py diff --git a/gpt4free/quora/tests/test_api.py b/gpt4free/quora/tests/test_api.py new file mode 100644 index 00000000..2a4bb41b --- /dev/null +++ b/gpt4free/quora/tests/test_api.py @@ -0,0 +1,38 @@ +import unittest +import requests +from unittest.mock import MagicMock +from gpt4free.quora.api import retry_request + + +class TestRetryRequest(unittest.TestCase): + def test_successful_request(self): + # Mock a successful request with a 200 status code + mock_response = MagicMock() + mock_response.status_code = 200 + requests.get = MagicMock(return_value=mock_response) + + # Call the function and assert that it returns the response + response = retry_request(requests.get, "http://example.com", max_attempts=3) + self.assertEqual(response.status_code, 200) + + def test_exponential_backoff(self): + # Mock a failed request that succeeds after two retries + mock_response = MagicMock() + mock_response.status_code = 200 + requests.get = MagicMock(side_effect=[requests.exceptions.RequestException] * 2 + [mock_response]) + + # Call the function and assert that it retries with exponential backoff + with self.assertLogs() as logs: + response = retry_request(requests.get, "http://example.com", max_attempts=3, delay=1) + self.assertEqual(response.status_code, 200) + self.assertGreaterEqual(len(logs.output), 2) + self.assertIn("Retrying in 1 seconds...", logs.output[0]) + self.assertIn("Retrying in 2 seconds...", logs.output[1]) + + def test_too_many_attempts(self): + # Mock a failed request that never succeeds + requests.get = MagicMock(side_effect=requests.exceptions.RequestException) + + # Call the function and assert that it raises an exception after the maximum number of attempts + with self.assertRaises(RuntimeError): + retry_request(requests.get, "http://example.com", max_attempts=3) |