diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/requests.py | 32 |
1 files changed, 23 insertions, 9 deletions
diff --git a/g4f/requests.py b/g4f/requests.py index 736442e3..f6f2383b 100644 --- a/g4f/requests.py +++ b/g4f/requests.py @@ -1,6 +1,6 @@ from __future__ import annotations -import json, sys +import json, sys, asyncio from functools import partialmethod from aiohttp import StreamReader @@ -8,6 +8,9 @@ from aiohttp.base_protocol import BaseProtocol from curl_cffi.requests import AsyncSession as BaseSession from curl_cffi.requests import Response +from curl_cffi import AsyncCurl + +is_newer_0_5_9 = hasattr(AsyncCurl, "remove_handle") class StreamResponse: @@ -35,7 +38,7 @@ class StreamResponse: class StreamRequest: def __init__(self, session: AsyncSession, method: str, url: str, **kwargs): self.session = session - self.loop = session.loop + self.loop = session.loop if session.loop else asyncio.get_running_loop() self.content = StreamReader( BaseProtocol(session.loop), sys.maxsize, @@ -51,10 +54,9 @@ class StreamRequest: self.content.feed_data(data) def on_done(self, task): + if not self.enter.done(): + self.enter.set_result(None) self.content.feed_eof() - self.curl.clean_after_perform() - self.curl.reset() - self.session.push_curl(self.curl) async def __aenter__(self) -> StreamResponse: self.curl = await self.session.pop_curl() @@ -66,18 +68,30 @@ class StreamRequest: content_callback=self.on_content, **self.options ) - await self.session.acurl.add_handle(self.curl, False) - self.handle = self.session.acurl._curl2future[self.curl] + if is_newer_0_5_9: + self.handle = self.session.acurl.add_handle(self.curl) + else: + await self.session.acurl.add_handle(self.curl, False) + self.handle = self.session.acurl._curl2future[self.curl] self.handle.add_done_callback(self.on_done) await self.enter + if is_newer_0_5_9: + response = self.session._parse_response(self.curl, _, header_buffer) + response.request = request + else: + response = self.session._parse_response(self.curl, request, _, header_buffer) return StreamResponse( - self.session._parse_response(self.curl, request, _, header_buffer), + response, self.content, request ) async def __aexit__(self, exc_type, exc, tb): - pass + if not self.handle.done(): + self.session.acurl.set_result(self.curl) + self.curl.clean_after_perform() + self.curl.reset() + self.session.push_curl(self.curl) class AsyncSession(BaseSession): def request( |