From 08085d345b919f99deb0ea1b8338c868002b7334 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 19 Nov 2024 10:59:49 +0100 Subject: Fix loading models in Airforce provider --- g4f/Provider/airforce/AirforceChat.py | 11 ++++++++--- g4f/Provider/airforce/AirforceImage.py | 32 +++++++++++++++----------------- 2 files changed, 23 insertions(+), 20 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/airforce/AirforceChat.py b/g4f/Provider/airforce/AirforceChat.py index e94dd0a8..1efe0026 100644 --- a/g4f/Provider/airforce/AirforceChat.py +++ b/g4f/Provider/airforce/AirforceChat.py @@ -4,6 +4,7 @@ import json import requests from aiohttp import ClientSession from typing import List +import logging from ...typing import AsyncResult, Messages from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin @@ -54,9 +55,13 @@ class AirforceChat(AsyncGeneratorProvider, ProviderModelMixin): @classmethod def get_models(cls) -> list: if not cls.models: - response = requests.get('https://api.airforce/models') - data = response.json() - cls.models = [model['id'] for model in data['data']] + try: + response = requests.get('https://api.airforce/models', verify=False) + data = response.json() + cls.models = [model['id'] for model in data['data']] + except Exception as e: + logging.exception(e) + cls.models = [cls.default_model] model_aliases = { # openchat diff --git a/g4f/Provider/airforce/AirforceImage.py b/g4f/Provider/airforce/AirforceImage.py index b74bc364..a5bd113f 100644 --- a/g4f/Provider/airforce/AirforceImage.py +++ b/g4f/Provider/airforce/AirforceImage.py @@ -4,39 +4,37 @@ from aiohttp import ClientSession from urllib.parse import urlencode import random import requests +import logging from ...typing import AsyncResult, Messages from ...image import ImageResponse from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin - class AirforceImage(AsyncGeneratorProvider, ProviderModelMixin): label = "Airforce Image" - #url = "https://api.airforce" + url = "https://api.airforce" api_endpoint = "https://api.airforce/imagine2" - #working = True + working = False default_model = 'flux' - - response = requests.get('https://api.airforce/imagine/models') - data = response.json() - - image_models = data - - models = [*image_models, "stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "Flux-1.1-Pro"] - + additional_models = ["stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "Flux-1.1-Pro"] model_aliases = { "sdxl": "stable-diffusion-xl-base", "sdxl": "stable-diffusion-xl-lightning", "flux-pro": "Flux-1.1-Pro", } - + @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - else: - return cls.default_model + def get_models(cls) -> list: + if not cls.models: + try: + response = requests.get('https://api.airforce/imagine/models', verify=False) + response.raise_for_status() + cls.models = [*response.json(), *cls.additional_models] + except Exception as e: + logging.exception(e) + cls.models = [cls.default_model] + return cls.models @classmethod async def create_async_generator( -- cgit v1.2.3