diff options
Diffstat (limited to 'g4f/Provider/H2o.py')
-rw-r--r-- | g4f/Provider/H2o.py | 76 |
1 files changed, 33 insertions, 43 deletions
diff --git a/g4f/Provider/H2o.py b/g4f/Provider/H2o.py index 305a0bbf..ea4d0825 100644 --- a/g4f/Provider/H2o.py +++ b/g4f/Provider/H2o.py @@ -1,25 +1,21 @@ -import json -import uuid +import json, uuid, requests -import requests - -from ..typing import Any, CreateResult +from ..typing import Any, CreateResult from .base_provider import BaseProvider class H2o(BaseProvider): - url = "https://gpt-gm.h2o.ai" - working = True + url = "https://gpt-gm.h2o.ai" + working = True supports_stream = True - model = "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1" + model = "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1" @staticmethod def create_completion( model: str, messages: list[dict[str, str]], - stream: bool, - **kwargs: Any, - ) -> CreateResult: + stream: bool, **kwargs: Any) -> CreateResult: + conversation = "" for message in messages: conversation += "%s: %s\n" % (message["role"], message["content"]) @@ -29,58 +25,52 @@ class H2o(BaseProvider): headers = {"Referer": "https://gpt-gm.h2o.ai/r/jGfKSwU"} data = { - "ethicsModalAccepted": "true", + "ethicsModalAccepted" : "true", "shareConversationsWithModelAuthors": "true", - "ethicsModalAcceptedAt": "", - "activeModel": model, - "searchEnabled": "true", + "ethicsModalAcceptedAt" : "", + "activeModel" : model, + "searchEnabled" : "true", } - session.post( - "https://gpt-gm.h2o.ai/settings", - headers=headers, - data=data, - ) + + session.post("https://gpt-gm.h2o.ai/settings", + headers=headers, data=data) headers = {"Referer": "https://gpt-gm.h2o.ai/"} - data = {"model": model} + data = {"model": model} - response = session.post( - "https://gpt-gm.h2o.ai/conversation", - headers=headers, - json=data, - ).json() + response = session.post("https://gpt-gm.h2o.ai/conversation", + headers=headers, json=data).json() + if "conversationId" not in response: return data = { "inputs": conversation, "parameters": { - "temperature": kwargs.get("temperature", 0.4), - "truncate": kwargs.get("truncate", 2048), - "max_new_tokens": kwargs.get("max_new_tokens", 1024), - "do_sample": kwargs.get("do_sample", True), + "temperature" : kwargs.get("temperature", 0.4), + "truncate" : kwargs.get("truncate", 2048), + "max_new_tokens" : kwargs.get("max_new_tokens", 1024), + "do_sample" : kwargs.get("do_sample", True), "repetition_penalty": kwargs.get("repetition_penalty", 1.2), - "return_full_text": kwargs.get("return_full_text", False), + "return_full_text" : kwargs.get("return_full_text", False), }, - "stream": True, + "stream" : True, "options": { - "id": kwargs.get("id", str(uuid.uuid4())), - "response_id": kwargs.get("response_id", str(uuid.uuid4())), - "is_retry": False, - "use_cache": False, + "id" : kwargs.get("id", str(uuid.uuid4())), + "response_id" : kwargs.get("response_id", str(uuid.uuid4())), + "is_retry" : False, + "use_cache" : False, "web_search_id": "", }, } - response = session.post( - f"https://gpt-gm.h2o.ai/conversation/{response['conversationId']}", - headers=headers, - json=data, - ) + response = session.post(f"https://gpt-gm.h2o.ai/conversation/{response['conversationId']}", + headers=headers, json=data) + response.raise_for_status() response.encoding = "utf-8" - generated_text = response.text.replace("\n", "").split("data:") - generated_text = json.loads(generated_text[-1]) + generated_text = response.text.replace("\n", "").split("data:") + generated_text = json.loads(generated_text[-1]) yield generated_text["generated_text"] |