Spaces:
Runtime error
Runtime error
import os | |
from langchain.document_loaders import TextLoader, DirectoryLoader | |
from langchain.vectorstores import FAISS | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import torch | |
import numpy as np | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
from datetime import datetime | |
import gradio as gr | |
class DocumentRetrievalAndGeneration: | |
def __init__(self, embedding_model_name, lm_model_id, data_folder, faiss_index_path): | |
self.documents = self.load_documents(data_folder) | |
self.embeddings = SentenceTransformer(embedding_model_name) | |
self.gpu_index = self.load_faiss_index(faiss_index_path) | |
self.llm = self.initialize_llm(lm_model_id) | |
def load_documents(self, folder_path): | |
loader = DirectoryLoader(folder_path, loader_cls=TextLoader) | |
documents = loader.load() | |
print('Length of documents:', len(documents)) | |
return documents | |
def load_faiss_index(self, faiss_index_path): | |
cpu_index = faiss.read_index(faiss_index_path) | |
gpu_resource = faiss.StandardGpuResources() | |
gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, cpu_index) | |
return gpu_index | |
def initialize_llm(self, model_id): | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
generate_text = pipeline( | |
model=model, | |
tokenizer=tokenizer, | |
return_full_text=True, | |
task='text-generation', | |
temperature=0.6, | |
max_new_tokens=2048, | |
) | |
return generate_text | |
def query_and_generate_response(self, query): | |
query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy() | |
distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5) | |
content = "" | |
for idx in indices[0]: | |
content += "-" * 50 + "\n" | |
content += self.documents[idx].page_content + "\n" | |
print(self.documents[idx].page_content) | |
print("############################") | |
prompt = f"Query: {query}\nSolution: {content}\n" | |
# Encode and prepare inputs | |
messages = [{"role": "user", "content": prompt}] | |
encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
model_inputs = encodeds.to(self.llm.device) | |
# Perform inference and measure time | |
start_time = datetime.now() | |
generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True) | |
elapsed_time = datetime.now() - start_time | |
# Decode and return output | |
decoded = self.llm.tokenizer.batch_decode(generated_ids) | |
generated_response = decoded[0] | |
print("Generated response:", generated_response) | |
print("Time elapsed:", elapsed_time) | |
print("Device in use:", self.llm.device) | |
return generated_response | |
def qa_infer_gradio(self, query): | |
response = self.query_and_generate_response(query) | |
return response | |
if __name__ == "__main__": | |
# Example usage | |
embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12' | |
lm_model_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
data_folder = 'sample_embedding_folder' | |
faiss_index_path = 'faiss_index_new_model3.index' | |
doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder, faiss_index_path) | |
# Define Gradio interface function | |
def launch_interface(): | |
css_code = """ | |
.gradio-container { | |
background-color: #daccdb; | |
} | |
/* Button styling for all buttons */ | |
button { | |
background-color: #927fc7; /* Default color for all other buttons */ | |
color: black; | |
border: 1px solid black; | |
padding: 10px; | |
margin-right: 10px; | |
font-size: 16px; /* Increase font size */ | |
font-weight: bold; /* Make text bold */ | |
} | |
""" | |
EXAMPLES = ["TDA4 product planning and datasheet release progress? ", | |
"I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?", | |
"Master core in TDA2XX is a15 and in TDA3XX it is m4,so we have to shift all modules that are being used by a15 in TDA2XX to m4 in TDA3xx."] | |
file_path = "ticketNames.txt" | |
# Read the file content | |
with open(file_path, "r") as file: | |
content = file.read() | |
ticket_names = json.loads(content) | |
dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names) | |
# Define Gradio interface | |
interface = gr.Interface( | |
fn=doc_retrieval_gen.qa_infer_gradio, | |
inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")], | |
allow_flagging='never', | |
examples=EXAMPLES, | |
cache_examples=False, | |
outputs=gr.Textbox(label="SOLUTION"), | |
css=css_code | |
) | |
# Launch Gradio interface | |
interface.launch(debug=True) | |
# Launch the interface | |
launch_interface() | |