Goated121 commited on
Commit
1e93d04
·
verified ·
1 Parent(s): 093f515

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -54
app.py CHANGED
@@ -1,46 +1,33 @@
1
- from llama_cpp import Llama
2
  import gradio as gr
3
  import faiss
4
  import pickle
5
  import numpy as np
6
  from sentence_transformers import SentenceTransformer
7
- import os
8
 
 
9
  print("Files in current directory:", os.listdir())
10
 
11
  # -----------------------------
12
- # Globals (lazy-loaded)
13
  # -----------------------------
14
- model = None
15
- embed_model = None
16
- index = None
17
- chunks = None
18
- metadata = None
19
 
20
  # -----------------------------
21
- # Lazy-loading functions
22
  # -----------------------------
23
- def load_llm():
24
- global model
25
- if model is None:
26
- print("Loading LLM...")
27
- model = Llama(
28
- model_path="qwen2.5-1.5B-q4.gguf",
29
- n_ctx=4096,
30
- n_gpu_layers=0,
31
- chat_format="qwen",
32
- )
33
- print("LLM loaded.")
34
-
35
- def load_rag():
36
- global embed_model, index, chunks, metadata
37
- if embed_model is None or index is None or chunks is None or metadata is None:
38
- print("Loading embedding model and FAISS index...")
39
- embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
40
- index = faiss.read_index("faiss_index.bin")
41
- chunks = pickle.load(open("chunks.pkl", "rb"))
42
- metadata = pickle.load(open("metadata.pkl", "rb"))
43
- print("RAG components loaded.")
44
 
45
  # -----------------------------
46
  # Detect query intent
@@ -67,8 +54,6 @@ def detect_query(query):
67
  # Retrieve context (RAG)
68
  # -----------------------------
69
  def retrieve_context(query):
70
- load_rag() # ensure RAG is loaded
71
-
72
  animal, topic = detect_query(query)
73
 
74
  filtered_indices = []
@@ -83,7 +68,9 @@ def retrieve_context(query):
83
  filtered_indices = list(range(len(chunks)))
84
 
85
  query_embedding = embed_model.encode([query])
86
- filtered_embeddings = np.array([index.reconstruct(i) for i in filtered_indices])
 
 
87
 
88
  distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
89
  top_indices = distances.argsort()[:2]
@@ -99,8 +86,6 @@ def retrieve_context(query):
99
  # Chat function
100
  # -----------------------------
101
  def chat(user_input):
102
- load_llm() # ensure LLM is loaded
103
-
104
  context = retrieve_context(user_input)
105
 
106
  prompt = f"""
@@ -118,29 +103,16 @@ Question:
118
  Answer in short and clear sentences.
119
  """
120
 
121
- messages = [
122
- {"role": "system", "content": "You are a helpful assistant."},
123
- {"role": "user", "content": prompt}
124
- ]
125
-
126
- response = model.create_chat_completion(
127
- messages=messages,
128
- max_tokens=200,
129
- temperature=0.5,
130
- )
131
-
132
- return response["choices"][0]["message"]["content"]
133
 
134
  # -----------------------------
135
  # Gradio UI
136
  # -----------------------------
137
- demo = gr.Interface(
138
  fn=chat,
139
  inputs="text",
140
  outputs="text",
141
- title="Livestock Chatbot",
142
- description="Ask questions about goats and cows. The assistant answers using only the provided knowledge base."
143
- )
144
-
145
- if __name__ == "__main__":
146
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # app.py
2
  import gradio as gr
3
  import faiss
4
  import pickle
5
  import numpy as np
6
  from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
 
9
+ import os
10
  print("Files in current directory:", os.listdir())
11
 
12
  # -----------------------------
13
+ # Load RAG components
14
  # -----------------------------
15
+ embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
16
+
17
+ index = faiss.read_index("faiss_index.bin")
18
+ chunks = pickle.load(open("chunks.pkl", "rb"))
19
+ metadata = pickle.load(open("metadata.pkl", "rb"))
20
 
21
  # -----------------------------
22
+ # Load Hugging Face LLM (CPU-friendly)
23
  # -----------------------------
24
+ # Small model for HF Spaces CPU limits
25
+ model_name = "TheBloke/vicuna-7B-1.1-HF" # You can replace with a smaller model if needed
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") # Hugging Face will manage CPU/GPU
28
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=200)
29
+
30
+ print("LLM loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # -----------------------------
33
  # Detect query intent
 
54
  # Retrieve context (RAG)
55
  # -----------------------------
56
  def retrieve_context(query):
 
 
57
  animal, topic = detect_query(query)
58
 
59
  filtered_indices = []
 
68
  filtered_indices = list(range(len(chunks)))
69
 
70
  query_embedding = embed_model.encode([query])
71
+
72
+ filtered_embeddings = [index.reconstruct(i) for i in filtered_indices]
73
+ filtered_embeddings = np.array(filtered_embeddings)
74
 
75
  distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
76
  top_indices = distances.argsort()[:2]
 
86
  # Chat function
87
  # -----------------------------
88
  def chat(user_input):
 
 
89
  context = retrieve_context(user_input)
90
 
91
  prompt = f"""
 
103
  Answer in short and clear sentences.
104
  """
105
 
106
+ # Generate response
107
+ response = generator(prompt, max_length=200, do_sample=True, temperature=0.5)
108
+ return response[0]["generated_text"]
 
 
 
 
 
 
 
 
 
109
 
110
  # -----------------------------
111
  # Gradio UI
112
  # -----------------------------
113
+ gr.Interface(
114
  fn=chat,
115
  inputs="text",
116
  outputs="text",
117
+ title="Livestock Chatbot"
118
+ ).launch()