File size: 10,264 Bytes
0db0445
 
 
d044800
0db0445
ccfe63c
bc71d3d
 
522d800
d044800
0db0445
9aaf868
 
 
0346aed
ccfe63c
0346aed
 
 
 
 
 
0db0445
d044800
0db0445
d044800
 
227c84a
 
0db0445
 
d044800
227c84a
0db0445
 
227c84a
0db0445
 
d044800
 
0db0445
 
d044800
0db0445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d044800
522d800
7711122
522d800
7711122
d044800
ccfe63c
7711122
227c84a
522d800
7ccc6c1
 
 
 
 
 
 
 
 
 
 
 
 
 
227c84a
522d800
7711122
 
522d800
7711122
0346aed
d044800
 
 
7711122
 
 
 
 
d044800
7711122
 
 
 
 
 
 
 
 
 
 
 
ccfe63c
 
 
 
d044800
ccfe63c
0db0445
7711122
 
 
 
ccfe63c
 
0db0445
 
d044800
0db0445
227c84a
0346aed
0db0445
 
0346aed
0db0445
 
0346aed
 
0db0445
ccfe63c
0db0445
0346aed
0db0445
0346aed
 
 
 
 
 
 
0db0445
0346aed
0db0445
0346aed
0db0445
 
0346aed
 
 
 
d044800
0346aed
 
0db0445
0346aed
 
 
 
 
227c84a
0346aed
d044800
227c84a
0db0445
0346aed
ccfe63c
0346aed
 
d044800
ccfe63c
 
 
 
 
 
 
 
 
d044800
 
ccfe63c
e51e541
 
ccfe63c
 
0346aed
 
ccfe63c
 
 
 
 
0db0445
 
ccfe63c
0346aed
 
ccfe63c
d044800
0db0445
d044800
0db0445
0346aed
 
 
0db0445
227c84a
ccfe63c
0346aed
ccfe63c
0346aed
0db0445
0346aed
 
ccfe63c
0346aed
ccfe63c
 
 
 
0346aed
 
 
 
 
 
 
 
 
 
 
d044800
0346aed
 
 
 
 
 
 
 
 
 
0db0445
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import streamlit as st
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM

# Set page config first
st.set_page_config(page_title="Coding Multiple Choice Q&A", layout="wide")

# Use the specified model
MODEL_PATH = "tuandunghcmut/Qwen25_Coder_MultipleChoice_v4"



from coding_examples import CODING_EXAMPLES_BY_CATEGORY

# Flatten examples
CODING_EXAMPLES = []
for category, examples in CODING_EXAMPLES_BY_CATEGORY.items():
    for example in examples:
        example["category"] = category
        CODING_EXAMPLES.append(example)

class PromptCreator:
    def __init__(self, prompt_type="yaml"):
        self.prompt_type = prompt_type
        
    def format_choices(self, choices):
        if not choices: return ""
        if isinstance(choices, str): return choices
        return "\n".join(f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices))

    def get_max_letter(self, choices):
        if not choices: return "A"
        if isinstance(choices, str):
            num_choices = len([line for line in choices.split("\n") if line.strip()])
            return "A" if num_choices == 0 else chr(64 + num_choices)
        return chr(64 + len(choices))

    def create_inference_prompt(self, question, choices):
        if not question: return ""
        formatted_choices = self.format_choices(choices)
        max_letter = self.get_max_letter(choices)
        
        return f"""Question: {question}

Choices:
{formatted_choices}

Analyze this question step-by-step and provide a detailed explanation.
Your response MUST be in YAML format as follows:

understanding: |
  <your understanding of what the question is asking>
analysis: |
  <your analysis of each option>
reasoning: |
  <your step-by-step reasoning process>
conclusion: |
  <your final conclusion>
answer: <single letter A through {max_letter}>

The answer field MUST contain ONLY a single character letter."""

class QwenModelHandler:
    def __init__(self, model_path):
        with st.spinner("Loading model..."):
            try:
                # Explicitly disable quantization options
                self.tokenizer = AutoTokenizer.from_pretrained(
                    model_path,
                    trust_remote_code=True
                )
                
                # Load with standard precision on CPU
                from peft import PeftModel
                from transformers import AutoModelForCausalLM

                base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct")
                self.model = PeftModel.from_pretrained(base_model, "tuandunghcmut/Qwen25_Coder_MultipleChoice_v4")
                # self.model = AutoModelForCausalLM.from_pretrained(
                #     model_path,
                #     torch_dtype=torch.float32,
                #     device_map="cpu",
                #     trust_remote_code=True,
                #     # Explicitly disable quantization
                #     load_in_8bit=False,
                #     load_in_4bit=False
                # )
                
                if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
            except Exception as e:
                st.error(f"Error: {str(e)}")
                raise

    def generate_response(self, prompt, max_tokens=512, temperature=0.7, 
                          top_p=0.9, top_k=50, repetition_penalty=1.0, 
                          do_sample=True):
        try:
            inputs = self.tokenizer(prompt, return_tensors="pt")
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    repetition_penalty=repetition_penalty,
                    do_sample=do_sample,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            if prompt in response:
                response = response[len(prompt):].strip()
            return response
        except Exception as e:
            return f"Error during generation: {str(e)}"

