File size: 4,272 Bytes
d8f06d4
 
 
 
7017d8a
 
 
d8f06d4
 
 
 
7017d8a
 
 
 
 
 
d8f06d4
7017d8a
 
 
 
 
 
 
 
d8f06d4
7017d8a
d8f06d4
 
 
 
 
 
 
 
7017d8a
 
 
 
 
 
 
 
 
d8f06d4
 
 
 
7017d8a
 
 
 
 
 
 
 
 
d8f06d4
7017d8a
d8f06d4
7017d8a
 
d8f06d4
7017d8a
d8f06d4
 
7017d8a
 
 
 
 
 
 
 
81a46db
7017d8a
 
 
d8f06d4
 
 
 
 
 
 
 
3673d92
d8f06d4
 
3673d92
d8f06d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3673d92
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import json
from typing import List, Dict, Any, Optional, Tuple

# from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

from huggingface_hub import InferenceClient

from buffalo_rag.vector_store.db import VectorStore

class BuffaloRAG:
    def __init__(
        self,
        model_name: str = "meta-llama/Llama-2-7b-chat-hf",
        vector_store: Optional[VectorStore] = None
    ):
        # 1. Vector store
        self.vector_store = vector_store or VectorStore()

        # 2. Hugging Face Inference client
        hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
        if not hf_token:
            raise ValueError("Please set HUGGINGFACEHUB_API_TOKEN in your environment.")
        self.client = InferenceClient(
                provider="cerebras",
                api_key=hf_token,
            )

    
    def retrieve(self, 
                query: str, 
                k: int = 5,
                filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
        return self.vector_store.hybrid_search(query, k=k, filter_categories=filter_categories)
    
    def format_context(self, results: List[Dict[str, Any]]) -> str:
        ctx = []
        for i, r in enumerate(results, start=1):
            c = r["chunk"]
            ctx.append(
                f"Source {i}: {c['title']}\n"
                f"URL: {c['url']}\n"
                f"Content: {c['content'][:500]}...\n"
            )
        return "\n".join(ctx)
    
    def generate_response(self, query: str, context: str) -> str:
        prompt = f"""You are a friendly and professional counselor for international students at the University at Buffalo. Respond to the student's query in a supportive, detailed, and well-structured manner.

                For your responses:
                1. Address the student respectfully and empathetically
                2. Provide clear, accurate information with specific details and steps when applicable
                3. Organize your answer with appropriate headings, bullet points, or numbered lists when helpful
                4. If the student's question is unclear or lacks essential details, ask 1-2 specific clarifying questions to better understand their situation
                5. Include relevant deadlines, contacts, or resources when appropriate
                6. Conclude with a brief encouraging statement
                7. Only answer related to international students at UB, if it's not related to international students at UB, just say "I'm sorry, I don't have information about that."
                8. Do not entertain any questions that are not related to students at UB.

                Question: {query}

                Relevant Information:
                {context}

                Answer:"""
        
        try:
            completion = self.client.chat.completions.create(
                model="meta-llama/Llama-3.3-70B-Instruct",
                messages=[
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
                max_tokens=2048,
            )

            return completion.choices[0].message.content
        except Exception as e:
            print(f"Error during generation: {str(e)}")
            return "I'm sorry, I encountered an issue generating a response. Please try asking your question in a different way or contact UB International Student Services directly for assistance."
    
    def answer(self, 
              query: str, 
              k: int = 5,
              filter_categories: Optional[List[str]] = None) -> Dict[str, Any]:
        
        results = self.retrieve(query, k=k, filter_categories=filter_categories)
        
        
        context = self.format_context(results)
        
        response = self.generate_response(query, context)
        
        return {
            'query': query,
            'response': response,
            'sources': [
                {
                    'title': result['chunk']['title'],
                    'url': result['chunk']['url'],
                    'score': result.get('rerank_score', result['score'])
                }
                for result in results
            ]
        }