Spaces:
Sleeping
Sleeping
Update agents/research_agent.py
Browse files- agents/research_agent.py +13 -8
agents/research_agent.py
CHANGED
|
@@ -9,16 +9,21 @@ class ResearchAgent:
|
|
| 9 |
"""
|
| 10 |
Initialize the research agent with local Ollama LLM.
|
| 11 |
"""
|
| 12 |
-
print("Initializing ResearchAgent with Hugging Face Transformers...")
|
| 13 |
-
model_name = getattr(settings, "HF_MODEL_RESEARCH", "google/flan-t5-large")
|
| 14 |
-
|
| 15 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 16 |
-
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16)
|
| 17 |
|
| 18 |
-
|
| 19 |
-
self.model.to(self.device)
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def sanitize_response(self, response_text: str) -> str:
|
|
|
|
| 9 |
"""
|
| 10 |
Initialize the research agent with local Ollama LLM.
|
| 11 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
print("Initializing RelevanceChecker with lightweight Hugging Face model...")
|
|
|
|
| 14 |
|
| 15 |
+
# Use a smaller, CPU-friendly model by default
|
| 16 |
+
model_name = getattr(settings, "HF_MODEL_RELEVANCE", "google/flan-t5-small")
|
| 17 |
+
|
| 18 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
|
| 20 |
+
# Use float32 on CPU (fp16 only works on GPU)
|
| 21 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 22 |
+
|
| 23 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 24 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device)
|
| 25 |
+
|
| 26 |
+
print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.")
|
| 27 |
|
| 28 |
|
| 29 |
def sanitize_response(self, response_text: str) -> str:
|