|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Callable, Any |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
|
class Model(BaseModel): |
|
|
""" |
|
|
This is a Pydantic model class named 'Model' that is used to define a custom language model. |
|
|
|
|
|
Attributes: |
|
|
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model. |
|
|
The function should take any argument and return a string. |
|
|
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. |
|
|
This could include parameters such as the model name, API key, etc. |
|
|
|
|
|
Example usage: |
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}) |
|
|
|
|
|
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model. |
|
|
The 'kwargs' dictionary contains the model name and API key to be passed to the function. |
|
|
""" |
|
|
|
|
|
gen_func: Callable[[Any], str] = Field( |
|
|
..., |
|
|
description="A function that generates the response from the llm. The response must be a string", |
|
|
) |
|
|
kwargs: dict[str, Any] = Field( |
|
|
..., |
|
|
description="The arguments to pass to the callable function. Eg. the api key, model name, etc", |
|
|
) |
|
|
|
|
|
class Config: |
|
|
arbitrary_types_allowed = True |
|
|
|
|
|
|
|
|
class MultiModel: |
|
|
""" |
|
|
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. |
|
|
Could also be used for spliting across diffrent models or providers. |
|
|
|
|
|
Attributes: |
|
|
models (List[Model]): A list of language models to be used. |
|
|
|
|
|
Usage example: |
|
|
```python |
|
|
models = [ |
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}), |
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}), |
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}), |
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}), |
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}), |
|
|
] |
|
|
multi_model = MultiModel(models) |
|
|
rag = LightRAG( |
|
|
llm_model_func=multi_model.llm_model_func |
|
|
/ ..other args |
|
|
) |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__(self, models: list[Model]): |
|
|
self._models = models |
|
|
self._current_model = 0 |
|
|
|
|
|
def _next_model(self): |
|
|
self._current_model = (self._current_model + 1) % len(self._models) |
|
|
return self._models[self._current_model] |
|
|
|
|
|
async def llm_model_func( |
|
|
self, |
|
|
prompt: str, |
|
|
system_prompt: str | None = None, |
|
|
history_messages: list[dict[str, Any]] = [], |
|
|
**kwargs: Any, |
|
|
) -> str: |
|
|
kwargs.pop("model", None) |
|
|
kwargs.pop("keyword_extraction", None) |
|
|
kwargs.pop("mode", None) |
|
|
next_model = self._next_model() |
|
|
args = dict( |
|
|
prompt=prompt, |
|
|
system_prompt=system_prompt, |
|
|
history_messages=history_messages, |
|
|
**kwargs, |
|
|
**next_model.kwargs, |
|
|
) |
|
|
|
|
|
return await next_model.gen_func(**args) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import asyncio |
|
|
|
|
|
async def main(): |
|
|
from lightrag.llm.openai import gpt_4o_mini_complete |
|
|
|
|
|
result = await gpt_4o_mini_complete("How are you?") |
|
|
print(result) |
|
|
|
|
|
asyncio.run(main()) |
|
|
|