ethanwinters1907 commited on
Commit
5cb50ed
·
verified ·
1 Parent(s): 6e23460

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -267
app.py CHANGED
@@ -1,283 +1,35 @@
1
- from flask import Flask, request, jsonify, render_template_string
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
- import os
5
 
6
  app = Flask(__name__)
7
 
8
- # Load model and tokenizer with proper configuration
9
- model_name = "openai/gpt-oss-20b"
10
- print("Loading model and tokenizer...")
11
-
12
  tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
13
  model = AutoModelForCausalLM.from_pretrained("openai/gpt-oss-20b")
14
 
15
- torch_dtype="auto",
16
- device_map="auto"
17
- )
18
 
19
- print("Model loaded successfully!")
 
 
 
20
 
21
- # HTML template
22
- HTML_TEMPLATE = """
23
- <!DOCTYPE html>
24
- <html>
25
- <head>
26
- <title>OpenAI GPT-OSS-20B Chat</title>
27
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
28
- <style>
29
- body {
30
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
31
- max-width: 800px;
32
- margin: 0 auto;
33
- padding: 20px;
34
- background-color: #f5f5f5;
35
- }
36
- .container {
37
- background-color: white;
38
- border-radius: 8px;
39
- padding: 20px;
40
- box-shadow: 0 2px 10px rgba(0,0,0,0.1);
41
- }
42
- h1 {
43
- text-align: center;
44
- color: #333;
45
- margin-bottom: 30px;
46
- }
47
- #chat-container {
48
- border: 1px solid #ddd;
49
- height: 400px;
50
- overflow-y: auto;
51
- padding: 15px;
52
- margin-bottom: 15px;
53
- background-color: #fafafa;
54
- border-radius: 6px;
55
- }
56
- .message {
57
- margin: 12px 0;
58
- padding: 10px 15px;
59
- border-radius: 8px;
60
- max-width: 80%;
61
- word-wrap: break-word;
62
- }
63
- .user {
64
- background-color: #007bff;
65
- color: white;
66
- margin-left: auto;
67
- text-align: right;
68
- }
69
- .assistant {
70
- background-color: #e9ecef;
71
- color: #333;
72
- margin-right: auto;
73
- }
74
- #input-container {
75
- display: flex;
76
- gap: 10px;
77
- align-items: center;
78
- }
79
- #message-input {
80
- flex: 1;
81
- padding: 12px;
82
- border: 2px solid #ddd;
83
- border-radius: 6px;
84
- font-size: 14px;
85
- }
86
- #message-input:focus {
87
- outline: none;
88
- border-color: #007bff;
89
- }
90
- #send-button {
91
- padding: 12px 20px;
92
- background-color: #007bff;
93
- color: white;
94
- border: none;
95
- cursor: pointer;
96
- border-radius: 6px;
97
- font-size: 14px;
98
- font-weight: 500;
99
- }
100
- #send-button:hover:not(:disabled) {
101
- background-color: #0056b3;
102
- }
103
- #send-button:disabled {
104
- background-color: #ccc;
105
- cursor: not-allowed;
106
- }
107
- #loading {
108
- display: none;
109
- text-align: center;
110
- color: #666;
111
- margin: 10px 0;
112
- font-style: italic;
113
- }
114
- .error {
115
- color: #d32f2f;
116
- }
117
- .typing-indicator {
118
- display: none;
119
- margin: 12px 0;
120
- padding: 10px 15px;
121
- background-color: #e9ecef;
122
- border-radius: 8px;
123
- max-width: 80%;
124
- }
125
- .typing-dots {
126
- display: inline-block;
127
- }
128
- .typing-dots span {
129
- display: inline-block;
130
- width: 8px;
131
- height: 8px;
132
- border-radius: 50%;
133
- background-color: #999;
134
- margin: 0 2px;
135
- animation: typing 1.4s infinite both;
136
- }
137
- .typing-dots span:nth-child(2) { animation-delay: 0.2s; }
138
- .typing-dots span:nth-child(3) { animation-delay: 0.4s; }
139
- @keyframes typing {
140
- 0%, 60%, 100% { transform: translateY(0); }
141
- 30% { transform: translateY(-10px); }
142
- }
143
- </style>
144
- </head>
145
- <body>
146
- <div class="container">
147
- <h1>🤖 OpenAI GPT-OSS-20B Chat</h1>
148
- <div id="chat-container">
149
- <div class="message assistant">
150
- <strong>Assistant:</strong> Hello! I'm GPT-OSS-20B. How can I help you today?
151
- </div>
152
- </div>
153
- <div class="typing-indicator" id="typing-indicator">
154
- <strong>Assistant:</strong> <div class="typing-dots"><span></span><span></span><span></span></div>
155
- </div>
156
- <div id="loading">Generating response...</div>
157
- <div id="input-container">
158
- <input type="text" id="message-input" placeholder="Type your message here..." onkeypress="if(event.key==='Enter') sendMessage()">
159
- <button id="send-button" onclick="sendMessage()">Send</button>
160
- </div>
161
- </div>
162
 
163
- <script>
164
- let chatHistory = [];
165
-
166
- function addMessage(role, content, isError = false) {
167
- const chatContainer = document.getElementById('chat-container');
168
- const messageDiv = document.createElement('div');
169
- messageDiv.className = `message ${role}`;
170
- if (isError) messageDiv.classList.add('error');
171
- messageDiv.innerHTML = `<strong>${role === 'user' ? 'You' : 'Assistant'}:</strong> ${content}`;
172
- chatContainer.appendChild(messageDiv);
173
- chatContainer.scrollTop = chatContainer.scrollHeight;
174
- }
175
-
176
- async function sendMessage() {
177
- const input = document.getElementById('message-input');
178
- const sendButton = document.getElementById('send-button');
179
- const typingIndicator = document.getElementById('typing-indicator');
180
- const message = input.value.trim();
181
-
182
- if (!message) return;
183
-
184
- addMessage('user', message);
185
- input.value = '';
186
- sendButton.disabled = true;
187
-
188
- // Show typing indicator
189
- typingIndicator.style.display = 'block';
190
- const chatContainer = document.getElementById('chat-container');
191
- chatContainer.scrollTop = chatContainer.scrollHeight;
192
-
193
- try {
194
- const response = await fetch('/chat', {
195
- method: 'POST',
196
- headers: { 'Content-Type': 'application/json' },
197
- body: JSON.stringify({ message: message, history: chatHistory })
198
- });
199
-
200
- const data = await response.json();
201
-
202
- if (data.error) {
203
- addMessage('assistant', `Error: ${data.error}`, true);
204
- } else {
205
- addMessage('assistant', data.response);
206
- chatHistory.push([message, data.response]);
207
- }
208
- } catch (error) {
209
- addMessage('assistant', `Network Error: ${error.message}`, true);
210
- } finally {
211
- typingIndicator.style.display = 'none';
212
- sendButton.disabled = false;
213
- input.focus();
214
- }
215
- }
216
-
217
- // Focus input on load
218
- document.addEventListener('DOMContentLoaded', function() {
219
- document.getElementById('message-input').focus();
220
- });
221
- </script>
222
- </body>
223
- </html>
224
- """
225
 
