|
import os |
|
import json |
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
|
|
|
|
|
from huggingface_hub import InferenceClient |
|
|
|
from buffalo_rag.vector_store.db import VectorStore |
|
|
|
class BuffaloRAG: |
|
def __init__( |
|
self, |
|
model_name: str = "meta-llama/Llama-2-7b-chat-hf", |
|
vector_store: Optional[VectorStore] = None |
|
): |
|
|
|
self.vector_store = vector_store or VectorStore() |
|
|
|
|
|
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
if not hf_token: |
|
raise ValueError("Please set HUGGINGFACEHUB_API_TOKEN in your environment.") |
|
self.client = InferenceClient( |
|
provider="cerebras", |
|
api_key=hf_token, |
|
) |
|
|
|
|
|
def retrieve(self, |
|
query: str, |
|
k: int = 5, |
|
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]: |
|
"""Retrieve relevant chunks for a query.""" |
|
return self.vector_store.hybrid_search(query, k=k, filter_categories=filter_categories) |
|
|
|
def format_context(self, results: List[Dict[str, Any]]) -> str: |
|
"""Concatenate retrieved passages into context.""" |
|
ctx = [] |
|
for i, r in enumerate(results, start=1): |
|
c = r["chunk"] |
|
ctx.append( |
|
f"Source {i}: {c['title']}\n" |
|
f"URL: {c['url']}\n" |
|
f"Content: {c['content'][:500]}...\n" |
|
) |
|
return "\n".join(ctx) |
|
|
|
def generate_response(self, query: str, context: str) -> str: |
|
"""Generate response using the language model with error handling.""" |
|
prompt = f"""You are a friendly and professional counselor for international students at the University at Buffalo. Respond to the student's query in a supportive, detailed, and well-structured manner. |
|
|
|
For your responses: |
|
1. Address the student respectfully and empathetically |
|
2. Provide clear, accurate information with specific details and steps when applicable |
|
3. Organize your answer with appropriate headings, bullet points, or numbered lists when helpful |
|
4. If the student's question is unclear or lacks essential details, ask 1-2 specific clarifying questions to better understand their situation |
|
5. Include relevant deadlines, contacts, or resources when appropriate |
|
6. Conclude with a brief encouraging statement |
|
7. Only answer related to international students at UB, if it's not related to international students at UB, just say "I'm sorry, I don't have information about that." |
|
8. Do not entertain any questions that are not related to students at UB. |
|
|
|
Question: {query} |
|
|
|
Relevant Information: |
|
{context} |
|
|
|
Answer:""" |
|
|
|
try: |
|
completion = self.client.chat.completions.create( |
|
model="meta-llama/Llama-3.3-70B-Instruct", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
], |
|
max_tokens=512, |
|
) |
|
|
|
return completion.choices[0].message.content |
|
except Exception as e: |
|
print(f"Error during generation: {str(e)}") |
|
|
|
return "I'm sorry, I encountered an issue generating a response. Please try asking your question in a different way or contact UB International Student Services directly for assistance." |
|
|
|
def answer(self, |
|
query: str, |
|
k: int = 5, |
|
filter_categories: Optional[List[str]] = None) -> Dict[str, Any]: |
|
"""End-to-end RAG pipeline.""" |
|
|
|
results = self.retrieve(query, k=k, filter_categories=filter_categories) |
|
|
|
|
|
context = self.format_context(results) |
|
|
|
|
|
response = self.generate_response(query, context) |
|
|
|
|
|
return { |
|
'query': query, |
|
'response': response, |
|
'sources': [ |
|
{ |
|
'title': result['chunk']['title'], |
|
'url': result['chunk']['url'], |
|
'score': result.get('rerank_score', result['score']) |
|
} |
|
for result in results |
|
] |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
rag = BuffaloRAG(model_name="1bitLLM/bitnet_b1_58-large") |
|
response = rag.answer("How do I apply for OPT?") |
|
|
|
print(f"Query: {response['query']}") |
|
print(f"Response: {response['response']}") |
|
print("\nSources:") |
|
for source in response['sources']: |
|
print(f"- {source['title']} (Score: {source['score']:.4f})") |