Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
@@ -29,7 +29,7 @@ import json
|
|
29 |
from langchain_core.documents import Document
|
30 |
from langchain_community.vectorstores import FAISS
|
31 |
from langchain.vectorstores import FAISS
|
32 |
-
from langchain.embeddings import BERTEmbeddings
|
33 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
34 |
|
35 |
from youtube_transcript_api import YouTubeTranscriptApi
|
@@ -343,37 +343,6 @@ for name in enabled_tool_names:
|
|
343 |
tools.append(tool_map[name])
|
344 |
|
345 |
|
346 |
-
|
347 |
-
# -------------------------------
|
348 |
-
# Set up BERT Embeddings
|
349 |
-
# -------------------------------
|
350 |
-
|
351 |
-
# -----------------------------
|
352 |
-
# Define Custom BERT Embedding Model
|
353 |
-
# -----------------------------
|
354 |
-
import torch
|
355 |
-
import torch.nn.functional as F
|
356 |
-
from transformers import BertTokenizer, BertModel
|
357 |
-
|
358 |
-
class BERTEmbeddings:
|
359 |
-
def __init__(self, model_name='bert-base-uncased'):
|
360 |
-
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
361 |
-
self.model = BertModel.from_pretrained(model_name)
|
362 |
-
self.model.eval() # Set to evaluation mode
|
363 |
-
|
364 |
-
def embed_documents(self, texts):
|
365 |
-
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
|
366 |
-
with torch.no_grad():
|
367 |
-
outputs = self.model(**inputs)
|
368 |
-
embeddings = outputs.last_hidden_state.mean(dim=1)
|
369 |
-
embeddings = F.normalize(embeddings, p=2, dim=1) # Normalize for cosine similarity
|
370 |
-
return embeddings.cpu().numpy()
|
371 |
-
|
372 |
-
def embed_query(self, text):
|
373 |
-
return self.embed_documents([text])[0]
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
# -----------------------------
|
378 |
# Create FAISS Vector Store
|
379 |
# -----------------------------
|
|
|
29 |
from langchain_core.documents import Document
|
30 |
from langchain_community.vectorstores import FAISS
|
31 |
from langchain.vectorstores import FAISS
|
32 |
+
#from langchain.embeddings import BERTEmbeddings
|
33 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
34 |
|
35 |
from youtube_transcript_api import YouTubeTranscriptApi
|
|
|
343 |
tools.append(tool_map[name])
|
344 |
|
345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
# -----------------------------
|
347 |
# Create FAISS Vector Store
|
348 |
# -----------------------------
|