added a class to use multiple models
Browse files- lightrag/llm.py +69 -0
lightrag/llm.py
CHANGED
@@ -13,6 +13,8 @@ from tenacity import (
|
|
13 |
)
|
14 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
15 |
import torch
|
|
|
|
|
16 |
from .base import BaseKVStorage
|
17 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
18 |
|
@@ -423,6 +425,73 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
|
423 |
|
424 |
return embed_text
|
425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
if __name__ == "__main__":
|
428 |
import asyncio
|
|
|
13 |
)
|
14 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
15 |
import torch
|
16 |
+
from pydantic import BaseModel, Field
|
17 |
+
from typing import List, Dict, Callable, Any
|
18 |
from .base import BaseKVStorage
|
19 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
20 |
|
|
|
425 |
|
426 |
return embed_text
|
427 |
|
428 |
+
class Model(BaseModel):
|
429 |
+
"""
|
430 |
+
This is a Pydantic model class named 'Model' that is used to define a custom language model.
|
431 |
+
|
432 |
+
Attributes:
|
433 |
+
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
|
434 |
+
The function should take any argument and return a string.
|
435 |
+
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
436 |
+
This could include parameters such as the model name, API key, etc.
|
437 |
+
|
438 |
+
Example usage:
|
439 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
|
440 |
+
|
441 |
+
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
|
442 |
+
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
|
443 |
+
"""
|
444 |
+
|
445 |
+
gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
|
446 |
+
kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
|
447 |
+
|
448 |
+
class Config:
|
449 |
+
arbitrary_types_allowed = True
|
450 |
+
|
451 |
+
|
452 |
+
class MultiModel():
|
453 |
+
"""
|
454 |
+
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
|
455 |
+
Could also be used for spliting across diffrent models or providers.
|
456 |
+
|
457 |
+
Attributes:
|
458 |
+
models (List[Model]): A list of language models to be used.
|
459 |
+
|
460 |
+
Usage example:
|
461 |
+
```python
|
462 |
+
models = [
|
463 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
|
464 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
|
465 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
|
466 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
|
467 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
|
468 |
+
]
|
469 |
+
multi_model = MultiModel(models)
|
470 |
+
rag = LightRAG(
|
471 |
+
llm_model_func=multi_model.llm_model_func
|
472 |
+
/ ..other args
|
473 |
+
)
|
474 |
+
```
|
475 |
+
"""
|
476 |
+
def __init__(self, models: List[Model]):
|
477 |
+
self._models = models
|
478 |
+
self._current_model = 0
|
479 |
+
|
480 |
+
def _next_model(self):
|
481 |
+
self._current_model = (self._current_model + 1) % len(self._models)
|
482 |
+
return self._models[self._current_model]
|
483 |
+
|
484 |
+
async def llm_model_func(
|
485 |
+
self,
|
486 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
487 |
+
) -> str:
|
488 |
+
kwargs.pop("model", None) # stop from overwriting the custom model name
|
489 |
+
next_model = self._next_model()
|
490 |
+
args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
|
491 |
+
|
492 |
+
return await next_model.gen_func(
|
493 |
+
**args
|
494 |
+
)
|
495 |
|
496 |
if __name__ == "__main__":
|
497 |
import asyncio
|