From b11cf3ab4babb0493856c194a542b7b70d1a7728 Mon Sep 17 00:00:00 2001 From: kqlio67 Date: Wed, 30 Oct 2024 14:09:16 +0200 Subject: feat(g4f/client/client.py): integrate ModelUtils for model retrieval --- g4f/client/client.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) (limited to 'g4f/client') diff --git a/g4f/client/client.py b/g4f/client/client.py index 8e195213..07db107a 100644 --- a/g4f/client/client.py +++ b/g4f/client/client.py @@ -184,8 +184,12 @@ class Completions: ignore_stream: bool = False, **kwargs ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: - model, provider = get_model_and_provider( - model, + # We use ModelUtils to obtain the model object. + model_instance = ModelUtils.get_model(model) + + # We receive the model and the provider. + model_name, provider = get_model_and_provider( + model_instance.name, # We use the model name from the object. self.provider if provider is None else provider, stream, ignored, @@ -196,9 +200,8 @@ class Completions: stop = [stop] if isinstance(stop, str) else stop if asyncio.iscoroutinefunction(provider.create_completion): - # Run the asynchronous function in an event loop response = asyncio.run(provider.create_completion( - model, + model_name, # We use a model based on the object. messages, stream=stream, **filter_none( @@ -211,7 +214,7 @@ class Completions: )) else: response = provider.create_completion( - model, + model_name, # We use a model from the object. messages, stream=stream, **filter_none( @@ -225,21 +228,19 @@ class Completions: if stream: if hasattr(response, '__aiter__'): - # It's an async generator, wrap it into a sync iterator response = to_sync_iter(response) - # Now 'response' is an iterator response = iter_response(response, stream, response_format, max_tokens, stop) response = iter_append_model_and_provider(response) return response else: if hasattr(response, '__aiter__'): - # If response is an async generator, collect it into a list response = list(to_sync_iter(response)) response = iter_response(response, stream, response_format, max_tokens, stop) response = iter_append_model_and_provider(response) return next(response) + async def async_create( self, messages: Messages, -- cgit v1.2.3