Tonic's picture
Update app.py (#3)
82e50e0 verified
raw
history blame
No virus
9.87 kB
# app.py
import spaces
from torch.nn import DataParallel
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import InferenceClient
from openai import OpenAI
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_chroma import Chroma
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.config import Settings
import chromadb #import HttpClient
import os
import tempfile
import re
import uuid
import gradio as gr
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
from utils import load_env_variables, parse_and_route, escape_special_characters
from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name, metadata_prompt
load_dotenv()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_CACHE_DISABLE'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Ensure the temporary directory exists
temp_dir = '/tmp/gradio/'
os.makedirs(temp_dir, exist_ok=True)
# Set Gradio cache directory
gr.components.file.GRADIO_CACHE = temp_dir
### Utils
hf_token, yi_token = load_env_variables()
def clear_cuda_cache():
torch.cuda.empty_cache()
client = OpenAI(api_key=yi_token, base_url=API_BASE)
chroma_client = chromadb.Client(Settings())
# Create a collection
chroma_collection = chroma_client.create_collection("all-my-documents")
class EmbeddingGenerator:
def __init__(self, model_name: str, token: str, intention_client):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device)
self.intention_client = intention_client
def clear_cuda_cache(self):
torch.cuda.empty_cache()
@spaces.GPU
def compute_embeddings(self, input_text: str):
escaped_input_text = escape_special_characters(input_text)
intention_completion = self.intention_client.chat.completions.create(
model="yi-large",
messages=[
{"role": "system", "content": escape_special_characters(intention_prompt)},
{"role": "user", "content": escaped_input_text}
]
)
intention_output = intention_completion.choices[0].message.content
# Parse and route the intention
parsed_task = parse_and_route(intention_output)
selected_task = parsed_task
# Construct the prompt
if selected_task in tasks:
task_description = tasks[selected_task]
else:
task_description = tasks["DEFAULT"]
print(f"Selected task not found: {selected_task}")
query_prefix = f"Instruct: {task_description}\nQuery: "
queries = [escaped_input_text]
# Get the metadata
metadata_completion = self.intention_client.chat.completions.create(
model="yi-large",
messages=[
{"role": "system", "content": escape_special_characters(metadata_prompt)},
{"role": "user", "content": escaped_input_text}
]
)
metadata_output = metadata_completion.choices[0].message.content
metadata = self.extract_metadata(metadata_output)
# Get the embeddings
with torch.no_grad():
inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
outputs = self.model(**inputs)
query_embeddings = outputs["sentence_embeddings"].mean(dim=1)
query_embeddings = outputs.last_hidden_state.mean(dim=1)
# Normalize embeddings
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
self.clear_cuda_cache()
return embeddings_list, metadata
def extract_metadata(self, metadata_output: str):
# Regex pattern to extract key-value pairs
pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"')
matches = pattern.findall(metadata_output)
metadata = {key: value for key, value in matches}
return metadata
class MyEmbeddingFunction(EmbeddingFunction):
def __init__(self, model_name: str, token: str, intention_client):
self.model_name = model_name
self.token = token
self.intention_client = intention_client
def create_embedding_generator(self):
return EmbeddingGenerator(self.model_name, self.token, self.intention_client)
def __call__(self, input: Documents) -> (Embeddings, list):
embedding_generator = self.create_embedding_generator()
embeddings_with_metadata = [embedding_generator.compute_embeddings(doc.page_content) for doc in input]
embeddings = [item[0] for item in embeddings_with_metadata]
metadata = [item[1] for item in embeddings_with_metadata]
embeddings_flattened = [emb for sublist in embeddings for emb in sublist]
metadata_flattened = [meta for sublist in metadata for meta in sublist]
return embeddings_flattened, metadata_flattened
def load_documents(file_path: str, mode: str = "elements"):
loader = UnstructuredFileLoader(file_path, mode=mode)
docs = loader.load()
return [doc.page_content for doc in docs]
def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction):
db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function)
return db
def add_documents_to_chroma(documents: list, embedding_function: MyEmbeddingFunction):
for doc in documents:
embeddings, metadata = embedding_function.create_embedding_generator().compute_embeddings(doc)
for embedding, meta in zip(embeddings, metadata):
chroma_collection.add(
ids=[str(uuid.uuid1())],
documents=[doc],
embeddings=[embedding],
metadatas=[meta]
)
def query_chroma(query_text: str, embedding_function: MyEmbeddingFunction):
query_embeddings, query_metadata = embedding_function.create_embedding_generator().compute_embeddings(query_text)
result_docs = chroma_collection.query(
query_texts=[query_text],
n_results=2
)
return result_docs
# Initialize clients
intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client)
embedding_function = MyEmbeddingFunction(model_name=model_name, token=hf_token, intention_client=intention_client)
chroma_db = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
retrieved_text = query_documents(message)
messages = [{"role": "system", "content": escape_special_characters(system_message)}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": f"{retrieved_text}\n\n{escape_special_characters(message)}"})
response = ""
for message in intention_client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
def upload_documents(files):
for file in files:
loader = UnstructuredFileLoader(file.name)
documents = loader.load()
add_documents_to_chroma(documents, embedding_function)
return "Documents uploaded and processed successfully!"
def query_documents(query):
results = query_chroma(query, embedding_function)
return "\n\n".join([result.content for result in results])
with gr.Blocks() as demo:
with gr.Tab("Upload Documents"):
document_upload = gr.File(file_count="multiple", file_types=["document"])
upload_button = gr.Button("Upload and Process")
upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text())
with gr.Tab("Ask Questions"):
with gr.Row():
chat_interface = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
query_input = gr.Textbox(label="Query")
query_button = gr.Button("Query")
query_output = gr.Textbox()
query_button.click(query_documents, inputs=query_input, outputs=query_output)
if __name__ == "__main__":
# os.system("chroma run --host localhost --port 8000 &")
demo.launch()