RAG_Eval / evaluation /generators /hf_generator.py
Rom89823974978's picture
Updated metrics and tests
fc20fed
"""Generator wrapper around Hugging Face transformers pipelines."""
import logging
from typing import List
try:
from transformers import pipeline
except ImportError:
pipeline = None
from .base import Generator
logger = logging.getLogger(__name__)
class HFGenerator(Generator):
"""Seq2seq generator using a Hugging Face model (e.g., Flan-T5)."""
def __init__(self, model_name: str = "google/flan-t5-base", device: str = "cpu"):
self.model_name = model_name
device_index = 0 if device.startswith("cuda") else -1
if pipeline is None:
logger.warning(
"transformers.pipeline not available. HFGenerator.generate() → empty string."
)
self.pipe = lambda *args, **kwargs: [{"generated_text": ""}]
else:
try:
self.pipe = pipeline(
"text2text-generation",
model=model_name,
device=device_index,
)
logger.info("HFGenerator loaded model '%s' on %s", model_name, device)
except Exception as e:
logger.warning(
"HFGenerator failed to load '%s'. generate() will return empty. (%s)",
model_name,
e,
)
self.pipe = lambda *args, **kwargs: [{"generated_text": ""}]
def generate(
self,
question: str,
contexts: List[str],
*,
max_new_tokens: int = 256,
temperature: float = 0.0,
) -> str:
# Safely join contexts outside f-string
context_block = "\n".join(contexts)
prompt = (
"Answer the question using only the provided context.\n\n"
"Context:\n"
f"{context_block}\n\n"
f"Question: {question}\nAnswer:"
)
try:
outputs = self.pipe(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=(temperature > 0),
)
return outputs[0].get("generated_text", "").strip()
except Exception:
return ""
def __repr__(self):
return f"HFGenerator(model={self.model_name})"