Spaces:
Sleeping
Sleeping
"""Generator wrapper around Hugging Face transformers pipelines.""" | |
import logging | |
from typing import List | |
try: | |
from transformers import pipeline | |
except ImportError: | |
pipeline = None | |
from .base import Generator | |
logger = logging.getLogger(__name__) | |
class HFGenerator(Generator): | |
"""Seq2seq generator using a Hugging Face model (e.g., Flan-T5).""" | |
def __init__(self, model_name: str = "google/flan-t5-base", device: str = "cpu"): | |
self.model_name = model_name | |
device_index = 0 if device.startswith("cuda") else -1 | |
if pipeline is None: | |
logger.warning( | |
"transformers.pipeline not available. HFGenerator.generate() → empty string." | |
) | |
self.pipe = lambda *args, **kwargs: [{"generated_text": ""}] | |
else: | |
try: | |
self.pipe = pipeline( | |
"text2text-generation", | |
model=model_name, | |
device=device_index, | |
) | |
logger.info("HFGenerator loaded model '%s' on %s", model_name, device) | |
except Exception as e: | |
logger.warning( | |
"HFGenerator failed to load '%s'. generate() will return empty. (%s)", | |
model_name, | |
e, | |
) | |
self.pipe = lambda *args, **kwargs: [{"generated_text": ""}] | |
def generate( | |
self, | |
question: str, | |
contexts: List[str], | |
*, | |
max_new_tokens: int = 256, | |
temperature: float = 0.0, | |
) -> str: | |
# Safely join contexts outside f-string | |
context_block = "\n".join(contexts) | |
prompt = ( | |
"Answer the question using only the provided context.\n\n" | |
"Context:\n" | |
f"{context_block}\n\n" | |
f"Question: {question}\nAnswer:" | |
) | |
try: | |
outputs = self.pipe( | |
prompt, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
do_sample=(temperature > 0), | |
) | |
return outputs[0].get("generated_text", "").strip() | |
except Exception: | |
return "" | |
def __repr__(self): | |
return f"HFGenerator(model={self.model_name})" | |