chagu-dev / rag_sec /senamtic_response_generator.py
talexm
update
c3c1187
raw
history blame
1.04 kB
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
class SemanticResponseGenerator:
def __init__(self, model_name="google/flan-t5-small", max_input_length=512, max_new_tokens=50):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.max_input_length = max_input_length
self.max_new_tokens = max_new_tokens
def generate_response(self, retrieved_docs):
combined_docs = " ".join(retrieved_docs[:2])
truncated_docs = combined_docs[:self.max_input_length - 50]
input_text = f"Based on the following information: {truncated_docs}"
inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, max_length=self.max_input_length)
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)