wassemgtk commited on
Commit
dd19186
Β·
verified Β·
1 Parent(s): 47395b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -124
app.py CHANGED
@@ -1,145 +1,202 @@
 
 
1
  import os
2
  import json
3
- import requests
4
- import gradio as gr
5
 
6
- FIREWORKS_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
7
- MODEL_ID = os.getenv("FIREWORKS_MODEL_ID", "accounts/waseem-9b447b/models/ft-gdixl08u-sz53t")
 
8
 
9
- # Secrets (server-side only; never sent to the client UI)
10
- FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY") # required
11
- SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT")
12
 
13
- if not FIREWORKS_API_KEY:
14
- raise RuntimeError("Missing FIREWORKS_API_KEY environment variable")
15
-
16
- def _fireworks_stream(payload):
17
- """Generator that streams tokens from Fireworks chat completions SSE response."""
 
 
 
 
 
 
 
 
 
 
 
18
  headers = {
19
  "Accept": "application/json",
20
  "Content-Type": "application/json",
21
- "Authorization": f"Bearer {FIREWORKS_API_KEY}",
22
  }
23
- payload = dict(payload)
24
- payload["stream"] = True
25
- with requests.post(FIREWORKS_URL, headers=headers, json=payload, stream=True) as r:
26
- r.raise_for_status()
27
- buffer = ""
28
- for line in r.iter_lines(decode_unicode=True):
29
- if not line:
30
- continue
31
- if line.startswith("data:"):
32
- data = line[len("data:"):].strip()
33
- if data == "[DONE]":
34
- break
35
- try:
36
- obj = json.loads(data)
37
- except json.JSONDecodeError:
38
- buffer += data
39
- try:
40
- obj = json.loads(buffer)
41
- buffer = ""
42
- except Exception:
43
- continue
44
- try:
45
- delta = obj["choices"][0]["delta"]
46
- if "content" in delta and delta["content"]:
47
- yield delta["content"]
48
- except Exception:
49
- continue
50
-
51
- def _normalize_history_to_messages(history):
52
- """Normalize history from Gradio into OpenAI-style messages without system prompt."""
53
- # Chatbot(type='messages') already gives a list of dicts: [{'role': 'user'|'assistant', 'content': '...'}, ...]
54
- if not history:
55
- return []
56
- if isinstance(history, list) and len(history) > 0 and isinstance(history[0], dict) and "role" in history[0]:
57
- # Already messages format; pass through (filter any roles other than user/assistant)
58
- return [m for m in history if m.get("role") in ("user", "assistant")]
59
- # Back-compat: history may be list of (user, assistant) tuples
60
- msgs = []
61
- for u, a in history:
62
- if u:
63
- msgs.append({"role": "user", "content": u})
64
- if a:
65
- msgs.append({"role": "assistant", "content": a})
66
- return msgs
67
-
68
- def _build_messages(history, user_message):
69
- messages = []
70
- if SYSTEM_PROMPT:
71
- messages.append({"role": "system", "content": SYSTEM_PROMPT})
72
- messages.extend(_normalize_history_to_messages(history))
73
- if user_message:
74
- messages.append({"role": "user", "content": user_message})
75
- return messages
76
-
77
- def chat_fn(user_message, history, max_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty):
78
  payload = {
79
- "model": MODEL_ID,
80
- "max_tokens": int(max_tokens),
81
- "temperature": float(temperature),
82
- "top_p": float(top_p),
83
- "top_k": int(top_k),
84
- "presence_penalty": float(presence_penalty),
85
- "frequency_penalty": float(frequency_penalty),
86
- "messages": _build_messages(history, user_message),
87
  }
88
- for token in _fireworks_stream(payload):
89
- yield token
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- def clear_history():
92
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- with gr.Blocks(theme=gr.themes.Soft(), css="""
95
- :root { --radius: 16px; }
96
- #title { font-weight: 800; letter-spacing: -0.02em; }
97
- div.controls { gap: 10px !important; }
98
- """) as demo:
99
- gr.HTML("""
100
- <div style="display:flex; align-items:center; gap:12px; margin: 6px 0 16px;">
101
- <svg width="28" height="28" viewBox="0 0 24 24" fill="none"><path d="M12 3l7 4v6c0 5-7 8-7 8s-7-3-7-8V7l7-4z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/></svg>
102
- <div>
103
- <div id="title" style="font-size:1.25rem;">Fireworks Chat Playground</div>
104
- <div style="opacity:0.7; font-size:0.95rem;">Secure, streamed chat to <code>inference/v1/chat/completions</code></div>
105
- </div>
106
- </div>
107
- """)
108
  with gr.Row():
