Goated121 commited on
Commit
f888dd3
·
verified ·
1 Parent(s): d584e33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -14
app.py CHANGED
@@ -3,6 +3,8 @@ import faiss
3
  import pickle
4
  import numpy as np
5
  from sentence_transformers import SentenceTransformer
 
 
6
  import os
7
 
8
  print("Files in current directory:", os.listdir())
@@ -11,11 +13,33 @@ print("Files in current directory:", os.listdir())
11
  # Load RAG components
12
  # -----------------------------
13
  embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
14
-
15
  index = faiss.read_index("faiss_index.bin")
16
  chunks = pickle.load(open("chunks.pkl", "rb"))
17
  metadata = pickle.load(open("metadata.pkl", "rb"))
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # -----------------------------
20
  # Intent detection
21
  # -----------------------------
@@ -42,40 +66,55 @@ def detect_query(query):
42
  def retrieve_context(query, top_k=2):
43
  animal, topic = detect_query(query)
44
 
45
- # Filter relevant chunks based on metadata
46
  filtered_indices = [
47
  i for i, meta in enumerate(metadata)
48
  if (not animal or meta["animal"] == animal) and
49
  (not topic or meta["topic"] == topic)
50
  ]
51
 
52
- # If no specific filter matches, consider all chunks
53
  if not filtered_indices:
54
  filtered_indices = list(range(len(chunks)))
55
 
56
- # Embed query
57
  query_embedding = embed_model.encode([query])
58
  filtered_embeddings = np.array([index.reconstruct(i) for i in filtered_indices])
59
 
60
- # Compute distances and get top-k closest chunks
61
  distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
62
  top_indices = distances.argsort()[:top_k]
63
 
64
- # Combine top chunks into context
65
  context = "\n".join(chunks[filtered_indices[idx]] for idx in top_indices)
66
-
67
  return context.strip()
68
 
69
  # -----------------------------
70
- # Chat function (RAG only)
71
  # -----------------------------
72
  def chat(user_input):
73
  context = retrieve_context(user_input)
74
  if not context:
75
  return "I don't know."
76
-
77
- # Return context with clear formatting
78
- return f"Answer from retrieved data:\n\n{context}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # -----------------------------
81
  # Gradio UI
@@ -84,7 +123,6 @@ gr.Interface(
84
  fn=chat,
85
  inputs=gr.Textbox(lines=2, placeholder="Ask a question about livestock..."),
86
  outputs=gr.Textbox(),
87
- title="Livestock Chatbot (RAG only)",
88
- description="This chatbot answers livestock questions using only retrieved data. No AI model is used.",
89
- allow_flagging="never"
90
  ).launch()
 
3
  import pickle
4
  import numpy as np
5
  from sentence_transformers import SentenceTransformer
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
+ import torch
8
  import os
9
 
10
  print("Files in current directory:", os.listdir())
 
13
  # Load RAG components
14
  # -----------------------------
15
  embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
 
16
  index = faiss.read_index("faiss_index.bin")
17
  chunks = pickle.load(open("chunks.pkl", "rb"))
18
  metadata = pickle.load(open("metadata.pkl", "rb"))
19
 
20
+ # -----------------------------
21
+ # Load Qwen 2.5B Instruct model
22
+ # -----------------------------
23
+ model_name = "Qwen/Qwen2.5-1.5B-Instruct"
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ device_map="auto",
29
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
30
+ )
31
+
32
+ generator = pipeline(
33
+ "text-generation",
34
+ model=model,
35
+ tokenizer=tokenizer,
36
+ max_new_tokens=200,
37
+ do_sample=True,
38
+ temperature=0.6
39
+ )
40
+
41
+ print("Qwen model loaded successfully!")
42
+
43
  # -----------------------------
44
  # Intent detection
45
  # -----------------------------
 
66
  def retrieve_context(query, top_k=2):
67
  animal, topic = detect_query(query)
68
 
 
69
  filtered_indices = [
70
  i for i, meta in enumerate(metadata)
71
  if (not animal or meta["animal"] == animal) and
72
  (not topic or meta["topic"] == topic)
73
  ]
74
 
 
75
  if not filtered_indices:
76
  filtered_indices = list(range(len(chunks)))
77
 
 
78
  query_embedding = embed_model.encode([query])
79
  filtered_embeddings = np.array([index.reconstruct(i) for i in filtered_indices])
80
 
 
81
  distances = np.linalg.norm(filtered_embeddings - query_embedding, axis=1)
82
  top_indices = distances.argsort()[:top_k]
83
 
 
84
  context = "\n".join(chunks[filtered_indices[idx]] for idx in top_indices)
 
85
  return context.strip()
86
 
87
  # -----------------------------
88
+ # Chat function (RAG + Qwen)
89
  # -----------------------------
90
  def chat(user_input):
91
  context = retrieve_context(user_input)
92
  if not context:
93
  return "I don't know."
94
+
95
+ prompt = f"""
96
+ You are a livestock expert assistant.
97
+
98
+ Use ONLY the information below to answer the question.
99
+ If the answer is not present, say "I don't know".
100
+
101
+ Context:
102
+ {context}
103
+
104
+ Question:
105
+ {user_input}
106
+
107
+ Answer in full, clear sentences.
108
+ """
109
+
110
+ response = generator(prompt, max_new_tokens=200, do_sample=True, temperature=0.6)
111
+ text = response[0]["generated_text"]
112
+
113
+ # Remove prompt repetition
114
+ if prompt.strip() in text:
115
+ text = text.split(prompt.strip())[-1].strip()
116
+
117
+ return text
118
 
119
  # -----------------------------
120
  # Gradio UI
 
123
  fn=chat,
124
  inputs=gr.Textbox(lines=2, placeholder="Ask a question about livestock..."),
125
  outputs=gr.Textbox(),
126
+ title="Livestock Chatbot (RAG + Qwen)",
127
+ description="This chatbot answers livestock questions using retrieved data and Qwen Instruct model."
 
128
  ).launch()