summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorkqlio67 <kqlio67@users.noreply.github.com>2024-10-30 15:25:55 +0100
committerkqlio67 <kqlio67@users.noreply.github.com>2024-10-30 15:25:55 +0100
commite6627d8d30fe7dfcf2a111b444f2abb5c4ead1ac (patch)
tree5a06635657ff1f8d2ab2156898ff0bb1fff025f4
parentfeat(g4f/models.py): add versioning support for model retrieval (diff)
downloadgpt4free-e6627d8d30fe7dfcf2a111b444f2abb5c4ead1ac.tar
gpt4free-e6627d8d30fe7dfcf2a111b444f2abb5c4ead1ac.tar.gz
gpt4free-e6627d8d30fe7dfcf2a111b444f2abb5c4ead1ac.tar.bz2
gpt4free-e6627d8d30fe7dfcf2a111b444f2abb5c4ead1ac.tar.lz
gpt4free-e6627d8d30fe7dfcf2a111b444f2abb5c4ead1ac.tar.xz
gpt4free-e6627d8d30fe7dfcf2a111b444f2abb5c4ead1ac.tar.zst
gpt4free-e6627d8d30fe7dfcf2a111b444f2abb5c4ead1ac.zip
-rw-r--r--g4f/client/client.py17
-rw-r--r--g4f/models.py42
2 files changed, 8 insertions, 51 deletions
diff --git a/g4f/client/client.py b/g4f/client/client.py
index 07db107a..8e195213 100644
--- a/g4f/client/client.py
+++ b/g4f/client/client.py
@@ -184,12 +184,8 @@ class Completions:
ignore_stream: bool = False,
**kwargs
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
- # We use ModelUtils to obtain the model object.
- model_instance = ModelUtils.get_model(model)
-
- # We receive the model and the provider.
- model_name, provider = get_model_and_provider(
- model_instance.name, # We use the model name from the object.
+ model, provider = get_model_and_provider(
+ model,
self.provider if provider is None else provider,
stream,
ignored,
@@ -200,8 +196,9 @@ class Completions:
stop = [stop] if isinstance(stop, str) else stop
if asyncio.iscoroutinefunction(provider.create_completion):
+ # Run the asynchronous function in an event loop
response = asyncio.run(provider.create_completion(
- model_name, # We use a model based on the object.
+ model,
messages,
stream=stream,
**filter_none(
@@ -214,7 +211,7 @@ class Completions:
))
else:
response = provider.create_completion(
- model_name, # We use a model from the object.
+ model,
messages,
stream=stream,
**filter_none(
@@ -228,19 +225,21 @@ class Completions:
if stream:
if hasattr(response, '__aiter__'):
+ # It's an async generator, wrap it into a sync iterator
response = to_sync_iter(response)
+ # Now 'response' is an iterator
response = iter_response(response, stream, response_format, max_tokens, stop)
response = iter_append_model_and_provider(response)
return response
else:
if hasattr(response, '__aiter__'):
+ # If response is an async generator, collect it into a list
response = list(to_sync_iter(response))
response = iter_response(response, stream, response_format, max_tokens, stop)
response = iter_append_model_and_provider(response)
return next(response)
-
async def async_create(
self,
messages: Messages,
diff --git a/g4f/models.py b/g4f/models.py
index 2378079b..bea09f28 100644
--- a/g4f/models.py
+++ b/g4f/models.py
@@ -891,17 +891,6 @@ any_dark = Model(
)
-
-class ModelVersions:
- # Global Prefixes for All Models
- GLOBAL_PREFIXES = [":latest"]
-
- # Specific Prefixes for Particular Models
- MODEL_SPECIFIC_PREFIXES = {
- #frozenset(["gpt-3.5-turbo", "gpt-4"]): [":custom1", ":custom2"]
- #frozenset(["gpt-3.5-turbo"]): [":custom"],
- }
-
class ModelUtils:
"""
Utility class for mapping string identifiers to Model instances.
@@ -1174,35 +1163,4 @@ class ModelUtils:
'any-dark': any_dark,
}
- @classmethod
- def get_model(cls, model_name: str) -> Model:
- # Checking for specific prefixes
- for model_set, specific_prefixes in ModelVersions.MODEL_SPECIFIC_PREFIXES.items():
- for prefix in specific_prefixes:
- if model_name.endswith(prefix):
- base_name = model_name[:-len(prefix)]
- if base_name in model_set:
- return cls.convert.get(base_name, None)
-
- # Check for global prefixes
- for prefix in ModelVersions.GLOBAL_PREFIXES:
- if model_name.endswith(prefix):
- base_name = model_name[:-len(prefix)]
- return cls.convert.get(base_name, None)
-
- # Check without prefix
- if model_name in cls.convert:
- return cls.convert[model_name]
-
- raise KeyError(f"Model {model_name} not found")
-
- @classmethod
- def get_available_versions(cls, model_name: str) -> list[str]:
- # Obtaining prefixes for a specific model
- prefixes = ModelVersions.GLOBAL_PREFIXES.copy()
- for model_set, specific_prefixes in ModelVersions.MODEL_SPECIFIC_PREFIXES.items():
- if model_name in model_set:
- prefixes.extend(specific_prefixes)
- return prefixes
-
_all_models = list(ModelUtils.convert.keys())