File size: 1,617 Bytes
0cfb5e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from types import ModuleType
from typing import Optional, List, Any, Mapping, Union
import g4f
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
class G4F(LLM):
# Model.model or str
model: Union[type, str]
# Provider.Provider
provider: Optional[ModuleType] = None
auth: Optional[Union[str, bool]] = None
create_kwargs: Optional[dict] = None
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
create_kwargs = {} if self.create_kwargs is None else self.create_kwargs.copy()
if self.model is not None:
create_kwargs["model"] = self.model
if self.provider is not None:
create_kwargs["provider"] = self.provider
if self.auth is not None:
create_kwargs["auth"] = self.auth
text = g4f.ChatCompletion.create(
messages=[{"role": "user", "content": prompt}],
**create_kwargs,
)
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
"provider": self.provider,
"auth": self.auth,
"create_kwargs": self.create_kwargs,
} |