summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHeiner Lohaus <hlohaus@users.noreply.github.com>2024-04-13 18:02:47 +0200
committerHeiner Lohaus <hlohaus@users.noreply.github.com>2024-04-13 18:02:47 +0200
commit0a3fe0b7df7b2779ae795c8590d2a707d6672345 (patch)
treeec033e0233d18846cee37196e2a2615377af7c49
parentFix GPT4All import error (diff)
downloadgpt4free-0a3fe0b7df7b2779ae795c8590d2a707d6672345.tar
gpt4free-0a3fe0b7df7b2779ae795c8590d2a707d6672345.tar.gz
gpt4free-0a3fe0b7df7b2779ae795c8590d2a707d6672345.tar.bz2
gpt4free-0a3fe0b7df7b2779ae795c8590d2a707d6672345.tar.lz
gpt4free-0a3fe0b7df7b2779ae795c8590d2a707d6672345.tar.xz
gpt4free-0a3fe0b7df7b2779ae795c8590d2a707d6672345.tar.zst
gpt4free-0a3fe0b7df7b2779ae795c8590d2a707d6672345.zip
-rw-r--r--g4f/locals/provider.py7
-rw-r--r--g4f/providers/base_provider.py9
2 files changed, 12 insertions, 4 deletions
diff --git a/g4f/locals/provider.py b/g4f/locals/provider.py
index 45041539..d9d73455 100644
--- a/g4f/locals/provider.py
+++ b/g4f/locals/provider.py
@@ -66,9 +66,12 @@ class LocalProvider:
if message["role"] != "system"
) + "\nASSISTANT: "
+ def should_not_stop(token_id: int, token: str):
+ return "USER" not in token
+
with model.chat_session(system_message, prompt_template):
if stream:
- for token in model.generate(conversation, streaming=True):
+ for token in model.generate(conversation, streaming=True, callback=should_not_stop):
yield token
else:
- yield model.generate(conversation) \ No newline at end of file
+ yield model.generate(conversation, callback=should_not_stop) \ No newline at end of file
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py
index 86789ec2..cb60d78f 100644
--- a/g4f/providers/base_provider.py
+++ b/g4f/providers/base_provider.py
@@ -19,8 +19,13 @@ else:
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
if sys.platform == 'win32':
- if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
- asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+ try:
+ from curl_cffi import aio
+ if not hasattr(aio, "_get_selector"):
+ if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+ except ImportError:
+ pass
def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
try: