From 126496d3cacd06a4fa8cbb4e5bde417ce6bb5b4a Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Fri, 25 Aug 2023 06:41:32 +0200 Subject: Add OpenaiChat and Hugchat Provider Add tests for providers with auth Improve async support / 2x faster Shared get_cookies by domain function --- g4f/Provider/Hugchat.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 g4f/Provider/Hugchat.py (limited to 'g4f/Provider/Hugchat.py') diff --git a/g4f/Provider/Hugchat.py b/g4f/Provider/Hugchat.py new file mode 100644 index 00000000..cedf8402 --- /dev/null +++ b/g4f/Provider/Hugchat.py @@ -0,0 +1,67 @@ +has_module = False +try: + from hugchat.hugchat import ChatBot +except ImportError: + has_module = False + +from .base_provider import BaseProvider, get_cookies +from g4f.typing import CreateResult + +class Hugchat(BaseProvider): + url = "https://huggingface.co/chat/" + needs_auth = True + working = has_module + llms = ['OpenAssistant/oasst-sft-6-llama-30b-xor', 'meta-llama/Llama-2-70b-chat-hf'] + + @classmethod + def create_completion( + cls, + model: str, + messages: list[dict[str, str]], + stream: bool = False, + proxy: str = None, + cookies: str = get_cookies(".huggingface.co"), + **kwargs + ) -> CreateResult: + bot = ChatBot( + cookies=cookies + ) + + if proxy and "://" not in proxy: + proxy = f"http://{proxy}" + bot.session.proxies = {"http": proxy, "https": proxy} + + if model: + try: + if not isinstance(model, int): + model = cls.llms.index(model) + bot.switch_llm(model) + except: + raise RuntimeError(f"Model are not supported: {model}") + + if len(messages) > 1: + formatted = "\n".join( + ["%s: %s" % (message["role"], message["content"]) for message in messages] + ) + prompt = f"{formatted}\nAssistant:" + else: + prompt = messages.pop()["content"] + + try: + yield bot.chat(prompt, **kwargs) + finally: + bot.delete_conversation(bot.current_conversation) + bot.current_conversation = "" + pass + + @classmethod + @property + def params(cls): + params = [ + ("model", "str"), + ("messages", "list[dict[str, str]]"), + ("stream", "bool"), + ("proxy", "str"), + ] + param = ", ".join([": ".join(p) for p in params]) + return f"g4f.provider.{cls.__name__} supports: ({param})" -- cgit v1.2.3