# app.py import spaces from torch.nn import DataParallel from torch import Tensor from transformers import AutoTokenizer, AutoModel from huggingface_hub import InferenceClient from openai import OpenAI from langchain_community.document_loaders import UnstructuredFileLoader from langchain_chroma import Chroma from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.config import Settings import chromadb #import HttpClient import os import tempfile import re import uuid import gradio as gr import torch import torch.nn.functional as F from dotenv import load_dotenv from utils import load_env_variables, parse_and_route, escape_special_characters from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name, metadata_prompt load_dotenv() os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30' os.environ['CUDA_LAUNCH_BLOCKING'] = '1' os.environ['CUDA_CACHE_DISABLE'] = '1' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Ensure the temporary directory exists temp_dir = '/tmp/gradio/' os.makedirs(temp_dir, exist_ok=True) # Set Gradio cache directory gr.components.file.GRADIO_CACHE = temp_dir ### Utils hf_token, yi_token = load_env_variables() def clear_cuda_cache(): torch.cuda.empty_cache() client = OpenAI(api_key=yi_token, base_url=API_BASE) chroma_client = chromadb.Client(Settings()) # Create a collection chroma_collection = chroma_client.create_collection("all-my-documents") class EmbeddingGenerator: def __init__(self, model_name: str, token: str, intention_client): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True) self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device) self.intention_client = intention_client def clear_cuda_cache(self): torch.cuda.empty_cache() @spaces.GPU def compute_embeddings(self, input_text: str): escaped_input_text = escape_special_characters(input_text) intention_completion = self.intention_client.chat.completions.create( model="yi-large", messages=[ {"role": "system", "content": escape_special_characters(intention_prompt)}, {"role": "user", "content": escaped_input_text} ] ) intention_output = intention_completion.choices[0].message.content # Parse and route the intention parsed_task = parse_and_route(intention_output) selected_task = parsed_task # Construct the prompt if selected_task in tasks: task_description = tasks[selected_task] else: task_description = tasks["DEFAULT"] print(f"Selected task not found: {selected_task}") query_prefix = f"Instruct: {task_description}\nQuery: " queries = [escaped_input_text] # Get the metadata metadata_completion = self.intention_client.chat.completions.create( model="yi-large", messages=[ {"role": "system", "content": escape_special_characters(metadata_prompt)}, {"role": "user", "content": escaped_input_text} ] ) metadata_output = metadata_completion.choices[0].message.content metadata = self.extract_metadata(metadata_output) # Get the embeddings with torch.no_grad(): inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device) outputs = self.model(**inputs) query_embeddings = outputs["sentence_embeddings"].mean(dim=1) query_embeddings = outputs.last_hidden_state.mean(dim=1) # Normalize embeddings query_embeddings = F.normalize(query_embeddings, p=2, dim=1) embeddings_list = query_embeddings.detach().cpu().numpy().tolist() self.clear_cuda_cache() return embeddings_list, metadata def extract_metadata(self, metadata_output: str): # Regex pattern to extract key-value pairs pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"') matches = pattern.findall(metadata_output) metadata = {key: value for key, value in matches} return metadata class MyEmbeddingFunction(EmbeddingFunction): def __init__(self, model_name: str, token: str, intention_client): self.model_name = model_name self.token = token self.intention_client = intention_client def create_embedding_generator(self): return EmbeddingGenerator(self.model_name, self.token, self.intention_client) def __call__(self, input: Documents) -> (Embeddings, list): embedding_generator = self.create_embedding_generator() embeddings_with_metadata = [embedding_generator.compute_embeddings(doc.page_content) for doc in input] embeddings = [item[0] for item in embeddings_with_metadata] metadata = [item[1] for item in embeddings_with_metadata] embeddings_flattened = [emb for sublist in embeddings for emb in sublist] metadata_flattened = [meta for sublist in metadata for meta in sublist] return embeddings_flattened, metadata_flattened def load_documents(file_path: str, mode: str = "elements"): loader = UnstructuredFileLoader(file_path, mode=mode) docs = loader.load() return [doc.page_content for doc in docs] def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction): db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function) return db def add_documents_to_chroma(documents: list, embedding_function: MyEmbeddingFunction): for doc in documents: embeddings, metadata = embedding_function.create_embedding_generator().compute_embeddings(doc) for embedding, meta in zip(embeddings, metadata): chroma_collection.add( ids=[str(uuid.uuid1())], documents=[doc], embeddings=[embedding], metadatas=[meta] ) def query_chroma(query_text: str, embedding_function: MyEmbeddingFunction): query_embeddings, query_metadata = embedding_function.create_embedding_generator().compute_embeddings(query_text) result_docs = chroma_collection.query( query_texts=[query_text], n_results=2 ) return result_docs # Initialize clients intention_client = OpenAI(api_key=yi_token, base_url=API_BASE) embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client) embedding_function = MyEmbeddingFunction(model_name=model_name, token=hf_token, intention_client=intention_client) chroma_db = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function) def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): retrieved_text = query_documents(message) messages = [{"role": "system", "content": escape_special_characters(system_message)}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": f"{retrieved_text}\n\n{escape_special_characters(message)}"}) response = "" for message in intention_client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): token = message.choices[0].delta.content response += token yield response def upload_documents(files): for file in files: loader = UnstructuredFileLoader(file.name) documents = loader.load() add_documents_to_chroma(documents, embedding_function) return "Documents uploaded and processed successfully!" def query_documents(query): results = query_chroma(query, embedding_function) return "\n\n".join([result.content for result in results]) with gr.Blocks() as demo: with gr.Tab("Upload Documents"): document_upload = gr.File(file_count="multiple", file_types=["document"]) upload_button = gr.Button("Upload and Process") upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text()) with gr.Tab("Ask Questions"): with gr.Row(): chat_interface = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are a friendly Chatbot.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), ], ) query_input = gr.Textbox(label="Query") query_button = gr.Button("Query") query_output = gr.Textbox() query_button.click(query_documents, inputs=query_input, outputs=query_output) if __name__ == "__main__": # os.system("chroma run --host localhost --port 8000 &") demo.launch()