Spaces:
Starting
Starting
File size: 7,287 Bytes
751d628 4701375 751d628 4701375 751d628 4701375 751d628 4701375 751d628 4701375 751d628 4701375 751d628 4701375 751d628 4701375 751d628 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import json
import os
import logging
import torch
from typing import List
from langchain_core.documents import Document
from sentence_transformers import SentenceTransformer
try:
from datasets import load_dataset
except ImportError:
load_dataset = None
logger = logging.getLogger(__name__)
def get_device():
"""
Determine the appropriate device for PyTorch.
Returns:
str: Device name ('cuda', 'mps', or 'cpu').
"""
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"
def load_guest_dataset(dataset_path: str = "agents-course/unit3-invitees") -> List[Document]:
"""
Load guest dataset from a local JSON file or Hugging Face dataset.
Args:
dataset_path (str): Path to local JSON file or Hugging Face dataset name.
Returns:
List[Document]: List of Document objects with guest information.
"""
try:
# Try loading from Hugging Face dataset if datasets library is available
if load_dataset and not os.path.exists(dataset_path):
logger.info(f"Attempting to load Hugging Face dataset: {dataset_path}")
guest_dataset = load_dataset(dataset_path, split="train")
docs = [
Document(
page_content="\n".join([
f"Name: {guest['name']}",
f"Relation: {guest['relation']}",
f"Description: {guest['description']}",
f"Email: {guest['email']}"
]),
metadata={
"name": guest["name"],
"relation": guest["relation"],
"description": guest["description"],
"email": guest["email"]
}
)
for guest in guest_dataset
]
logger.info(f"Loaded {len(docs)} guests from Hugging Face dataset")
return docs
# Try loading from local JSON file
if os.path.exists(dataset_path):
logger.info(f"Loading guest dataset from local path: {dataset_path}")
with open(dataset_path, 'r') as f:
guests = json.load(f)
docs = [
Document(
page_content=guest.get('description', ''),
metadata={
'name': guest.get('name', ''),
'relation': guest.get('relation', ''),
'description': guest.get('description', ''),
'email': guest.get('email', '') # Optional email field
}
)
for guest in guests
]
logger.info(f"Loaded {len(docs)} guests from local JSON")
return docs
# Fallback to mock dataset if both fail
logger.warning(f"Dataset not found at {dataset_path}, using mock dataset")
docs = [
Document(
page_content="\n".join([
"Name: Dr. Nikola Tesla",
"Relation: old friend from university days",
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
"Email: nikola.tesla@gmail.com"
]),
metadata={
"name": "Dr. Nikola Tesla",
"relation": "old friend from university days",
"description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
"email": "nikola.tesla@gmail.com"
}
)
]
logger.info("Loaded mock dataset with 1 guest")
return docs
except Exception as e:
logger.error(f"Failed to load guest dataset: {e}")
# Return mock dataset as final fallback
docs = [
Document(
page_content="\n".join([
"Name: Dr. Nikola Tesla",
"Relation: old friend from university days",
"Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
"Email: nikola.tesla@gmail.com"
]),
metadata={
"name": "Dr. Nikola Tesla",
"relation": "old friend from university days",
"description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
"email": "nikola.tesla@gmail.com"
}
)
]
logger.info("Loaded mock dataset with 1 guest due to error")
return docs
class BM25Retriever:
"""
A retriever class using SentenceTransformer for embedding-based search.
"""
def __init__(self, dataset_path: str):
"""
Initialize the retriever with a SentenceTransformer model.
Args:
dataset_path (str): Path to the dataset for retrieval.
Raises:
Exception: If embedder initialization fails.
"""
try:
self.model = SentenceTransformer("all-MiniLM-L6-v2", device=get_device())
self.dataset_path = dataset_path
logger.info("Initialized SentenceTransformer")
except Exception as e:
logger.error(f"Failed to initialize embedder: {e}")
raise
def search(self, query: str) -> List[dict]:
"""
Search the dataset for relevant guest information.
Args:
query (str): Search query (e.g., guest name or relation).
Returns:
List[dict]: List of matching guest metadata dictionaries.
"""
try:
# Load dataset
docs = load_guest_dataset(self.dataset_path)
if not docs:
logger.warning("No documents available for search")
return []
# Convert documents to text for BM25 (using metadata for consistency)
texts = [f"{doc.metadata['name']} {doc.metadata['relation']} {doc.metadata['description']}" for doc in docs]
from langchain_community.retrievers import BM25Retriever
retriever = BM25Retriever.from_texts(texts)
retriever.k = 3 # Limit to top 3 results
# Perform search
results = retriever.invoke(query)
# Map results back to original metadata
matches = [
docs[i].metadata
for i in range(len(docs))
if any(f"{docs[i].metadata['name']} {docs[i].metadata['relation']} {docs[i].metadata['description']}" in r.page_content for r in results)
]
logger.info(f"Found {len(matches)} matches for query: {query}")
return matches[:3] # Return top 3 matches
except Exception as e:
logger.error(f"Search failed for query '{query}': {e}")
return [] |