Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
@@ -39,6 +39,11 @@ from docx import Document as DocxDocument
|
|
39 |
import openpyxl
|
40 |
from io import StringIO
|
41 |
|
|
|
|
|
|
|
|
|
|
|
42 |
load_dotenv()
|
43 |
|
44 |
@tool
|
@@ -313,47 +318,55 @@ for task in tasks:
|
|
313 |
# -------------------------------
|
314 |
# Initialize HuggingFace Embedding model
|
315 |
#embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
316 |
-
embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
-
#
|
338 |
-
|
339 |
-
# -----------------------------
|
340 |
-
wiki_docs = []
|
341 |
-
for doc in docs:
|
342 |
-
try:
|
343 |
-
wiki_results = WikipediaLoader(query=doc.page_content, load_max_docs=1).load()
|
344 |
-
wiki_docs.extend(wiki_results)
|
345 |
-
except Exception as e:
|
346 |
-
print(f"Failed to load Wikipedia for: {doc.page_content} — {e}")
|
347 |
|
348 |
-
|
|
|
|
|
349 |
|
350 |
-
#
|
351 |
-
|
352 |
-
# -----------------------------
|
353 |
-
embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
|
354 |
-
vector_store = FAISS.from_documents(all_docs, embedding_model)
|
355 |
vector_store.save_local("faiss_index")
|
356 |
|
|
|
357 |
# -----------------------------
|
358 |
# Step 4: Create Retriever Tool
|
359 |
# -----------------------------
|
@@ -367,23 +380,6 @@ question_retriever_tool = create_retriever_tool(
|
|
367 |
|
368 |
|
369 |
|
370 |
-
# -------------------------------
|
371 |
-
# Step 5: Create Retriever Tool (for use in LangChain)
|
372 |
-
# -------------------------------
|
373 |
-
retriever = vector_store.as_retriever()
|
374 |
-
|
375 |
-
# Create the retriever tool
|
376 |
-
question_retriever_tool = create_retriever_tool(
|
377 |
-
retriever=retriever,
|
378 |
-
name="Question_Search",
|
379 |
-
description="A tool to retrieve documents related to a user's question."
|
380 |
-
)
|
381 |
-
|
382 |
-
vector_store = FAISS.from_documents(all_docs, embedding_model)
|
383 |
-
vector_store.save_local("faiss_index")
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
def retriever(state: MessagesState):
|
388 |
"""Retriever node using similarity scores for filtering"""
|
389 |
query = state["messages"][0].content
|
|
|
39 |
import openpyxl
|
40 |
from io import StringIO
|
41 |
|
42 |
+
from transformers import BertTokenizer, BertModel
|
43 |
+
import torch
|
44 |
+
#from langchain.embeddings import Embedding
|
45 |
+
|
46 |
+
|
47 |
load_dotenv()
|
48 |
|
49 |
@tool
|
|
|
318 |
# -------------------------------
|
319 |
# Initialize HuggingFace Embedding model
|
320 |
#embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
321 |
+
#embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
|
322 |
+
|
323 |
+
from transformers import BertTokenizer, BertModel
|
324 |
+
import torch
|
325 |
+
from langchain.embeddings import Embedding
|
326 |
+
from langchain.schema import Document
|
327 |
+
|
328 |
+
class BERTEmbedding(Embedding):
|
329 |
+
def __init__(self, model_name='bert-base-uncased'):
|
330 |
+
# Load the pre-trained BERT model and tokenizer
|
331 |
+
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
332 |
+
self.model = BertModel.from_pretrained(model_name)
|
333 |
+
|
334 |
+
def embed(self, texts):
|
335 |
+
# Tokenize and convert texts to input format for BERT
|
336 |
+
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
|
337 |
+
|
338 |
+
# Get the BERT embeddings (we use the last hidden state)
|
339 |
+
with torch.no_grad():
|
340 |
+
outputs = self.model(**inputs)
|
341 |
+
|
342 |
+
# Use the mean of the last layer hidden states as the embedding
|
343 |
+
embeddings = outputs.last_hidden_state.mean(dim=1) # Shape: (batch_size, hidden_dim)
|
344 |
+
|
345 |
+
# Return the embeddings as a list of numpy arrays
|
346 |
+
return embeddings.cpu().numpy().tolist()
|
347 |
+
|
348 |
+
# Example usage of BERTEmbedding with LangChain
|
349 |
+
|
350 |
+
embedding_model = BERTEmbedding(model_name="bert-base-uncased")
|
351 |
+
|
352 |
+
# Sample text (replace with your own text)
|
353 |
+
docs = [
|
354 |
+
Document(page_content="Mercedes Sosa was an Argentine singer and musician."),
|
355 |
+
Document(page_content="The 2000s were a significant decade for music in Latin America.")
|
356 |
+
]
|
357 |
|
358 |
+
# Get the embeddings for the documents
|
359 |
+
embeddings = embedding_model.embed([doc.page_content for doc in docs])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
+
# Now, you can use the embeddings with FAISS or other retrieval systems
|
362 |
+
# For example, with FAISS:
|
363 |
+
from langchain.vectorstores import FAISS
|
364 |
|
365 |
+
# Assuming 'docs' contains your list of documents and 'embedding_model' is the model you created
|
366 |
+
vector_store = FAISS.from_documents(docs, embedding_model)
|
|
|
|
|
|
|
367 |
vector_store.save_local("faiss_index")
|
368 |
|
369 |
+
|
370 |
# -----------------------------
|
371 |
# Step 4: Create Retriever Tool
|
372 |
# -----------------------------
|
|
|
380 |
|
381 |
|
382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
def retriever(state: MessagesState):
|
384 |
"""Retriever node using similarity scores for filtering"""
|
385 |
query = state["messages"][0].content
|