summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/base_provider.py1
-rw-r--r--g4f/Provider/needs_auth/Groq.py23
-rw-r--r--g4f/Provider/needs_auth/Openai.py74
-rw-r--r--g4f/Provider/needs_auth/__init__.py4
-rw-r--r--g4f/client.py5
-rw-r--r--g4f/providers/base_provider.py37
-rw-r--r--g4f/providers/types.py6
-rw-r--r--g4f/requests/aiohttp.py16
8 files changed, 141 insertions, 25 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 8e761dba..4c0157f3 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -1,2 +1,3 @@
from ..providers.base_provider import *
+from ..providers.types import FinishReason
from .helper import get_cookies, format_prompt \ No newline at end of file
diff --git a/g4f/Provider/needs_auth/Groq.py b/g4f/Provider/needs_auth/Groq.py
new file mode 100644
index 00000000..87e87e60
--- /dev/null
+++ b/g4f/Provider/needs_auth/Groq.py
@@ -0,0 +1,23 @@
+from __future__ import annotations
+
+from .Openai import Openai
+from ...typing import AsyncResult, Messages
+
+class Groq(Openai):
+ url = "https://console.groq.com/playground"
+ working = True
+ default_model = "mixtral-8x7b-32768"
+ models = ["mixtral-8x7b-32768", "llama2-70b-4096", "gemma-7b-it"]
+ model_aliases = {"mixtral-8x7b": "mixtral-8x7b-32768", "llama2-70b": "llama2-70b-4096"}
+
+ @classmethod
+ def create_async_generator(
+ cls,
+ model: str,
+ messages: Messages,
+ api_base: str = "https://api.groq.com/openai/v1",
+ **kwargs
+ ) -> AsyncResult:
+ return super().create_async_generator(
+ model, messages, api_base=api_base, **kwargs
+ ) \ No newline at end of file
diff --git a/g4f/Provider/needs_auth/Openai.py b/g4f/Provider/needs_auth/Openai.py
new file mode 100644
index 00000000..b876cd0b
--- /dev/null
+++ b/g4f/Provider/needs_auth/Openai.py
@@ -0,0 +1,74 @@
+from __future__ import annotations
+
+import json
+
+from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
+from ...typing import AsyncResult, Messages
+from ...requests.raise_for_status import raise_for_status
+from ...requests import StreamSession
+from ...errors import MissingAuthError
+
+class Openai(AsyncGeneratorProvider, ProviderModelMixin):
+ url = "https://openai.com"
+ working = True
+ needs_auth = True
+ supports_message_history = True
+ supports_system_message = True
+
+ @classmethod
+ async def create_async_generator(
+ cls,
+ model: str,
+ messages: Messages,
+ proxy: str = None,
+ timeout: int = 120,
+ api_key: str = None,
+ api_base: str = "https://api.openai.com/v1",
+ temperature: float = None,
+ max_tokens: int = None,
+ top_p: float = None,
+ stop: str = None,
+ stream: bool = False,
+ **kwargs
+ ) -> AsyncResult:
+ if api_key is None:
+ raise MissingAuthError('Add a "api_key"')
+ async with StreamSession(
+ proxies={"all": proxy},
+ headers=cls.get_headers(api_key),
+ timeout=timeout
+ ) as session:
+ data = {
+ "messages": messages,
+ "model": cls.get_model(model),
+ "temperature": temperature,
+ "max_tokens": max_tokens,
+ "top_p": top_p,
+ "stop": stop,
+ "stream": stream,
+ }
+ async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
+ await raise_for_status(response)
+ async for line in response.iter_lines():
+ if line.startswith(b"data: ") or not stream:
+ async for chunk in cls.read_line(line[6:] if stream else line, stream):
+ yield chunk
+
+ @staticmethod
+ async def read_line(line: str, stream: bool):
+ if line == b"[DONE]":
+ return
+ choice = json.loads(line)["choices"][0]
+ if stream and "content" in choice["delta"] and choice["delta"]["content"]:
+ yield choice["delta"]["content"]
+ elif not stream and "content" in choice["message"]:
+ yield choice["message"]["content"]
+ if "finish_reason" in choice and choice["finish_reason"] is not None:
+ yield FinishReason(choice["finish_reason"])
+
+ @staticmethod
+ def get_headers(api_key: str) -> dict:
+ return {
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ } \ No newline at end of file
diff --git a/g4f/Provider/needs_auth/__init__.py b/g4f/Provider/needs_auth/__init__.py
index 5eb1b2eb..92fa165b 100644
--- a/g4f/Provider/needs_auth/__init__.py
+++ b/g4f/Provider/needs_auth/__init__.py
@@ -4,4 +4,6 @@ from .Theb import Theb
from .ThebApi import ThebApi
from .OpenaiChat import OpenaiChat
from .OpenAssistant import OpenAssistant
-from .Poe import Poe \ No newline at end of file
+from .Poe import Poe
+from .Openai import Openai
+from .Groq import Groq \ No newline at end of file
diff --git a/g4f/client.py b/g4f/client.py
index d7ceb009..2c4fe788 100644
--- a/g4f/client.py
+++ b/g4f/client.py
@@ -8,7 +8,7 @@ import string
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .typing import Union, Iterator, Messages, ImageType
-from .providers.types import BaseProvider, ProviderType
+from .providers.types import BaseProvider, ProviderType, FinishReason
from .image import ImageResponse as ImageProviderResponse
from .errors import NoImageResponseError, RateLimitError, MissingAuthError
from . import get_model_and_provider, get_last_provider
@@ -47,6 +47,9 @@ def iter_response(
finish_reason = None
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
for idx, chunk in enumerate(response):
+ if isinstance(chunk, FinishReason):
+ finish_reason = chunk.reason
+ break
content += str(chunk)
if max_tokens is not None and idx + 1 >= max_tokens:
finish_reason = "length"
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py
index ee5bcbb8..37f4af15 100644
--- a/g4f/providers/base_provider.py
+++ b/g4f/providers/base_provider.py
@@ -6,9 +6,10 @@ from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
from inspect import signature, Parameter
-from ..typing import CreateResult, AsyncResult, Messages, Union
-from .types import BaseProvider
-from ..errors import NestAsyncioError, ModelNotSupportedError
+from typing import Callable, Union
+from ..typing import CreateResult, AsyncResult, Messages
+from .types import BaseProvider, FinishReason
+from ..errors import NestAsyncioError, ModelNotSupportedError, MissingRequirementsError
from .. import debug
if sys.version_info < (3, 10):
@@ -21,17 +22,23 @@ if sys.platform == 'win32':
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
-def get_running_loop() -> Union[AbstractEventLoop, None]:
+def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
try:
loop = asyncio.get_running_loop()
- if not hasattr(loop.__class__, "_nest_patched"):
- raise NestAsyncioError(
- 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
- )
+ if check_nested and not hasattr(loop.__class__, "_nest_patched"):
+ try:
+ import nest_asyncio
+ nest_asyncio.apply(loop)
+ except ImportError:
+ raise MissingRequirementsError('Install "nest_asyncio" package')
return loop
except RuntimeError:
pass
+# Fix for RuntimeError: async generator ignored GeneratorExit
+async def await_callback(callback: Callable):
+ return await callback()
+
class AbstractProvider(BaseProvider):
"""
Abstract class for providing asynchronous functionality to derived classes.
@@ -132,7 +139,7 @@ class AsyncProvider(AbstractProvider):
Returns:
CreateResult: The result of the completion creation.
"""
- get_running_loop()
+ get_running_loop(check_nested=True)
yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod
@@ -158,7 +165,6 @@ class AsyncProvider(AbstractProvider):
"""
raise NotImplementedError()
-
class AsyncGeneratorProvider(AsyncProvider):
"""
Provides asynchronous generator functionality for streaming results.
@@ -187,9 +193,9 @@ class AsyncGeneratorProvider(AsyncProvider):
Returns:
CreateResult: The result of the streaming completion creation.
"""
- loop = get_running_loop()
+ loop = get_running_loop(check_nested=True)
new_loop = False
- if not loop:
+ if loop is None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
new_loop = True
@@ -197,16 +203,11 @@ class AsyncGeneratorProvider(AsyncProvider):
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
gen = generator.__aiter__()
- # Fix for RuntimeError: async generator ignored GeneratorExit
- async def await_callback(callback):
- return await callback()
-
try:
while True:
yield loop.run_until_complete(await_callback(gen.__anext__))
except StopAsyncIteration:
...
- # Fix for: ResourceWarning: unclosed event loop
finally:
if new_loop:
loop.close()
@@ -233,7 +234,7 @@ class AsyncGeneratorProvider(AsyncProvider):
"""
return "".join([
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
- if not isinstance(chunk, Exception)
+ if not isinstance(chunk, (Exception, FinishReason))
])
@staticmethod
diff --git a/g4f/providers/types.py b/g4f/providers/types.py
index 67340958..a3eeb99e 100644
--- a/g4f/providers/types.py
+++ b/g4f/providers/types.py
@@ -97,4 +97,8 @@ class BaseRetryProvider(BaseProvider):
__name__: str = "RetryProvider"
supports_stream: bool = True
-ProviderType = Union[Type[BaseProvider], BaseRetryProvider] \ No newline at end of file
+ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
+
+class FinishReason():
+ def __init__(self, reason: str):
+ self.reason = reason \ No newline at end of file
diff --git a/g4f/requests/aiohttp.py b/g4f/requests/aiohttp.py
index 16b052eb..71e7bde7 100644
--- a/g4f/requests/aiohttp.py
+++ b/g4f/requests/aiohttp.py
@@ -15,11 +15,19 @@ class StreamResponse(ClientResponse):
async for chunk in self.content.iter_any():
yield chunk
- async def json(self) -> Any:
- return await super().json(content_type=None)
+ async def json(self, content_type: str = None) -> Any:
+ return await super().json(content_type=content_type)
class StreamSession(ClientSession):
- def __init__(self, headers: dict = {}, timeout: int = None, proxies: dict = {}, impersonate = None, **kwargs):
+ def __init__(
+ self,
+ headers: dict = {},
+ timeout: int = None,
+ connector: BaseConnector = None,
+ proxies: dict = {},
+ impersonate = None,
+ **kwargs
+ ):
if impersonate:
headers = {
**DEFAULT_HEADERS,
@@ -29,7 +37,7 @@ class StreamSession(ClientSession):
**kwargs,
timeout=ClientTimeout(timeout) if timeout else None,
response_class=StreamResponse,
- connector=get_connector(kwargs.get("connector"), proxies.get("https")),
+ connector=get_connector(connector, proxies.get("all", proxies.get("https"))),
headers=headers
)