File size: 5,024 Bytes
9108a9a 326a604 9108a9a 326a604 9108a9a 326a604 9108a9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import json
from typing import List, Dict, Any, Optional, Tuple
# from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from huggingface_hub import InferenceClient
from buffalo_rag.vector_store.db import VectorStore
class BuffaloRAG:
def __init__(
self,
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
vector_store: Optional[VectorStore] = None
):
# 1. Vector store
self.vector_store = vector_store or VectorStore()
# 2. Hugging Face Inference client
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not hf_token:
raise ValueError("Please set HUGGINGFACEHUB_API_TOKEN in your environment.")
self.client = InferenceClient(
provider="cerebras",
api_key=hf_token,
)
def retrieve(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Retrieve relevant chunks for a query."""
return self.vector_store.hybrid_search(query, k=k, filter_categories=filter_categories)
def format_context(self, results: List[Dict[str, Any]]) -> str:
"""Concatenate retrieved passages into context."""
ctx = []
for i, r in enumerate(results, start=1):
c = r["chunk"]
ctx.append(
f"Source {i}: {c['title']}\n"
f"URL: {c['url']}\n"
f"Content: {c['content'][:500]}...\n"
)
return "\n".join(ctx)
def generate_response(self, query: str, context: str) -> str:
"""Generate response using the language model with error handling."""
prompt = f"""You are a friendly and professional counselor for international students at the University at Buffalo. Respond to the student's query in a supportive, detailed, and well-structured manner.
For your responses:
1. Address the student respectfully and empathetically
2. Provide clear, accurate information with specific details and steps when applicable
3. Organize your answer with appropriate headings, bullet points, or numbered lists when helpful
4. If the student's question is unclear or lacks essential details, ask 1-2 specific clarifying questions to better understand their situation
5. Include relevant deadlines, contacts, or resources when appropriate
6. Conclude with a brief encouraging statement
7. Only answer related to international students at UB, if it's not related to international students at UB, just say "I'm sorry, I don't have information about that."
8. Do not entertain any questions that are not related to students at UB.
Question: {query}
Relevant Information:
{context}
Answer:"""
try:
completion = self.client.chat.completions.create(
model="meta-llama/Llama-3.3-70B-Instruct",
messages=[
{
"role": "user",
"content": prompt
}
],
max_tokens=512,
)
return completion.choices[0].message.content
except Exception as e:
print(f"Error during generation: {str(e)}")
# Fallback response
return "I'm sorry, I encountered an issue generating a response. Please try asking your question in a different way or contact UB International Student Services directly for assistance."
def answer(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None) -> Dict[str, Any]:
"""End-to-end RAG pipeline."""
# Retrieve relevant chunks
results = self.retrieve(query, k=k, filter_categories=filter_categories)
# Format context
context = self.format_context(results)
# Generate response
response = self.generate_response(query, context)
# Return response and sources
return {
'query': query,
'response': response,
'sources': [
{
'title': result['chunk']['title'],
'url': result['chunk']['url'],
'score': result.get('rerank_score', result['score'])
}
for result in results
]
}
# Example usage
if __name__ == "__main__":
rag = BuffaloRAG(model_name="1bitLLM/bitnet_b1_58-large")
response = rag.answer("How do I apply for OPT?")
print(f"Query: {response['query']}")
print(f"Response: {response['response']}")
print("\nSources:")
for source in response['sources']:
print(f"- {source['title']} (Score: {source['score']:.4f})") |