Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
@@ -374,52 +374,107 @@ async def start_questions(request: Request):
|
|
374 |
# -----------------------------
|
375 |
# 1. Define Custom BERT Embedding Model
|
376 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
377 |
class BERTEmbeddings(Embeddings):
|
378 |
-
def __init__(self, model_name='bert-base-uncased'):
|
|
|
379 |
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
380 |
self.model = BertModel.from_pretrained(model_name)
|
381 |
self.model.eval() # Set model to eval mode
|
|
|
|
|
382 |
|
383 |
def embed_documents(self, texts):
|
384 |
-
|
|
|
|
|
|
|
385 |
with torch.no_grad():
|
386 |
outputs = self.model(**inputs)
|
|
|
|
|
387 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
388 |
-
|
|
|
|
|
|
|
|
|
389 |
return embeddings.cpu().numpy()
|
390 |
|
391 |
def embed_query(self, text):
|
|
|
392 |
return self.embed_documents([text])[0]
|
393 |
|
394 |
|
395 |
# -----------------------------
|
396 |
# 2. Initialize Embedding Model
|
397 |
# -----------------------------
|
398 |
-
embedding_model = BERTEmbeddings()
|
399 |
-
|
400 |
|
401 |
# -----------------------------
|
402 |
-
#
|
403 |
# -----------------------------
|
404 |
-
docs = [
|
405 |
-
Document(page_content="Mercedes Sosa released many albums between 2000 and 2009.", metadata={"id": 1}),
|
406 |
-
Document(page_content="She was a prominent Argentine folk singer.", metadata={"id": 2}),
|
407 |
-
Document(page_content="Her album 'Al Despertar' was released in 1998.", metadata={"id": 3}),
|
408 |
-
Document(page_content="She continued releasing music well into the 2000s.", metadata={"id": 4}),
|
409 |
-
]
|
410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
|
412 |
# -----------------------------
|
413 |
-
#
|
414 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
vector_store = FAISS.from_documents(docs, embedding_model)
|
416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
|
419 |
# -----------------------------
|
420 |
# 6. Create LangChain Retriever Tool
|
421 |
# -----------------------------
|
422 |
-
|
|
|
423 |
|
424 |
question_retriever_tool = create_retriever_tool(
|
425 |
retriever=retriever,
|
@@ -1052,6 +1107,8 @@ def process_all_tasks(tasks: list):
|
|
1052 |
## Langgraph
|
1053 |
|
1054 |
# Build graph function
|
|
|
|
|
1055 |
provider = "huggingface"
|
1056 |
|
1057 |
model_config = {
|
|
|
374 |
# -----------------------------
|
375 |
# 1. Define Custom BERT Embedding Model
|
376 |
# -----------------------------
|
377 |
+
import torch
|
378 |
+
import torch.nn.functional as F
|
379 |
+
from transformers import BertTokenizer, BertModel
|
380 |
+
from langchain.embeddings import Embeddings
|
381 |
+
|
382 |
class BERTEmbeddings(Embeddings):
|
383 |
+
def __init__(self, model_name='bert-base-uncased', device='cpu'):
|
384 |
+
# Initialize the tokenizer and model
|
385 |
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
386 |
self.model = BertModel.from_pretrained(model_name)
|
387 |
self.model.eval() # Set model to eval mode
|
388 |
+
self.device = device
|
389 |
+
self.model.to(self.device) # Move model to the specified device (CPU or GPU)
|
390 |
|
391 |
def embed_documents(self, texts):
|
392 |
+
# Tokenize the input texts
|
393 |
+
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
394 |
+
inputs = {key: value.to(self.device) for key, value in inputs.items()} # Move inputs to the specified device
|
395 |
+
|
396 |
with torch.no_grad():
|
397 |
outputs = self.model(**inputs)
|
398 |
+
|
399 |
+
# Get the embeddings by averaging the last hidden state across tokens
|
400 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
401 |
+
|
402 |
+
# Normalize embeddings for cosine similarity
|
403 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
404 |
+
|
405 |
+
# Return the embeddings as numpy array
|
406 |
return embeddings.cpu().numpy()
|
407 |
|
408 |
def embed_query(self, text):
|
409 |
+
# Embed a single query (text)
|
410 |
return self.embed_documents([text])[0]
|
411 |
|
412 |
|
413 |
# -----------------------------
|
414 |
# 2. Initialize Embedding Model
|
415 |
# -----------------------------
|
|
|
|
|
416 |
|
417 |
# -----------------------------
|
418 |
+
# Create FAISS Vector Store
|
419 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
+
class MyVectorStore:
|
422 |
+
def __init__(self, index: faiss.Index):
|
423 |
+
self.index = index
|
424 |
+
|
425 |
+
def save_local(self, path: str):
|
426 |
+
# Save the FAISS index to the specified file
|
427 |
+
faiss.write_index(self.index, "/home/wendy/Downloads")
|
428 |
+
print(f"Index saved to {path}")
|
429 |
+
|
430 |
+
@classmethod
|
431 |
+
def load_local(cls, path: str):
|
432 |
+
# Load the FAISS index from the specified file
|
433 |
+
index = faiss.read_index(path)
|
434 |
+
return cls(index)
|
435 |
|
436 |
# -----------------------------
|
437 |
+
# 3. Prepare Documents
|
438 |
# -----------------------------
|
439 |
+
# Define the URL where the JSON file is hosted
|
440 |
+
url = "https://agents-course-unit4-scoring.hf.space/questions"
|
441 |
+
|
442 |
+
# Download the JSON file from the URL
|
443 |
+
response = requests.get(url)
|
444 |
+
response.raise_for_status() # Ensure that the request was successful
|
445 |
+
|
446 |
+
# Parse the JSON data
|
447 |
+
docs = json.loads(response.text)
|
448 |
+
|
449 |
+
# Assuming the JSON structure has a 'text' field for each document
|
450 |
+
texts = [doc['text'] for doc in docs] # Extract text from JSON
|
451 |
+
|
452 |
+
# Initialize the embedding model
|
453 |
+
embedding_model = BERTEmbeddings()
|
454 |
+
|
455 |
+
# Generate embeddings for each document
|
456 |
+
embeddings = [embedding_model.encode(text) for text in texts]
|
457 |
+
|
458 |
+
# Create the FAISS index
|
459 |
vector_store = FAISS.from_documents(docs, embedding_model)
|
460 |
+
|
461 |
+
# Save the FAISS index
|
462 |
+
vector_store = MyVectorStore(index)
|
463 |
+
vector_store.save_local("/home/wt/Downloads/faiss_index.index")
|
464 |
+
|
465 |
+
# Load the FAISS index later
|
466 |
+
loaded_vector_store = MyVectorStore.load_local("faiss_index.index")
|
467 |
+
|
468 |
+
|
469 |
+
|
470 |
+
|
471 |
|
472 |
|
473 |
# -----------------------------
|
474 |
# 6. Create LangChain Retriever Tool
|
475 |
# -----------------------------
|
476 |
+
|
477 |
+
retriever = FAISS.load_local("faiss_index.index", embedding_model).as_retriever()
|
478 |
|
479 |
question_retriever_tool = create_retriever_tool(
|
480 |
retriever=retriever,
|
|
|
1107 |
## Langgraph
|
1108 |
|
1109 |
# Build graph function
|
1110 |
+
vector_store = vector_store.save_local("faiss_index")
|
1111 |
+
|
1112 |
provider = "huggingface"
|
1113 |
|
1114 |
model_config = {
|