summaryrefslogtreecommitdiffstats
path: root/g4f/client
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/client')
-rw-r--r--g4f/client/client.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/g4f/client/client.py b/g4f/client/client.py
index 8e195213..07db107a 100644
--- a/g4f/client/client.py
+++ b/g4f/client/client.py
@@ -184,8 +184,12 @@ class Completions:
ignore_stream: bool = False,
**kwargs
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
- model, provider = get_model_and_provider(
- model,
+ # We use ModelUtils to obtain the model object.
+ model_instance = ModelUtils.get_model(model)
+
+ # We receive the model and the provider.
+ model_name, provider = get_model_and_provider(
+ model_instance.name, # We use the model name from the object.
self.provider if provider is None else provider,
stream,
ignored,
@@ -196,9 +200,8 @@ class Completions:
stop = [stop] if isinstance(stop, str) else stop
if asyncio.iscoroutinefunction(provider.create_completion):
- # Run the asynchronous function in an event loop
response = asyncio.run(provider.create_completion(
- model,
+ model_name, # We use a model based on the object.
messages,
stream=stream,
**filter_none(
@@ -211,7 +214,7 @@ class Completions:
))
else:
response = provider.create_completion(
- model,
+ model_name, # We use a model from the object.
messages,
stream=stream,
**filter_none(
@@ -225,21 +228,19 @@ class Completions:
if stream:
if hasattr(response, '__aiter__'):
- # It's an async generator, wrap it into a sync iterator
response = to_sync_iter(response)
- # Now 'response' is an iterator
response = iter_response(response, stream, response_format, max_tokens, stop)
response = iter_append_model_and_provider(response)
return response
else:
if hasattr(response, '__aiter__'):
- # If response is an async generator, collect it into a list
response = list(to_sync_iter(response))
response = iter_response(response, stream, response_format, max_tokens, stop)
response = iter_append_model_and_provider(response)
return next(response)
+
async def async_create(
self,
messages: Messages,