109
  with gr.Column(scale=3):
110
- # Use messages format to avoid deprecation
111
- chatbot = gr.Chatbot(height=480, type="messages", avatar_images=(None, None))
112
- with gr.Row(elem_classes=["controls"]):
113
- max_tokens = gr.Slider(32, 8192, value=4000, step=16, label="Max tokens")
114
- temperature = gr.Slider(0.0, 2.0, value=0.6, step=0.05, label="Temperature")
115
- with gr.Column(scale=2):
116
- with gr.Group():
117
- top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="top_p")
118
- top_k = gr.Slider(0, 200, value=40, step=1, label="top_k")
119
- presence_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="presence_penalty")
120
- frequency_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="frequency_penalty")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  gr.Markdown("""
122
- **Security notes**
123
- - API key and system prompt are server-side environment variables.
124
- - Change the model id with `FIREWORKS_MODEL_ID` (env var).
 
 
125
  """)
126
- clear_btn = gr.Button("Clear", variant="secondary")
127
- chat = gr.ChatInterface(
128
- fn=chat_fn,
129
- chatbot=chatbot,
130
- additional_inputs=[max_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty],
131
- title=None,
132
- submit_btn="Send",
133
 
134
- examples=[
135
- ["Hello!", 4000, 0.6, 1.0, 40, 0.0, 0.0],
136
- ["Summarize: Why is retrieval-augmented generation useful for insurers?", 4000, 0.6, 1.0, 40, 0.0, 0.0],
137
- ["Write a 3-bullet status update for the Palmyra team.", 4000, 0.6, 1.0, 40, 0.0, 0.0]
138
- ],
139
-
140
- description="Start chatting below. Streaming is enabled."
 
 
 
 
 
 
141
  )
