|
import json |
|
import os |
|
import gradio as gr |
|
import time |
|
from pydantic import BaseModel, Field |
|
from typing import Any, Optional, Dict, List |
|
from huggingface_hub import InferenceClient |
|
from langchain.llms.base import LLM |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
from langchain.vectorstores import Chroma |
|
from transformers import AutoTokenizer |
|
from transformers import Tool |
|
|
|
load_dotenv() |
|
|
|
path_work = "." |
|
hf_token = os.getenv("HF") |
|
|
|
embeddings = HuggingFaceInstructEmbeddings( |
|
model_name="sentence-transformers/all-MiniLM-L6-v2", |
|
model_kwargs={"device": "cpu"} |
|
) |
|
|
|
vectordb = Chroma( |
|
persist_directory=path_work + '/new_papers', |
|
embedding_function=embeddings |
|
) |
|
|
|
retriever = vectordb.as_retriever(search_kwargs={"k": 2}) |
|
|
|
class KwArgsModel(BaseModel): |
|
kwargs: Dict[str, Any] = Field(default_factory=dict) |
|
|
|
class CustomInferenceClient(LLM, KwArgsModel): |
|
model_name: str |
|
inference_client: InferenceClient |
|
|
|
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): |
|
inference_client = InferenceClient(model=model_name, token=hf_token) |
|
super().__init__( |
|
model_name=model_name, |
|
hf_token=hf_token, |
|
kwargs=kwargs, |
|
inference_client=inference_client |
|
) |
|
|
|
def _call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None |
|
) -> str: |
|
if stop is not None: |
|
raise ValueError("stop kwargs are not permitted.") |
|
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True) |
|
response = ''.join(response_gen) |
|
return response |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |
|
|
|
@property |
|
def _identifying_params(self) -> dict: |
|
return {"model_name": self.model_name} |
|
|
|
kwargs = {"max_new_tokens": 256, "temperature": 0.9, "top_p": 0.6, "repetition_penalty": 1.3, "do_sample": True} |
|
|
|
model_list = [ |
|
"meta-llama/Llama-2-13b-chat-hf", |
|
"HuggingFaceH4/zephyr-7b-alpha", |
|
"meta-llama/Llama-2-70b-chat-hf", |
|
"tiiuae/falcon-180B-chat" |
|
] |
|
|
|
qa_chain = None |
|
|
|
def load_model(model_selected): |
|
global qa_chain |
|
model_name = model_selected |
|
llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs) |
|
|
|
from langchain.chains import RetrievalQA |
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
return_source_documents=True, |
|
verbose=True, |
|
) |
|
return qa_chain |
|
|
|
load_model("meta-llama/Llama-2-70b-chat-hf") |
|
|
|
|
|
|
|
|
|
|
|
from langchain.document_loaders import PyPDFDirectoryLoader |
|
from langchain.document_loaders.utils import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
|
|
def load_and_process_pdfs(directory_path: str, chunk_size: int = 500, chunk_overlap: int = 200, collection_name: str = "my-collection"): |
|
|
|
loader = PyPDFDirectoryLoader(directory_path) |
|
documents = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) |
|
texts = text_splitter.split_documents(documents) |
|
|
|
|
|
db = Chroma.from_documents(texts, hf, collection_name=collection_name) |
|
|
|
return db |
|
|
|
|
|
load_and_process_pdfs("new_papers/") |
|
|
|
|
|
|
|
|
|
|
|
def predict(message, temperature=0.9, max_new_tokens=512, top_p=0.6, repetition_penalty=1.3): |
|
temperature = float(temperature) |
|
if temperature < 1e-2: temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
llm_response = qa_chain(message) |
|
res_result = llm_response['result'] |
|
|
|
res_relevant_doc = [source.metadata['source'] for source in llm_response["source_documents"]] |
|
response = f"{res_result}" + "\n\n" + "[Answer Source Documents (Ctrl + Click!)] :" + "\n" + f" \n {res_relevant_doc}" |
|
print("response: =====> \n", response, "\n\n") |
|
|
|
tokens = response.split('\n') |
|
token_list = [] |
|
for idx, token in enumerate(tokens): |
|
token_dict = {"id": idx + 1, "text": token} |
|
token_list.append(token_dict) |
|
response = {"data": {"token": token_list}} |
|
response = json.dumps(response, indent=4) |
|
|
|
response = json.loads(response) |
|
data_dict = response.get('data', {}) |
|
token_list = data_dict.get('token', []) |
|
|
|
partial_message = "" |
|
for token_entry in token_list: |
|
if token_entry: |
|
try: |
|
token_id = token_entry.get('id', None) |
|
token_text = token_entry.get('text', None) |
|
|
|
if token_text: |
|
for char in token_text: |
|
partial_message += char |
|
yield partial_message |
|
time.sleep(0.01) |
|
else: |
|
print(f"[[워닝]] ==> The key 'text' does not exist or is None in this token entry: {token_entry}") |
|
pass |
|
|
|
except KeyError as e: |
|
gr.Warning(f"KeyError: {e} occurred for token entry: {token_entry}") |
|
continue |
|
|
|
class TextGeneratorTool(Tool): |
|
name = "vector_retriever" |
|
description = "This tool searches in a vector store based on a given prompt." |
|
inputs = ["prompt"] |
|
outputs = ["generated_text"] |
|
|
|
def __init__(self): |
|
|
|
|
|
def __call__(self, prompt: str): |
|
result = predict(prompt, 0.9, 512, 0.6, 1.4) |
|
return result |
|
|
|
|
|
|
|
|
|
|