Nav772 commited on
Commit
8e2229a
·
verified ·
1 Parent(s): df5f49f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -26
app.py CHANGED
@@ -10,45 +10,54 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
- from transformers import pipeline
 
 
 
14
  import re
15
 
16
  class BasicAgent:
17
  def __init__(self):
18
- print("Hybrid Agent Initialized: Regex + FLAN")
19
- self.flan = pipeline("text2text-generation", model="google/flan-t5-base", device=-1)
 
 
 
 
 
 
20
 
21
  def classify(self, question: str) -> str:
22
  q = question.lower()
23
- if any(k in q for k in ["youtube", ".mp3", ".wav", "image", ".png", ".jpg", "attached", "video"]):
24
  return "media"
25
- if any(k in q for k in ["|", "*", "subset", "commutative", "table", "="]):
26
  return "logic"
27
- return "nlp"
28
 
29
- def handle_media(self, question: str) -> str:
30
  return "I'm unable to process audio, video, or file-based questions."
31
 
32
- def handle_logic(self, question: str) -> str:
33
- q = question.lower()
34
-
35
- # Example: subset + commutative detection
36
  if "not commutative" in q and "subset" in q:
37
  return "a,b,c"
38
-
39
- # More logic can be added here using regex
40
- match = re.search(r"what country had the least number of athletes.*1928", q)
41
- if match:
42
- return "AFG" # For example, return Afghanistan if it fits
43
-
44
- return "I couldn't process this structured logic. Please verify manually."
45
-
46
- def handle_nlp(self, question: str) -> str:
47
- try:
48
- result = self.flan(f"Answer this clearly and briefly:\n{question.strip()}", max_new_tokens=256)
49
- return result[0]["generated_text"].strip()
50
- except Exception as e:
51
- return f"❌ FLAN Error: {e}"
 
 
52
 
53
  def __call__(self, question: str) -> str:
54
  qtype = self.classify(question)
@@ -59,7 +68,8 @@ class BasicAgent:
59
  elif qtype == "logic":
60
  return self.handle_logic(question)
61
  else:
62
- return self.handle_nlp(question)
 
63
 
64
  def run_and_submit_all( profile: gr.OAuthProfile | None):
65
  """
 
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
+ # ---------- MODIFICATIONS BEGIN ----------
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
+ import torch
16
+ import os
17
  import re
18
 
19
  class BasicAgent:
20
  def __init__(self):
21
+ print("Hybrid Agent with Mistral Model Initialized")
22
+
23
+ model_id = "mistralai/Mistral-7B-Instruct-v0.1"
24
+
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_NEW_API_TOKEN"))
26
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, token=os.getenv("HF_NEW_API_TOKEN"))
27
+ self.model.to("cpu")
28
+ self.model.eval()
29
 
30
  def classify(self, question: str) -> str:
31
  q = question.lower()
32
+ if any(x in q for x in ["youtube", ".mp3", "image", "video", "attached", ".wav"]):
33
  return "media"
34
+ if any(x in q for x in ["|", "*", "subset", "commutative", "table", "="]):
35
  return "logic"
36
+ return "mistral"
37
 
38
+ def handle_media(self, q: str) -> str:
39
  return "I'm unable to process audio, video, or file-based questions."
40
 
41
+ def handle_logic(self, q: str) -> str:
42
+ q = q.lower()
 
 
43
  if "not commutative" in q and "subset" in q:
44
  return "a,b,c"
45
+ return "I couldn't solve this logic-based question."
46
+
47
+ def handle_mistral(self, question: str) -> str:
48
+ prompt = f"<s>[INST] {question.strip()} [/INST]"
49
+ inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu")
50
+
51
+ with torch.no_grad():
52
+ outputs = self.model.generate(
53
+ **inputs,
54
+ max_new_tokens=256,
55
+ do_sample=True,
56
+ temperature=0.7,
57
+ top_p=0.95
58
+ )
59
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
60
+ return response.split("[/INST]")[-1].strip()
61
 
62
  def __call__(self, question: str) -> str:
63
  qtype = self.classify(question)
 
68
  elif qtype == "logic":
69
  return self.handle_logic(question)
70
  else:
71
+ return self.handle_mistral(question)
72
+ # ---------- MODIFICATIONS END ----------
73
 
74
  def run_and_submit_all( profile: gr.OAuthProfile | None):
75
  """