Soumil30 commited on
Commit
02fb41e
·
1 Parent(s): a7f6abf

added a class to use multiple models

Browse files
Files changed (1) hide show
  1. lightrag/llm.py +69 -0
lightrag/llm.py CHANGED
@@ -13,6 +13,8 @@ from tenacity import (
13
  )
14
  from transformers import AutoTokenizer, AutoModelForCausalLM
15
  import torch
 
 
16
  from .base import BaseKVStorage
17
  from .utils import compute_args_hash, wrap_embedding_func_with_attrs
18
 
@@ -423,6 +425,73 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
423
 
424
  return embed_text
425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
  if __name__ == "__main__":
428
  import asyncio
 
13
  )
14
  from transformers import AutoTokenizer, AutoModelForCausalLM
15
  import torch
16
+ from pydantic import BaseModel, Field
17
+ from typing import List, Dict, Callable, Any
18
  from .base import BaseKVStorage
19
  from .utils import compute_args_hash, wrap_embedding_func_with_attrs
20
 
 
425
 
426
  return embed_text
427
 
428
+ class Model(BaseModel):
429
+ """
430
+ This is a Pydantic model class named 'Model' that is used to define a custom language model.
431
+
432
+ Attributes:
433
+ gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
434
+ The function should take any argument and return a string.
435
+ kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
436
+ This could include parameters such as the model name, API key, etc.
437
+
438
+ Example usage:
439
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
440
+
441
+ In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
442
+ The 'kwargs' dictionary contains the model name and API key to be passed to the function.
443
+ """
444
+
445
+ gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
446
+ kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
447
+
448
+ class Config:
449
+ arbitrary_types_allowed = True
450
+
451
+
452
+ class MultiModel():
453
+ """
454
+ Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
455
+ Could also be used for spliting across diffrent models or providers.
456
+
457
+ Attributes:
458
+ models (List[Model]): A list of language models to be used.
459
+
460
+ Usage example:
461
+ ```python
462
+ models = [
463
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
464
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
465
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
466
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
467
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
468
+ ]
469
+ multi_model = MultiModel(models)
470
+ rag = LightRAG(
471
+ llm_model_func=multi_model.llm_model_func
472
+ / ..other args
473
+ )
474
+ ```
475
+ """
476
+ def __init__(self, models: List[Model]):
477
+ self._models = models
478
+ self._current_model = 0
479
+
480
+ def _next_model(self):
481
+ self._current_model = (self._current_model + 1) % len(self._models)
482
+ return self._models[self._current_model]
483
+
484
+ async def llm_model_func(
485
+ self,
486
+ prompt, system_prompt=None, history_messages=[], **kwargs
487
+ ) -> str:
488
+ kwargs.pop("model", None) # stop from overwriting the custom model name
489
+ next_model = self._next_model()
490
+ args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
491
+
492
+ return await next_model.gen_func(
493
+ **args
494
+ )
495
 
496
  if __name__ == "__main__":
497
  import asyncio