jonathantiedchen commited on
Commit
87de3b4
Β·
verified Β·
1 Parent(s): b460ef7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -21
app.py CHANGED
@@ -39,7 +39,7 @@ def load_models():
39
  FastLanguageModel.for_inference(model)
40
 
41
  except Exception as e:
42
- st.error(f"⚠️ Failed to load Mistral model with Unsloth: {e}")
43
  continue
44
  else:
45
  tokenizer = AutoTokenizer.from_pretrained(path)
@@ -49,8 +49,9 @@ def load_models():
49
  models[name] = {"tokenizer": tokenizer, "model": model}
50
  return models
51
 
52
-
53
  models = load_models()
 
54
 
55
  model_choice = st.selectbox("Choose a model:", list(MODEL_OPTIONS.keys()))
56
  tokenizer = models[model_choice]["tokenizer"]
@@ -61,8 +62,9 @@ model = models[model_choice]["model"]
61
  def load_gsm8k_dataset():
62
  return load_dataset("openai/gsm8k", "main")["test"]
63
 
 
64
  gsm8k_data = load_gsm8k_dataset()
65
- st.write("πŸ“Š GSM8K loaded:", len(gsm8k_data), "samples")
66
 
67
  # === TABS ===
68
  tab1, tab2 = st.tabs(["πŸ”“ Manual Prompting", "πŸ“Š GSM8K Evaluation"])
@@ -71,7 +73,7 @@ tab1, tab2 = st.tabs(["πŸ”“ Manual Prompting", "πŸ“Š GSM8K Evaluation"])
71
  with tab1:
72
  prompt = st.text_area("Enter your math prompt:", "Jasper has 5 apples and eats 2 of them. How many apples does he have left?")
73
  if st.button("Generate Response", key="manual"):
74
- with st.spinner("Generating..."):
75
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
76
  output = model.generate(
77
  **inputs,
@@ -98,23 +100,24 @@ with tab2:
98
 
99
  if st.button("Run GSM8K Sample"):
100
  try:
101
- sample = random.choice(gsm8k_data)
102
- question = sample["question"]
103
- gold_answer = sample["answer"]
104
-
105
- inputs = tokenizer(question, return_tensors="pt").to(model.device)
106
-
107
- st.markdown(f"Create Output")
108
- output = model.generate(
109
- **inputs,
110
- max_new_tokens=150,
111
- temperature=0.7,
112
- do_sample=True,
113
- pad_token_id=tokenizer.eos_token_id,
114
- eos_token_id=tokenizer.eos_token_id,
115
- )
116
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
117
- response_only = generated_text[len(question):].strip()
 
118
 
119
  st.subheader("πŸ“Œ GSM8K Question")
120
  st.markdown(question)
 
39
  FastLanguageModel.for_inference(model)
40
 
41
  except Exception as e:
42
+ st.sidebar.error(f"⚠️ Failed to load Mistral model with Unsloth: {e}")
43
  continue
44
  else:
45
  tokenizer = AutoTokenizer.from_pretrained(path)
 
49
  models[name] = {"tokenizer": tokenizer, "model": model}
50
  return models
51
 
52
+ st.sidebar.write("πŸ“₯ Load Models.")
53
  models = load_models()
54
+ st.sidebar.write(f"βœ… Successfully loaded models:{models}")
55
 
56
  model_choice = st.selectbox("Choose a model:", list(MODEL_OPTIONS.keys()))
57
  tokenizer = models[model_choice]["tokenizer"]
 
62
  def load_gsm8k_dataset():
63
  return load_dataset("openai/gsm8k", "main")["test"]
64
 
65
+ st.sidebar.write("πŸ“₯ Load GSM8K")
66
  gsm8k_data = load_gsm8k_dataset()
67
+ st.sidebar.write("πŸ“Š GSM8K loaded:", len(gsm8k_data), "samples")
68
 
69
  # === TABS ===
70
  tab1, tab2 = st.tabs(["πŸ”“ Manual Prompting", "πŸ“Š GSM8K Evaluation"])
 
73
  with tab1:
74
  prompt = st.text_area("Enter your math prompt:", "Jasper has 5 apples and eats 2 of them. How many apples does he have left?")
75
  if st.button("Generate Response", key="manual"):
76
+ with st.sidebar.spinner("πŸ”„ Generating..."):
77
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
78
  output = model.generate(
79
  **inputs,
 
100
 
101
  if st.button("Run GSM8K Sample"):
102
  try:
103
+ with st.sidebar.spinner("πŸ”„ Generating..."):
104
+ sample = random.choice(gsm8k_data)
105
+ question = sample["question"]
106
+ gold_answer = sample["answer"]
107
+
108
+ inputs = tokenizer(question, return_tensors="pt").to(model.device)
109
+
110
+ st.markdown(f"Create Output")
111
+ output = model.generate(
112
+ **inputs,
113
+ max_new_tokens=150,
114
+ temperature=0.7,
115
+ do_sample=True,
116
+ pad_token_id=tokenizer.eos_token_id,
117
+ eos_token_id=tokenizer.eos_token_id,
118
+ )
119
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
120
+ response_only = generated_text[len(question):].strip()
121
 
122
  st.subheader("πŸ“Œ GSM8K Question")
123
  st.markdown(question)