142
- clear_btn.click(fn=clear_history, outputs=chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
 
144
  if __name__ == "__main__":
145
- demo.queue().launch(server_name="0.0.0.0")
 
1
+ import gradio as gr
2
+ import requests
3
  import os
4
  import json
 
 
5
 
6
+ # These will be set as Hugging Face Spaces secrets
7
+ API_KEY = os.environ.get("FIREWORKS_API_KEY", "")
8
+ SYSTEM_PROMPT = os.environ.get("SYSTEM_PROMPT", "You are a helpful AI assistant.")
9
 
10
+ # API endpoint
11
+ API_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
 
12
 
13
+ def chat_with_model(message, history, temperature, max_tokens, top_p, top_k):
14
+ """
15
+ Send a message to the Fireworks AI API and return the response
16
+ """
17
+ # Build conversation history
18
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
19
+
20
+ # Add conversation history
21
+ for human, assistant in history:
22
+ messages.append({"role": "user", "content": human})
23
+ messages.append({"role": "assistant", "content": assistant})
24
+
25
+ # Add current message
26
+ messages.append({"role": "user", "content": message})
27
+
28
+ # Prepare the request
29
  headers = {
30
  "Accept": "application/json",
31
  "Content-Type": "application/json",
32
+ "Authorization": f"Bearer {API_KEY}"
33
  }
34
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  payload = {
36
+ "model": "accounts/waseem-9b447b/models/ft-gdixl08u-sz53t",
37
+ "max_tokens": max_tokens,
38
+ "top_p": top_p,
39
+ "top_k": top_k,
40
+ "presence_penalty": 0,
41
+ "frequency_penalty": 0,
42
+ "temperature": temperature,
43
+ "messages": messages
44
  }
45
+
46
+ try:
47
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
48
+ response.raise_for_status()
49
+
50
+ result = response.json()
51
+ assistant_message = result["choices"][0]["message"]["content"]
52
+ return assistant_message
53
+
54
+ except requests.exceptions.RequestException as e:
55
+ return f"❌ Error: {str(e)}\n\nPlease check your API key in Hugging Face Spaces secrets."
56
+ except (KeyError, IndexError) as e:
57
+ return f"❌ Error parsing response: {str(e)}"
58
 
59
+ # Custom CSS for a modern look
60
+ custom_css = """
61
+ .gradio-container {
62
+ font-family: 'Inter', sans-serif;
63
+ }
64
+ #title {
65
+ text-align: center;
66
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
67
+ -webkit-background-clip: text;
68
+ -webkit-text-fill-color: transparent;
69
+ font-size: 2.5em;
70
+ font-weight: bold;
71
+ margin-bottom: 0.5em;
72
+ }
73
+ #description {
74
+ text-align: center;
75
+ font-size: 1.1em;
76
+ color: #666;
77
+ margin-bottom: 2em;
78
+ }
79
+ .message-wrap {
80
+ border-radius: 12px !important;
81
+ }
82
+ """
83
 
84
+ # Create Gradio interface
85
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
86
+ gr.HTML("<h1 id='title'>πŸš€ AI Model Playground</h1>")
87
+ gr.HTML("<p id='description'>Powered by Fireworks AI - Fine-tuned Model</p>")
88
+
 
 
 
 
 
 
 
 
 
89
  with gr.Row():
90
  with gr.Column(scale=3):
91
+ chatbot = gr.Chatbot(
92
+ height=500,
93
+ bubble_full_width=False,
94
+ avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=ai"),
95
+ show_copy_button=True
96
+ )
97
+
98
+ with gr.Row():
99
+ msg = gr.Textbox(
100
+ placeholder="Type your message here...",
101
+ show_label=False,
102
+ scale=4,
103
+ container=False
104
+ )
105
+ submit_btn = gr.Button("Send πŸ“€", scale=1, variant="primary")
106
+
107
+ with gr.Row():
108
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear Chat πŸ—‘οΈ")
109
+
110
+ with gr.Column(scale=1):
111
+ gr.Markdown("### βš™οΈ Model Parameters")
112
+
113
+ temperature = gr.Slider(
114
+ minimum=0,
115
+ maximum=2,
116
+ value=0.6,
117
+ step=0.1,
118
+ label="Temperature",
119
+ info="Controls randomness. Lower = more focused"
120
+ )
121
+
122
+ max_tokens = gr.Slider(
123
+ minimum=100,
124
+ maximum=4000,
125
+ value=2000,
126
+ step=100,
127
+ label="Max Tokens",
128
+ info="Maximum length of response"
129
+ )
130
+
131
+ top_p = gr.Slider(
132
+ minimum=0,
133
+ maximum=1,
134
+ value=1,
135
+ step=0.05,
136
+ label="Top P",
137
+ info="Nucleus sampling threshold"
138
+ )
139
+
140
+ top_k = gr.Slider(
141
+ minimum=1,
142
+ maximum=100,
143
+ value=40,
144
+ step=1,
145
+ label="Top K",
146
+ info="Number of top tokens to consider"
147
+ )
148
+
149
+ gr.Markdown("---")
150
+ gr.Markdown("### πŸ“ Setup Instructions")
151
  gr.Markdown("""
152
+ 1. Go to your Space **Settings**
153
+ 2. Add these secrets:
154
+ - `FIREWORKS_API_KEY`: Your API key
155
+ - `SYSTEM_PROMPT`: Custom system prompt
156
+ 3. Restart the Space
157
  """)
158
+
159
+ # Handle message submission
160
+ def respond(message, chat_history, temp, max_tok, top_p_val, top_k_val):
161
+ if not API_KEY:
162
+ bot_message = "⚠️ Please set FIREWORKS_API_KEY in Hugging Face Spaces secrets!"
163
+ else:
164
+ bot_message = chat_with_model(message, chat_history, temp, max_tok, top_p_val, top_k_val)
165
 
166
+ chat_history.append((message, bot_message))
167
+ return "", chat_history
168
+
169
+ msg.submit(
170
+ respond,
171
+ [msg, chatbot, temperature, max_tokens, top_p, top_k],
172
+ [msg, chatbot]
173
+ )
174
+
175
+ submit_btn.click(
176
+ respond,
177
+ [msg, chatbot, temperature, max_tokens, top_p, top_k],
178
+ [msg, chatbot]
179
  )
180
+
181
+ # Add examples
182
+ gr.Examples(
183
+ examples=[
184
+ ["Hello! Can you introduce yourself?"],
185
+ ["What can you help me with?"],
186
+ ["Tell me an interesting fact about AI."],
187
+ ],
188
+ inputs=msg,
189
+ label="πŸ’‘ Try these examples"
190
+ )
191
+
192
+ gr.Markdown("""
193
+ ---
194
+ ### πŸ”’ Privacy & Security
195
+ - Your API key is stored securely in Hugging Face Spaces secrets
196
+ - System prompt is hidden from users
197
+ - All conversations are private to your session
198
+ """)
199
 
200
+ # Launch the app
201
  if __name__ == "__main__":
202
+ demo.launch()