File size: 5,672 Bytes
1beac26
39fd400
1beac26
 
 
39fd400
 
 
394a0e3
33b4451
52931fc
 
 
 
39fd400
 
 
 
 
 
f629b14
8fd7246
 
f629b14
 
 
8fd7246
eead688
8fd7246
f629b14
 
 
 
52931fc
39fd400
 
 
 
f629b14
39fd400
 
52931fc
39c3467
52931fc
 
39fd400
 
 
 
 
52931fc
 
 
39c3467
52931fc
 
39c3467
f629b14
39fd400
 
f629b14
39c3467
9d3ef25
f629b14
 
 
39c3467
f629b14
 
39c3467
f629b14
52931fc
39fd400
 
9d3ef25
550cb22
9d3ef25
 
 
 
 
 
1beac26
9d3ef25
39fd400
 
 
 
 
432c4f6
 
 
 
 
 
 
 
 
 
 
39fd400
de27701
432c4f6
 
 
9d3ef25
39fd400
 
4cc9747
39fd400
 
 
acbb905
d525201
39fd400
 
9d3ef25
 
52931fc
39fd400
 
33b4451
39fd400
52931fc
432c4f6
 
 
e950877
 
f629b14
39fd400
 
 
 
 
 
 
 
 
33b4451
 
394a0e3
39fd400
 
f629b14
 
 
 
 
394a0e3
e950877
39fd400
 
52931fc
4cc9747
 
 
52931fc
39fd400
432c4f6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import random
import streamlit as st
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList
from unsloth import FastLanguageModel, is_bfloat16_supported
from utils import SpecificStringStoppingCriteria
from cot import EIGHT_SHOT_PROMPT, FOUR_SHOT_PROMPT

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generation_util = [
    "Q:",
    "</s>",
    "<|im_end|>"
]

# GPT-2 and Mistral model registry
gpt_models = {
    "GPT-2 Small BL": "openai-community/gpt2",
    "GPT-2 Small CPT+CL+IFT": "jonathantiedchen/GPT2-Small-CPT-CL-IFT"
}

mistral_models = {
    "Mistral 7B BL": "unsloth/mistral-7b-bnb-4bit",
    "Mistral 7B CPT+CL": "jonathantiedchen/Mistral-7B-CPT-CL",
    "Mistral 7B CPT+IFT": "jonathantiedchen/MistralMath-CPT-IFT"
}

all_models = gpt_models | mistral_models


### Load GSM8K once
@st.cache_resource
def load_gsm8k_dataset():
    return load_dataset("openai/gsm8k", "main")["test"]


### Load Mistral
@st.cache_resource
def load_mistral(mistral_path, _models):
    try:
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=mistral_path,
            max_seq_length=2048,
            dtype=torch.bfloat16 if is_bfloat16_supported() else torch.float16,
            load_in_4bit=True
        )
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        FastLanguageModel.for_inference(model)
        _models[mistral_path] = {"tokenizer": tokenizer, "model": model}
    except Exception as e:
        st.sidebar.error(f"⚠️ Failed to load Mistral model with Unsloth: {e}")
    return _models


### Load GPT-2
@st.cache_resource
def load_gpts(path, _models):
    try:
        tokenizer = AutoTokenizer.from_pretrained(path)
        model = AutoModelForCausalLM.from_pretrained(path).to(device)
        model.eval()
        _models[path] = {"tokenizer": tokenizer, "model": model}
    except Exception as e:
        st.sidebar.error(f"⚠️ Failed to load GPT model: {e}")
    return _models


# Load models
st.title("🧠 Math LLM Demo")
models = {}
with st.sidebar:
    with st.spinner("πŸ“₯ Load all Models. That might take a while."):
        for model_path in mistral_models.values():
            models = load_mistral(model_path, models)
        for model_path in gpt_models.values():
            models = load_gpts(model_path, models)
    st.write("βœ… Successfully loaded all models.")


# Load GSM8K dataset and allow selection
st.sidebar.write("πŸ“₯ Load GSM8K")
gsm8k_data = load_gsm8k_dataset()
st.sidebar.write("πŸ“Š GSM8K loaded:", len(gsm8k_data), "samples")

# Check for random question index in query params
random_index = st.query_params.get("question_index")
if random_index is not None:
    try:
        default_index = int(random_index)
    except (ValueError, TypeError):
        default_index = 0
else:
    default_index = 0

question_index = st.selectbox("πŸ”’ Select GSM8K question index", range(len(gsm8k_data)), index=default_index)

if st.button("🎲 Pick Random Question"):
    new_random_index = random.randint(0, len(gsm8k_data) - 1)
    st.query_params.update(question_index=new_random_index)
    st.rerun()  # Force app to rerun to update the selectbox

default_prompt = "Jasper has 5 apples and eats 2 of them. How many apples does he have left?"
selected_question = gsm8k_data[question_index]["question"] if question_index is not None else default_prompt
correct_answer = gsm8k_data[question_index]["answer"]


# Prompt options
st.write('##')
use_cot = st.toggle("Use Few-Shot Prompt")
model_choice = st.selectbox("Choose a model:", list(all_models.keys()))
model_path = all_models[model_choice]
tokenizer = models[model_path]["tokenizer"]
model = models[model_path]["model"]

# Prompt input
prompt = st.text_area("Enter your math prompt:", selected_question)

# Generation
if st.button("Generate Response", key="manual"):
    # Check if the current prompt is from GSM8K dataset
    is_gsm8k_question = prompt == selected_question
    
    with st.sidebar:
        with st.spinner("πŸ”„ Generating..."):

            if use_cot:
                if 'mistral' in model_choice.lower():
                    prompt_template = EIGHT_SHOT_PROMPT
                else:
                    prompt_template = FOUR_SHOT_PROMPT
                input_text = prompt_template.format(question=prompt)
            else:
                input_text = prompt

            inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
            stop_criteria = SpecificStringStoppingCriteria(tokenizer, generation_util, len(input_text))
            stopping_criteria_list = StoppingCriteriaList([stop_criteria])

            with torch.no_grad():
                output = model.generate(
                    **inputs,
                    max_new_tokens=512,
                    temperature=1,
                    pad_token_id=tokenizer.eos_token_id,
                    stopping_criteria=stopping_criteria_list
                )
                generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
                response_only = generated_text[len(input_text):].strip() if generated_text.startswith(input_text) else generated_text.strip()

    with st.expander("πŸ”Ž Prompt"):
        st.subheader("πŸ”Ž Prompt")
        st.write(input_text)
    st.subheader("🧠 Model Output")
    st.success(response_only)
    
    # Only show correct answer if using actual GSM8K question
    if is_gsm8k_question:
        st.subheader("βœ… Correct Answer (GSM8K)")
        st.info(correct_answer)