hf-rag-multi / src /rag_pipeline.py
siyu618's picture
Upload 18 files
94f5c4b verified
from transformers import RagTokenizer, RagSequenceForGeneration
from config.rag_config import RAGConfig
from src.embedder import Embedder
from src.retriever import Retriever
class RAGPipeline:
def __init__(self, config: RAGConfig, docs, doc_embeddings):
self.config = config
self.embedder = Embedder(config)
self.retriever = Retriever(doc_embeddings, docs, config)
self.tokenizer = RagTokenizer.from_pretrained(config.llm_model_name)
self.model = RagSequenceForGeneration.from_pretrained(config.llm_model_name)
def ask(self, query):
query_emb = self.embedder.embed_texts([query])[0]
retrieved = self.retriever.retrieve(query_emb)
context = "\n".join([r[0] for r in retrieved])
input_text = f"Question: {query}\nContext: {context}"
inputs = self.tokenizer(input_text, return_tensors="pt")
output = self.model.generate(
**inputs,
**self.config.generation_kwargs
)
return self.tokenizer.batch_decode(output, skip_special_tokens=True)[0], retrieved