226
- @app.route('/')
227
- def home():
228
- return render_template_string(HTML_TEMPLATE)
229
 
230
- @app.route('/chat', methods=['POST'])
231
- def chat():
232
- try:
233
- data = request.json
234
- message = data.get('message', '')
235
- history = data.get('history', [])
236
-
237
- # Format messages
238
- messages = []
239
- for human_msg, assistant_msg in history:
240
- messages.append({"role": "user", "content": human_msg})
241
- messages.append({"role": "assistant", "content": assistant_msg})
242
- messages.append({"role": "user", "content": message})
243
-
244
- # Apply chat template
245
- inputs = tokenizer.apply_chat_template(
246
- messages,
247
- add_generation_prompt=True,
248
- return_tensors="pt",
249
- return_dict=True,
250
- ).to(model.device)
251
-
252
- # Generate response
253
- with torch.no_grad():
254
- outputs = model.generate(
255
- **inputs,
256
- max_new_tokens=300,
257
- temperature=0.7,
258
- do_sample=True,
259
- pad_token_id=tokenizer.eos_token_id
260
- )
261
-
262
- # Decode response
263
- response = tokenizer.decode(
264
- outputs[0][inputs["input_ids"].shape[-1]:],
265
- skip_special_tokens=True
266
- )
267
-
268
- return jsonify({"response": response.strip()})
269
-
270
- except Exception as e:
271
- print(f"Error: {str(e)}")
272
- return jsonify({"error": str(e)}), 500
273
 
274
- @app.route('/health')
275
- def health():
276
- return jsonify({
277
- "status": "healthy",
278
- "model": "openai/gpt-oss-20b"
279
- })
280
 
281
  if __name__ == '__main__':
282
- port = int(os.environ.get("PORT", 7860))
283
- app.run(host='0.0.0.0', port=port, debug=False)
 
1
+ from flask import Flask, request, jsonify
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
  app = Flask(__name__)
6
 
7
+ # Load tokenizer and model once when the server starts
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
9
  model = AutoModelForCausalLM.from_pretrained("openai/gpt-oss-20b")
10
 
11
+ # Move model to GPU if available, else CPU
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model.to(device)
14
 
15
+ @app.route('/generate', methods=['POST'])
16
+ def generate_text():
17
+ data = request.get_json()
18
+ prompt = data.get('prompt')
19
 
20
+ if not prompt:
21
+ return jsonify({'error': 'No prompt provided'}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Tokenize input and move tensors to device
24
+ inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Generate output tokens (you can tweak max_length)
27
+ outputs = model.generate(inputs, max_length=50, do_sample=True, top_k=50, top_p=0.95)
 
28
 
29
+ # Decode tokens to string
30
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ return jsonify({'generated_text': generated_text})
 
 
 
 
 
33
 
34
  if __name__ == '__main__':
35
+ app.run(debug=True)