import json import os import gradio as gr import time from pydantic_settings import BaseSettings, validator from typing import Any, Optional, Dict, List from huggingface_hub import InferenceClient from langchain.llms.base import LLM from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.vectorstores import Chroma from transformers import AutoTokenizer from transformers import Tool from dotenv import load_dotenv load_dotenv() path_work = "." hf_token = os.getenv("HF") embeddings = HuggingFaceInstructEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"} ) vectordb = Chroma( persist_directory=path_work + '/new_papers', embedding_function=embeddings ) retriever = vectordb.as_retriever(search_kwargs={"k": 2})#5 class KwArgsModel(BaseModel): kwargs: Dict[str, Any] = Field(default_factory=dict) class CustomInferenceClient(LLM, KwArgsModel): model_name: str inference_client: InferenceClient def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): inference_client = InferenceClient(model=model_name, token=hf_token) super().__init__( model_name=model_name, hf_token=hf_token, kwargs=kwargs, inference_client=inference_client ) def _call( self, prompt: str, stop: Optional[List[str]] = None ) -> str: if stop is not None: raise ValueError("stop kwargs are not permitted.") response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True) response = ''.join(response_gen) return response @property def _llm_type(self) -> str: return "custom" @property def _identifying_params(self) -> dict: return {"model_name": self.model_name} kwargs = {"max_new_tokens": 256, "temperature": 0.9, "top_p": 0.6, "repetition_penalty": 1.3, "do_sample": True} model_list = [ "meta-llama/Llama-2-13b-chat-hf", "HuggingFaceH4/zephyr-7b-alpha", "meta-llama/Llama-2-70b-chat-hf", "tiiuae/falcon-180B-chat" ] qa_chain = None def load_model(model_selected): global qa_chain model_name = model_selected llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs) from langchain.chains import RetrievalQA qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, verbose=True, ) return qa_chain load_model("meta-llama/Llama-2-70b-chat-hf") ########## ##### ######### ### ### ### def predict(message, temperature=0.9, max_new_tokens=512, top_p=0.6, repetition_penalty=1.3): temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) llm_response = qa_chain(message) res_result = llm_response['result'] res_relevant_doc = [source.metadata['source'] for source in llm_response["source_documents"]] response = f"{res_result}" + "\n\n" + "[Answer Source Documents (Ctrl + Click!)] :" + "\n" + f" \n {res_relevant_doc}" print("response: =====> \n", response, "\n\n") tokens = response.split('\n') token_list = [] for idx, token in enumerate(tokens): token_dict = {"id": idx + 1, "text": token} token_list.append(token_dict) response = {"data": {"token": token_list}} response = json.dumps(response, indent=4) response = json.loads(response) data_dict = response.get('data', {}) token_list = data_dict.get('token', []) partial_message = "" for token_entry in token_list: if token_entry: try: token_id = token_entry.get('id', None) token_text = token_entry.get('text', None) if token_text: for char in token_text: partial_message += char yield partial_message time.sleep(0.01) else: print(f"Warning ==> The key 'text' does not exist or is None in this token entry: {token_entry}") pass except KeyError as e: gr.Warning(f"KeyError: {e} occurred for token entry: {token_entry}") continue class TextGeneratorTool(Tool): name = "vector_retriever" description = "This tool searches in a vector store based on a given prompt." inputs = ["prompt"] outputs = ["generated_text"] def __init__(self): #self.retriever = db.as_retriever(search_kwargs={"k": 1}) pass # You might want to add some initialization logic here def __call__(self, prompt: str): result = predict(prompt, 0.9, 512, 0.6, 1.4) return result