Goated121 commited on
Commit
093f515
·
verified ·
1 Parent(s): 53c3ebe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -21
app.py CHANGED
@@ -4,28 +4,44 @@ import faiss
4
  import pickle
5
  import numpy as np
6
  from sentence_transformers import SentenceTransformer
7
-
8
  import os
 
9
  print("Files in current directory:", os.listdir())
10
 
11
  # -----------------------------
12
- # Load LLM
13
  # -----------------------------
14
- model = Llama(
15
- model_path="qwen2.5-1.5B-q4.gguf",
16
- n_ctx=4096,
17
- n_gpu_layers=0,
18
- chat_format="qwen",
19
- )
20
 
21
  # -----------------------------
22
- # Load RAG
23
  # -----------------------------
24
- embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- index = faiss.read_index("faiss_index.bin")
27
- chunks = pickle.load(open("chunks.pkl", "rb"))
28
- metadata = pickle.load(open("metadata.pkl", "rb"))
29
  # -----------------------------
30
  # Detect query intent
31
  # -----------------------------
@@ -51,6 +67,8 @@ def detect_query(query):
51
  # Retrieve context (RAG)
52
  # -----------------------------
53
  def retrieve_context(query):
 
 
54
  animal, topic = detect_query(query)
55
 
56
  filtered_indices = []
@@ -65,9 +83,7 @@ def retrieve_context(query):
65
  filtered_indices = list(range(len(chunks)))
66
 
67
  query_embedding = embed_model.encode([query])
68
-
69
- filtered_embeddings = [index.reconstruct(i) for i in filtered_indices]
70
- filtered_embeddings = np.array(filtered_embeddings)
71
 
72
  distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
73
  top_indices = distances.argsort()[:2]
@@ -80,9 +96,11 @@ def retrieve_context(query):
80
  return context
81
 
82
  # -----------------------------
83
- # Chat function (UPDATED)
84
  # -----------------------------
85
  def chat(user_input):
 
 
86
  context = retrieve_context(user_input)
87
 
88
  prompt = f"""
@@ -114,11 +132,15 @@ Answer in short and clear sentences.
114
  return response["choices"][0]["message"]["content"]
115
 
116
  # -----------------------------
117
- # Gradio UI (UNCHANGED)
118
  # -----------------------------
119
- gr.Interface(
120
  fn=chat,
121
  inputs="text",
122
  outputs="text",
123
- title="Livestock Chatbot"
124
- ).launch()
 
 
 
 
 
4
  import pickle
5
  import numpy as np
6
  from sentence_transformers import SentenceTransformer
 
7
  import os
8
+
9
  print("Files in current directory:", os.listdir())
10
 
11
  # -----------------------------
12
+ # Globals (lazy-loaded)
13
  # -----------------------------
14
+ model = None
15
+ embed_model = None
16
+ index = None
17
+ chunks = None
18
+ metadata = None
 
19
 
20
  # -----------------------------
21
+ # Lazy-loading functions
22
  # -----------------------------
23
+ def load_llm():
24
+ global model
25
+ if model is None:
26
+ print("Loading LLM...")
27
+ model = Llama(
28
+ model_path="qwen2.5-1.5B-q4.gguf",
29
+ n_ctx=4096,
30
+ n_gpu_layers=0,
31
+ chat_format="qwen",
32
+ )
33
+ print("LLM loaded.")
34
+
35
+ def load_rag():
36
+ global embed_model, index, chunks, metadata
37
+ if embed_model is None or index is None or chunks is None or metadata is None:
38
+ print("Loading embedding model and FAISS index...")
39
+ embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
40
+ index = faiss.read_index("faiss_index.bin")
41
+ chunks = pickle.load(open("chunks.pkl", "rb"))
42
+ metadata = pickle.load(open("metadata.pkl", "rb"))
43
+ print("RAG components loaded.")
44
 
 
 
 
45
  # -----------------------------
46
  # Detect query intent
47
  # -----------------------------
 
67
  # Retrieve context (RAG)
68
  # -----------------------------
69
  def retrieve_context(query):
70
+ load_rag() # ensure RAG is loaded
71
+
72
  animal, topic = detect_query(query)
73
 
74
  filtered_indices = []
 
83
  filtered_indices = list(range(len(chunks)))
84
 
85
  query_embedding = embed_model.encode([query])
86
+ filtered_embeddings = np.array([index.reconstruct(i) for i in filtered_indices])
 
 
87
 
88
  distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
89
  top_indices = distances.argsort()[:2]
 
96
  return context
97
 
98
  # -----------------------------
99
+ # Chat function
100
  # -----------------------------
101
  def chat(user_input):
102
+ load_llm() # ensure LLM is loaded
103
+
104
  context = retrieve_context(user_input)
105
 
106
  prompt = f"""
 
132
  return response["choices"][0]["message"]["content"]
133
 
134
  # -----------------------------
135
+ # Gradio UI
136
  # -----------------------------
137
+ demo = gr.Interface(
138
  fn=chat,
139
  inputs="text",
140
  outputs="text",
141
+ title="Livestock Chatbot",
142
+ description="Ask questions about goats and cows. The assistant answers using only the provided knowledge base."
143
+ )
144
+
145
+ if __name__ == "__main__":
146
+ demo.launch(server_name="0.0.0.0", server_port=7860)