# Create prompt without requiring model
def create_prompt(question, choices):
    creator = PromptCreator(prompt_type="yaml")
    return creator.create_inference_prompt(question, choices)

def main():
    # Initialize session state
    if 'model_loaded' not in st.session_state:
        st.session_state.model_loaded = False
    if 'model_output' not in st.session_state:
        st.session_state.model_output = ""
    
    st.title("Coding Multiple Choice Q&A with YAML Reasoning")
    st.warning("⚠️ Running on CPU - model loading and inference will be slow")
    
    # Two-column layout
    col1, col2 = st.columns([4, 6])
    
    with col1:
        st.subheader("Examples")
        
        # Category selector
        category_options = ["All Categories"] + list(CODING_EXAMPLES_BY_CATEGORY.keys())
        selected_category = st.selectbox("Select a category", category_options)
        
        # Example selector
        if selected_category == "All Categories":
            example_options = [f"Example {i+1}: {ex['question']}" for i, ex in enumerate(CODING_EXAMPLES)]
        else:
            example_options = []
            start_idx = 0
            for cat, examples in CODING_EXAMPLES_BY_CATEGORY.items():
                if cat == selected_category:
                    example_options = [f"Example {start_idx+i+1}: {ex['question']}" for i, ex in enumerate(examples)]
                    break
                start_idx += len(examples)
        
        selected_example = st.selectbox("Select an example question", [""] + example_options)
        
        # Process selected example
        if selected_example:
            try:
                example_idx = int(selected_example.split(":")[0].split()[-1]) - 1
                example = CODING_EXAMPLES[example_idx]
                question = example["question"]
                choices = "\n".join(f"{chr(65+i)}. {choice}" for i, choice in enumerate(example["choices"]))
            except:
                question = ""
                choices = ""
        else:
            question = ""
            choices = ""
        
        st.subheader("Your Question")
        question_input = st.text_area("Question", value=question, height=100, 
                                     placeholder="Enter your coding question here...")
        
        choices_input = st.text_area("Choices", value=choices, height=150,
                                    placeholder="Enter each choice on a new line...")
        
        # Model Parameters
        temperature = st.slider("Temperature", 0.0, 1.0, 0.7, 0.1)
        
        with st.expander("Advanced Parameters"):
            max_tokens = st.slider("Max Tokens", 128, 1024, 512, 128)
            top_p = st.slider("Top-p", 0.1, 1.0, 0.9, 0.1)
            top_k = st.slider("Top-k", 1, 100, 50, 10)
            repetition_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.1, 0.1)
            do_sample = st.checkbox("Enable Sampling", True)
        
        # Load model button
        if not st.session_state.model_loaded:
            if st.button("Load Model", type="primary"):
                try:
                    st.session_state.model_handler = QwenModelHandler(MODEL_PATH)
                    st.session_state.prompt_creator = PromptCreator("yaml")
                    st.session_state.model_loaded = True
                    # st.experimental_rerun()
                    st.rerun()
                except Exception as e:
                    st.error(f"Failed to load model: {str(e)}")
        
        # Generate button
        if st.session_state.model_loaded:
            generate_button = st.button("Generate Response", type="primary")
        else:
            st.info("Please load the model first")
            generate_button = False
    
    with col2:
        # Show prompt
        st.subheader("Model Input")
        if question_input and choices_input:
            prompt = create_prompt(question_input, choices_input)
            st.text_area("Prompt", value=prompt, height=200, disabled=True)
        else:
            st.text_area("Prompt", value="", height=200, disabled=True)
        
        # Results Area
        st.subheader("Model Response")
        st.text_area("Response", value=st.session_state.model_output, height=300)
        
        # YAML parsing
        if st.session_state.model_output:
            try:
                with st.expander("Raw Output"):
                    st.code(st.session_state.model_output, language="yaml")
                
                try:
                    yaml_data = yaml.safe_load(st.session_state.model_output)
                    with st.expander("Parsed Output", expanded=True):
                        st.json(yaml_data)
                except:
                    st.warning("Could not parse output as YAML")
            except:
                pass
    
    # Handle generation
    if generate_button and st.session_state.model_loaded:
        if not question_input or not choices_input:
            st.error("Please provide both a question and choices.")
        else:
            try:
                prompt = st.session_state.prompt_creator.create_inference_prompt(question_input, choices_input)
                with st.spinner("Generating response..."):
                    response = st.session_state.model_handler.generate_response(
                        prompt=prompt,
                        max_tokens=max_tokens,
                        temperature=temperature,
                        top_p=top_p,
                        top_k=top_k,
                        repetition_penalty=repetition_penalty,
                        do_sample=do_sample
                    )
                    st.session_state.model_output = response
                    st.experimental_rerun()
            except Exception as e:
                st.error(f"Error generating response: {e}")

if __name__ == "__main__":
    main()