File size: 3,963 Bytes
58e1001 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from __future__ import annotations
import asyncio
from asyncio import SelectorEventLoop
from abc import ABC, abstractmethod
import browser_cookie3
from ..typing import AsyncGenerator, CreateResult
class BaseProvider(ABC):
url: str
working = False
needs_auth = False
supports_stream = False
supports_gpt_35_turbo = False
supports_gpt_4 = False
@staticmethod
@abstractmethod
def create_completion(
model: str,
messages: list[dict[str, str]],
stream: bool,
**kwargs
) -> CreateResult:
raise NotImplementedError()
@classmethod
@property
def params(cls):
params = [
("model", "str"),
("messages", "list[dict[str, str]]"),
("stream", "bool"),
]
param = ", ".join([": ".join(p) for p in params])
return f"g4f.provider.{cls.__name__} supports: ({param})"
class AsyncProvider(BaseProvider):
@classmethod
def create_completion(
cls,
model: str,
messages: list[dict[str, str]],
stream: bool = False,
**kwargs
) -> CreateResult:
loop = create_event_loop()
try:
yield loop.run_until_complete(cls.create_async(model, messages, **kwargs))
finally:
loop.close()
@staticmethod
@abstractmethod
async def create_async(
model: str,
messages: list[dict[str, str]],
**kwargs
) -> str:
raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider):
supports_stream = True
@classmethod
def create_completion(
cls,
model: str,
messages: list[dict[str, str]],
stream: bool = True,
**kwargs
) -> CreateResult:
loop = create_event_loop()
try:
generator = cls.create_async_generator(
model,
messages,
stream=stream,
**kwargs
)
gen = generator.__aiter__()
while True:
try:
yield loop.run_until_complete(gen.__anext__())
except StopAsyncIteration:
break
finally:
loop.close()
@classmethod
async def create_async(
cls,
model: str,
messages: list[dict[str, str]],
**kwargs
) -> str:
return "".join([
chunk async for chunk in cls.create_async_generator(
model,
messages,
stream=False,
**kwargs
)
])
@staticmethod
@abstractmethod
def create_async_generator(
model: str,
messages: list[dict[str, str]],
**kwargs
) -> AsyncGenerator:
raise NotImplementedError()
# Don't create a new event loop in a running async loop.
# Force use selector event loop on windows and linux use it anyway.
def create_event_loop() -> SelectorEventLoop:
try:
asyncio.get_running_loop()
except RuntimeError:
return SelectorEventLoop()
raise RuntimeError(
'Use "create_async" instead of "create" function in a async loop.')
_cookies = {}
def get_cookies(cookie_domain: str) -> dict:
if cookie_domain not in _cookies:
_cookies[cookie_domain] = {}
try:
for cookie in browser_cookie3.load(cookie_domain):
_cookies[cookie_domain][cookie.name] = cookie.value
except:
pass
return _cookies[cookie_domain]
def format_prompt(messages: list[dict[str, str]], add_special_tokens=False):
if add_special_tokens or len(messages) > 1:
formatted = "\n".join(
["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages]
)
return f"{formatted}\nAssistant:"
else:
return messages[0]["content"] |