|
import os |
|
import json |
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
from buffalo_rag.vector_store.db import VectorStore |
|
|
|
class BuffaloRAG: |
|
def __init__(self, |
|
model_name: str = "Qwen/Qwen1.5-1.8B-Chat", |
|
vector_store: Optional[VectorStore] = None): |
|
self.vector_store = vector_store or VectorStore() |
|
|
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
self.pipe = pipeline( |
|
"text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
max_new_tokens=256, |
|
do_sample=False, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
except Exception as e: |
|
print(f"Error loading main model: {str(e)}") |
|
print("Falling back to smaller model...") |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2") |
|
self.model = AutoModelForCausalLM.from_pretrained("distilgpt2") |
|
self.pipe = pipeline( |
|
"text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
max_new_tokens=256 |
|
) |
|
|
|
def retrieve(self, |
|
query: str, |
|
k: int = 5, |
|
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]: |
|
"""Retrieve relevant chunks for a query.""" |
|
return self.vector_store.hybrid_search(query, k=k, filter_categories=filter_categories) |
|
|
|
def format_context(self, results: List[Dict[str, Any]]) -> str: |
|
"""Format retrieved results into context.""" |
|
context = "" |
|
|
|
for i, result in enumerate(results): |
|
chunk = result['chunk'] |
|
context += f"Source {i+1}: {chunk['title']}\n" |
|
context += f"URL: {chunk['url']}\n" |
|
context += f"Content: {chunk['content'][:500]}...\n\n" |
|
|
|
return context |
|
|
|
def generate_response(self, query: str, context: str) -> str: |
|
"""Generate response using the language model with error handling.""" |
|
prompt = f"""You are a friendly and professional counselor for international students at the University at Buffalo. Respond to the student's query in a supportive, detailed, and well-structured manner. |
|
|
|
For your responses: |
|
1. Address the student respectfully and empathetically |
|
2. Provide clear, accurate information with specific details and steps when applicable |
|
3. Organize your answer with appropriate headings, bullet points, or numbered lists when helpful |
|
4. If the student's question is unclear or lacks essential details, ask 1-2 specific clarifying questions to better understand their situation |
|
5. Include relevant deadlines, contacts, or resources when appropriate |
|
6. Conclude with a brief encouraging statement |
|
7. Only answer related to international students at UB, if it's not related to international students at UB, just say "I'm sorry, I don't have information about that." |
|
8. Do not entertain any questions that are not related to students at UB. |
|
|
|
Question: {query} |
|
|
|
Relevant Information: |
|
{context} |
|
|
|
Answer:""" |
|
|
|
try: |
|
|
|
response = self.pipe(prompt)[0]['generated_text'] |
|
|
|
|
|
generated = response[len(prompt):].strip() |
|
|
|
return generated |
|
except Exception as e: |
|
print(f"Error during generation: {str(e)}") |
|
|
|
return "I'm sorry, I encountered an issue generating a response. Please try asking your question in a different way or contact UB International Student Services directly for assistance." |
|
|
|
def answer(self, |
|
query: str, |
|
k: int = 5, |
|
filter_categories: Optional[List[str]] = None) -> Dict[str, Any]: |
|
"""End-to-end RAG pipeline.""" |
|
|
|
results = self.retrieve(query, k=k, filter_categories=filter_categories) |
|
|
|
|
|
context = self.format_context(results) |
|
|
|
|
|
response = self.generate_response(query, context) |
|
|
|
|
|
return { |
|
'query': query, |
|
'response': response, |
|
'sources': [ |
|
{ |
|
'title': result['chunk']['title'], |
|
'url': result['chunk']['url'], |
|
'score': result.get('rerank_score', result['score']) |
|
} |
|
for result in results |
|
] |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
rag = BuffaloRAG(model_name="1bitLLM/bitnet_b1_58-large") |
|
response = rag.answer("How do I apply for OPT?") |
|
|
|
print(f"Query: {response['query']}") |
|
print(f"Response: {response['response']}") |
|
print("\nSources:") |
|
for source in response['sources']: |
|
print(f"- {source['title']} (Score: {source['score']:.4f})") |