|
"""Generator component for the RAG system.""" |
|
|
|
from typing import List, Dict |
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
LogitsProcessor, |
|
LogitsProcessorList |
|
) |
|
|
|
class FinancialContextProcessor(LogitsProcessor): |
|
"""Custom logits processor for financial context.""" |
|
def __init__(self, financial_constraints: Dict): |
|
self.constraints = financial_constraints |
|
|
|
def __call__(self, input_ids: torch.LongTensor, |
|
scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
|
|
|
return scores |
|
|
|
class RAGGenerator: |
|
def __init__(self, config: Dict): |
|
"""Initialize the generator.""" |
|
self.model_name = "gpt2" |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_name) |
|
self.max_length = 512 |
|
|
|
def prepare_context(self, retrieved_docs: List[Dict]) -> str: |
|
"""Prepare context from retrieved documents.""" |
|
context = "" |
|
for doc in retrieved_docs: |
|
context += f"{doc['document']['text']}\n" |
|
return context.strip() |
|
|
|
def generate(self, query: str, retrieved_docs: List[Dict], |
|
financial_constraints: Dict = None) -> str: |
|
"""Generate text based on query and retrieved documents.""" |
|
context = self.prepare_context(retrieved_docs) |
|
prompt = f"Context: {context}\nQuery: {query}\nResponse:" |
|
|
|
|
|
processors = LogitsProcessorList() |
|
if financial_constraints: |
|
processors.append(FinancialContextProcessor(financial_constraints)) |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt") |
|
outputs = self.model.generate( |
|
inputs.input_ids, |
|
max_length=self.max_length, |
|
num_return_sequences=1, |
|
logits_processor=processors, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9 |
|
) |
|
|
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|