Spaces:
Build error
Build error
refactor vectordb
Browse files
app.py
CHANGED
@@ -98,40 +98,121 @@ class MyEmbeddingFunction(EmbeddingFunction):
|
|
98 |
embeddings = [self.embedding_generator.compute_embeddings(doc) for doc in input]
|
99 |
embeddings = [item for sublist in embeddings for item in sublist]
|
100 |
return embeddings
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
|
129 |
# Initialize clients
|
130 |
intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
|
131 |
embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client)
|
132 |
embedding_function = MyEmbeddingFunction(embedding_generator=embedding_generator)
|
133 |
-
|
134 |
-
|
135 |
def respond(
|
136 |
message,
|
137 |
history: list[tuple[str, str]],
|
|
|
98 |
embeddings = [self.embedding_generator.compute_embeddings(doc) for doc in input]
|
99 |
embeddings = [item for sublist in embeddings for item in sublist]
|
100 |
return embeddings
|
101 |
+
# main.py
|
102 |
+
import os
|
103 |
+
import uuid
|
104 |
+
import gradio as gr
|
105 |
+
import torch
|
106 |
+
import torch.nn.functional as F
|
107 |
+
from torch.nn import DataParallel
|
108 |
+
from torch import Tensor
|
109 |
+
from transformers import AutoTokenizer, AutoModel
|
110 |
+
from huggingface_hub import InferenceClient
|
111 |
+
from openai import OpenAI
|
112 |
+
from langchain_community.document_loaders import UnstructuredFileLoader
|
113 |
+
from chromadb import Documents, EmbeddingFunction, Embeddings
|
114 |
+
from chromadb.config import Settings
|
115 |
+
from chromadb import HttpClient
|
116 |
+
from utils import load_env_variables, parse_and_route
|
117 |
+
from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name
|
118 |
+
|
119 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
|
120 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
121 |
+
os.environ['CUDA_CACHE_DISABLE'] = '1'
|
122 |
+
|
123 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
124 |
+
|
125 |
+
### Utils
|
126 |
+
hf_token, yi_token = load_env_variables()
|
127 |
+
|
128 |
+
def clear_cuda_cache():
|
129 |
+
torch.cuda.empty_cache()
|
130 |
+
|
131 |
+
client = OpenAI(api_key=yi_token, base_url=API_BASE)
|
132 |
+
|
133 |
+
class EmbeddingGenerator:
|
134 |
+
def __init__(self, model_name: str, token: str, intention_client):
|
135 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
136 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True)
|
137 |
+
self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device)
|
138 |
+
self.intention_client = intention_client
|
139 |
+
|
140 |
+
def clear_cuda_cache(self):
|
141 |
+
torch.cuda.empty_cache()
|
142 |
+
|
143 |
+
@spaces.GPU
|
144 |
+
def compute_embeddings(self, input_text: str):
|
145 |
+
# Get the intention
|
146 |
+
intention_completion = self.intention_client.chat.completions.create(
|
147 |
+
model="yi-large",
|
148 |
+
messages=[
|
149 |
+
{"role": "system", "content": intention_prompt},
|
150 |
+
{"role": "user", "content": input_text}
|
151 |
+
]
|
152 |
+
)
|
153 |
+
intention_output = intention_completion.choices[0].message['content']
|
154 |
+
|
155 |
+
# Parse and route the intention
|
156 |
+
parsed_task = parse_and_route(intention_output)
|
157 |
+
selected_task = list(parsed_task.keys())[0]
|
158 |
+
|
159 |
+
# Construct the prompt
|
160 |
+
try:
|
161 |
+
task_description = tasks[selected_task]
|
162 |
+
except KeyError:
|
163 |
+
print(f"Selected task not found: {selected_task}")
|
164 |
+
return f"Error: Task '{selected_task}' not found. Please select a valid task."
|
165 |
+
|
166 |
+
query_prefix = f"Instruct: {task_description}\nQuery: "
|
167 |
+
queries = [input_text]
|
168 |
+
|
169 |
+
# Get the embeddings
|
170 |
+
with torch.no_grad():
|
171 |
+
inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
|
172 |
+
outputs = self.model(**inputs)
|
173 |
+
query_embeddings = outputs.last_hidden_state.mean(dim=1)
|
174 |
+
|
175 |
+
# Normalize embeddings
|
176 |
+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
177 |
+
embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
|
178 |
+
self.clear_cuda_cache()
|
179 |
+
return embeddings_list
|
180 |
+
|
181 |
+
class MyEmbeddingFunction(EmbeddingFunction):
|
182 |
+
def __init__(self, embedding_generator: EmbeddingGenerator):
|
183 |
+
self.embedding_generator = embedding_generator
|
184 |
+
|
185 |
+
def __call__(self, input: Documents) -> Embeddings:
|
186 |
+
embeddings = [self.embedding_generator.compute_embeddings(doc) for doc in input]
|
187 |
+
embeddings = [item for sublist in embeddings for item in sublist]
|
188 |
+
return embeddings
|
189 |
+
|
190 |
+
def load_documents(file_path: str, mode: str = "elements"):
|
191 |
+
loader = UnstructuredFileLoader(file_path, mode=mode)
|
192 |
+
docs = loader.load()
|
193 |
+
return [doc.page_content for doc in docs]
|
194 |
+
|
195 |
+
def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction):
|
196 |
+
client = HttpClient(settings=Settings(allow_reset=True))
|
197 |
+
client.reset() # resets the database
|
198 |
+
collection = client.create_collection(collection_name)
|
199 |
+
return client, collection
|
200 |
+
|
201 |
+
def add_documents_to_chroma(client, collection, documents: list, embedding_function: MyEmbeddingFunction):
|
202 |
+
for doc in documents:
|
203 |
+
collection.add(ids=[str(uuid.uuid1())], documents=[doc], embeddings=embedding_function([doc]))
|
204 |
+
|
205 |
+
def query_chroma(client, collection_name: str, query_text: str, embedding_function: MyEmbeddingFunction):
|
206 |
+
db = Chroma(client=client, collection_name=collection_name, embedding_function=embedding_function)
|
207 |
+
result_docs = db.similarity_search(query_text)
|
208 |
+
return result_docs
|
209 |
|
|
|
210 |
# Initialize clients
|
211 |
intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
|
212 |
embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client)
|
213 |
embedding_function = MyEmbeddingFunction(embedding_generator=embedding_generator)
|
214 |
+
chroma_client, chroma_collection = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function)
|
215 |
+
|
216 |
def respond(
|
217 |
message,
|
218 |
history: list[tuple[str, str]],
|