Tonic commited on
Commit
05b47fe
·
verified ·
1 Parent(s): e52e0d0

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +272 -0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
2
+ import torch
3
+ from threading import Thread
4
+ import gradio as gr
5
+ import spaces
6
+ import re
7
+ import logging
8
+ import os
9
+ from peft import PeftModel
10
+
11
+ # Environment variables for GPT-OSS model configuration
12
+ import os
13
+ os.environ['HF_MODEL_ID'] = 'Tonic/med-gpt-oss-20b'
14
+ os.environ['LORA_MODEL_ID'] = 'Tonic/med-gpt-oss-20b'
15
+ os.environ['BASE_MODEL_ID'] = 'openai/gpt-oss-20b'
16
+ os.environ['MODEL_SUBFOLDER'] = ''
17
+ os.environ['MODEL_NAME'] = 'med-gpt-oss-20b'
18
+
19
+
20
+
21
+ # ----------------------------------------------------------------------
22
+ # Environment Variables Configuration
23
+ # ----------------------------------------------------------------------
24
+
25
+ # Get model configuration from environment variables
26
+ BASE_MODEL_ID = os.getenv('BASE_MODEL_ID', 'openai/gpt-oss-20b')
27
+ LORA_MODEL_ID = os.getenv('LORA_MODEL_ID', os.getenv('HF_MODEL_ID', 'Tonic/gpt-oss-20b-multilingual-reasoner'))
28
+ MODEL_NAME = os.getenv('MODEL_NAME', 'GPT-OSS Multilingual Reasoner')
29
+ MODEL_SUBFOLDER = os.getenv('MODEL_SUBFOLDER', '')
30
+
31
+ # If the LORA_MODEL_ID is the same as BASE_MODEL_ID, this is a merged model, not LoRA
32
+ USE_LORA = LORA_MODEL_ID != BASE_MODEL_ID and not LORA_MODEL_ID.startswith(BASE_MODEL_ID)
33
+
34
+ print(f"🔧 Configuration:")
35
+ print(f" Base Model: {BASE_MODEL_ID}")
36
+ print(f" Model ID: {LORA_MODEL_ID}")
37
+ print(f" Model Name: {MODEL_NAME}")
38
+ print(f" Model Subfolder: {MODEL_SUBFOLDER}")
39
+ print(f" Use LoRA: {USE_LORA}")
40
+
41
+ # ----------------------------------------------------------------------
42
+ # KaTeX delimiter config for Gradio
43
+ # ----------------------------------------------------------------------
44
+
45
+ LATEX_DELIMS = [
46
+ {"left": "$$", "right": "$$", "display": True},
47
+ {"left": "$", "right": "$", "display": False},
48
+ {"left": "\\[", "right": "\\]", "display": True},
49
+ {"left": "\\(", "right": "\\)", "display": False},
50
+ ]
51
+
52
+ # Configure logging
53
+ logging.basicConfig(level=logging.INFO)
54
+
55
+ # Load the model
56
+ try:
57
+ if USE_LORA:
58
+ # Load base model and LoRA adapter separately
59
+ print(f"🔄 Loading base model: {BASE_MODEL_ID}")
60
+ base_model = AutoModelForCausalLM.from_pretrained(
61
+ BASE_MODEL_ID,
62
+ torch_dtype="auto",
63
+ device_map="auto",
64
+ attn_implementation="kernels-community/vllm-flash-attn3"
65
+ )
66
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
67
+
68
+ # Load the LoRA adapter
69
+ try:
70
+ print(f"🔄 Loading LoRA adapter: {LORA_MODEL_ID}")
71
+ if MODEL_SUBFOLDER and MODEL_SUBFOLDER.strip():
72
+ model = PeftModel.from_pretrained(base_model, LORA_MODEL_ID, subfolder=MODEL_SUBFOLDER)
73
+ else:
74
+ model = PeftModel.from_pretrained(base_model, LORA_MODEL_ID)
75
+ print("✅ LoRA model loaded successfully!")
76
+ except Exception as lora_error:
77
+ print(f"⚠️ LoRA adapter failed to load: {lora_error}")
78
+ print("🔄 Falling back to base model...")
79
+ model = base_model
80
+ else:
81
+ # Load merged/fine-tuned model directly
82
+ print(f"🔄 Loading merged model: {LORA_MODEL_ID}")
83
+ model_kwargs = {
84
+ "torch_dtype": "auto",
85
+ "device_map": "auto",
86
+ "attn_implementation": "kernels-community/vllm-flash-attn3"
87
+ }
88
+
89
+ if MODEL_SUBFOLDER and MODEL_SUBFOLDER.strip():
90
+ model = AutoModelForCausalLM.from_pretrained(LORA_MODEL_ID, subfolder=MODEL_SUBFOLDER, **model_kwargs)
91
+ tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL_ID, subfolder=MODEL_SUBFOLDER)
92
+ else:
93
+ model = AutoModelForCausalLM.from_pretrained(LORA_MODEL_ID, **model_kwargs)
94
+ tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL_ID)
95
+ print("✅ Merged model loaded successfully!")
96
+
97
+ except Exception as e:
98
+ print(f"❌ Error loading model: {e}")
99
+ raise e
100
+
101
+ def format_conversation_history(chat_history):
102
+ messages = []
103
+ for item in chat_history:
104
+ role = item["role"]
105
+ content = item["content"]
106
+ if isinstance(content, list):
107
+ content = content[0]["text"] if content and "text" in content[0] else str(content)
108
+ messages.append({"role": role, "content": content})
109
+ return messages
110
+
111
+ def format_analysis_response(text):
112
+ """Enhanced response formatting with better structure and LaTeX support."""
113
+ # Look for analysis section followed by final response
114
+ m = re.search(r"analysis(.*?)assistantfinal", text, re.DOTALL | re.IGNORECASE)
115
+ if m:
116
+ reasoning = m.group(1).strip()
117
+ response = text.split("assistantfinal", 1)[-1].strip()
118
+
119
+ # Clean up the reasoning section
120
+ reasoning = re.sub(r'^analysis\s*', '', reasoning, flags=re.IGNORECASE).strip()
121
+
122
+ # Format with improved structure
123
+ formatted = (
124
+ f"**🤔 Analysis & Reasoning:**\n\n"
125
+ f"*{reasoning}*\n\n"
126
+ f"---\n\n"
127
+ f"**💬 Final Response:**\n\n{response}"
128
+ )
129
+
130
+ # Ensure LaTeX delimiters are balanced
131
+ if formatted.count("$") % 2:
132
+ formatted += "$"
133
+
134
+ return formatted
135
+
136
+ # Fallback: clean up the text and return as-is
137
+ cleaned = re.sub(r'^analysis\s*', '', text, flags=re.IGNORECASE).strip()
138
+ if cleaned.count("$") % 2:
139
+ cleaned += "$"
140
+ return cleaned
141
+
142
+ @spaces.GPU(duration=60)
143
+ def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
144
+ if not input_data.strip():
145
+ yield "Please enter a prompt."
146
+ return
147
+
148
+ # Log the request
149
+ logging.info(f"[User] {input_data}")
150
+ logging.info(f"[System] {system_prompt} | Temp={temperature} | Max tokens={max_new_tokens}")
151
+
152
+ new_message = {"role": "user", "content": input_data}
153
+ system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
154
+ processed_history = format_conversation_history(chat_history)
155
+ messages = system_message + processed_history + [new_message]
156
+ prompt = tokenizer.apply_chat_template(
157
+ messages,
158
+ tokenize=False,
159
+ add_generation_prompt=True
160
+ )
161
+
162
+ # Create streamer for proper streaming
163
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
164
+
165
+ # Prepare generation kwargs
166
+ generation_kwargs = {
167
+ "max_new_tokens": max_new_tokens,
168
+ "do_sample": True,
169
+ "temperature": temperature,
170
+ "top_p": top_p,
171
+ "top_k": top_k,
172
+ "repetition_penalty": repetition_penalty,
173
+ "pad_token_id": tokenizer.eos_token_id,
174
+ "streamer": streamer,
175
+ "use_cache": True
176
+ }
177
+
178
+ # Tokenize input using the chat template
179
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
180
+
181
+ # Start generation in a separate thread
182
+ thread = Thread(target=model.generate, kwargs={**inputs, **generation_kwargs})
183
+ thread.start()
184
+
185
+ # Stream the response with enhanced formatting
186
+ collected_text = ""
187
+ buffer = ""
188
+ yielded_once = False
189
+
190
+ try:
191
+ for chunk in streamer:
192
+ if not chunk:
193
+ continue
194
+
195
+ collected_text += chunk
196
+ buffer += chunk
197
+
198
+ # Initial yield to show immediate response
199
+ if not yielded_once:
200
+ yield chunk
201
+ buffer = ""
202
+ yielded_once = True
203
+ continue
204
+
205
+ # Yield accumulated text periodically for smooth streaming
206
+ if "\n" in buffer or len(buffer) > 150:
207
+ # Use enhanced formatting for partial text
208
+ partial_formatted = format_analysis_response(collected_text)
209
+ yield partial_formatted
210
+ buffer = ""
211
+
212
+ # Final formatting with complete text
213
+ final_formatted = format_analysis_response(collected_text)
214
+ yield final_formatted
215
+
216
+ except Exception as e:
217
+ logging.exception("Generation streaming failed")
218
+ yield f"❌ Error during generation: {e}"
219
+
220
+ demo = gr.ChatInterface(
221
+ fn=generate_response,
222
+ additional_inputs=[
223
+ gr.Slider(label="Max new tokens", minimum=64, maximum=4096, step=1, value=2048),
224
+ gr.Textbox(
225
+ label="System Prompt",
226
+ value="You are a helpful assistant. Reasoning: medium",
227
+ lines=4,
228
+ placeholder="Change system prompt"
229
+ ),
230
+ gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
231
+ gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
232
+ gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
233
+ gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
234
+ ],
235
+ examples=[
236
+ [{"text": "Explain Newton's laws clearly and concisely with mathematical formulas"}],
237
+ [{"text": "Write a Python function to calculate the Fibonacci sequence"}],
238
+ [{"text": "What are the benefits of open weight AI models? Include analysis."}],
239
+ [{"text": "Solve this equation: $x^2 + 5x + 6 = 0$"}],
240
+ ],
241
+ cache_examples=False,
242
+ type="messages",
243
+ description=f"""
244
+
245
+ # 🙋🏻‍♂️Welcome to 🌟{MODEL_NAME} Demo !
246
+
247
+ **Model**: `{LORA_MODEL_ID}`
248
+ **Base**: `{BASE_MODEL_ID}`
249
+
250
+ ✨ **Enhanced Features:**
251
+ - 🧠 **Advanced Reasoning**: Detailed analysis and step-by-step thinking
252
+ - 📊 **LaTeX Support**: Mathematical formulas rendered beautifully (use `$` or `$$`)
253
+ - 🎯 **Improved Formatting**: Clear separation of reasoning and final responses
254
+ - 📝 **Smart Logging**: Better error handling and request tracking
255
+
256
+ 💡 **Usage Tips:**
257
+ - Adjust reasoning level in system prompt (e.g., "Reasoning: high")
258
+ - Use LaTeX for math: `$E = mc^2$` or `$$\\int x^2 dx$$`
259
+ - Wait a couple of seconds initially for model loading
260
+ """,
261
+ fill_height=True,
262
+ textbox=gr.Textbox(
263
+ label="Query Input",
264
+ placeholder="Type your prompt (supports LaTeX: $x^2 + y^2 = z^2$)"
265
+ ),
266
+ stop_btn="Stop Generation",
267
+ multimodal=False,
268
+ theme=gr.themes.Soft()
269
+ )
270
+
271
+ if __name__ == "__main__":
272
+ demo.launch(share=True)