Spaces:
Sleeping
Sleeping
| from transformers import RagTokenizer, RagSequenceForGeneration | |
| from config.rag_config import RAGConfig | |
| from src.embedder import Embedder | |
| from src.retriever import Retriever | |
| class RAGPipeline: | |
| def __init__(self, config: RAGConfig, docs, doc_embeddings): | |
| self.config = config | |
| self.embedder = Embedder(config) | |
| self.retriever = Retriever(doc_embeddings, docs, config) | |
| self.tokenizer = RagTokenizer.from_pretrained(config.llm_model_name) | |
| self.model = RagSequenceForGeneration.from_pretrained(config.llm_model_name) | |
| def ask(self, query): | |
| query_emb = self.embedder.embed_texts([query])[0] | |
| retrieved = self.retriever.retrieve(query_emb) | |
| context = "\n".join([r[0] for r in retrieved]) | |
| input_text = f"Question: {query}\nContext: {context}" | |
| inputs = self.tokenizer(input_text, return_tensors="pt") | |
| output = self.model.generate( | |
| **inputs, | |
| **self.config.generation_kwargs | |
| ) | |
| return self.tokenizer.batch_decode(output, skip_special_tokens=True)[0], retrieved | |