Rajan Sharma commited on
Commit
2bdb6e6
·
verified ·
1 Parent(s): e9ea6c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -190
app.py CHANGED
@@ -6,15 +6,15 @@ from functools import lru_cache
6
  import gradio as gr
7
  import torch
8
 
9
- # Timezone conversion (Python 3.9+ stdlib)
10
  try:
11
  from zoneinfo import ZoneInfo
12
  except Exception:
13
- ZoneInfo = None # graceful fallback to UTC
14
 
15
- # Try Cohere SDK if present (for hosted path)
16
  try:
17
- import cohere # pip install cohere
18
  _HAS_COHERE = True
19
  except Exception:
20
  _HAS_COHERE = False
@@ -23,124 +23,84 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
23
  from huggingface_hub import login, HfApi
24
 
25
  # -------------------
26
- # Configuration
27
  # -------------------
28
  MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
29
-
30
- HF_TOKEN = (
31
- os.getenv("HUGGINGFACE_HUB_TOKEN") # official Spaces name
32
- or os.getenv("HF_TOKEN")
33
- )
34
-
35
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
36
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
37
 
38
  # -------------------
39
- # Helpers (status only)
40
  # -------------------
41
- def local_now_str(user_tz: str | None) -> tuple[str, str]:
42
- """Returns (label, formatted_time). Falls back to UTC if tz missing/invalid."""
43
- label = "UTC"
44
- dt = datetime.now(timezone.utc)
45
- if user_tz and ZoneInfo is not None:
46
- try:
47
- tz = ZoneInfo(user_tz)
48
- dt = datetime.now(tz)
49
- label = user_tz
50
- except Exception:
51
- dt = datetime.now(timezone.utc)
52
- label = "UTC"
53
- return label, dt.strftime("%Y-%m-%d %H:%M:%S")
54
-
55
-
56
  def pick_dtype_and_map():
57
  if torch.cuda.is_available():
58
  return torch.float16, "auto"
59
  if torch.backends.mps.is_available():
60
  return torch.float16, {"": "mps"}
61
- return torch.float32, "cpu" # CPU path (likely too big for R7B)
62
 
63
- def is_identity_query(message: str, history) -> bool:
64
- """Detects identity questions in current message or most recent user turn."""
65
  patterns = [
66
- r"\bwho\s+are\s+you\b",
67
- r"\bwhat\s+are\s+you\b",
68
- r"\bwhat\s+is\s+your\s+name\b",
69
- r"\bwho\s+is\s+this\b",
70
- r"\bidentify\s+yourself\b",
71
- r"\btell\s+me\s+about\s+yourself\b",
72
- r"\bdescribe\s+yourself\b",
73
- r"\band\s+you\s*\?\b",
74
- r"\byour\s+name\b",
75
- r"\bwho\s+am\s+i\s+chatting\s+with\b",
76
  ]
77
- def hit(text: str | None) -> bool:
78
- t = (text or "").strip().lower()
79
- return any(re.search(p, t) for p in patterns)
80
- if hit(message):
81
  return True
82
  if history:
83
- last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) and history[-1] else None
84
- if hit(last_user):
85
  return True
86
  return False
87
 
88
  # -------------------
89
- # Cohere Hosted Path
90
  # -------------------
91
  _co_client = None
92
  if USE_HOSTED_COHERE:
93
  _co_client = cohere.Client(api_key=COHERE_API_KEY)
94
 
95
  def _cohere_parse(resp):
96
- # v5+ responses.create
97
  if hasattr(resp, "output_text") and resp.output_text:
98
  return resp.output_text.strip()
99
  if getattr(resp, "message", None) and getattr(resp.message, "content", None):
100
  for p in resp.message.content:
101
  if hasattr(p, "text") and p.text:
102
  return p.text.strip()
103
- # v4 chat
104
  if hasattr(resp, "text") and resp.text:
105
  return resp.text.strip()
106
  return "Sorry, I couldn't parse the response from Cohere."
107
 
108
  def cohere_chat(message, history):
109
  try:
110
- # Prefer modern API
111
- try:
112
- msgs = []
113
- for u, a in (history or []):
114
- msgs.append({"role": "user", "content": u})
115
- msgs.append({"role": "assistant", "content": a})
116
- msgs.append({"role": "user", "content": message})
117
- resp = _co_client.responses.create(
118
- model="command-r7b-12-2024",
119
- messages=msgs,
120
- temperature=0.3,
121
- max_tokens=350,
122
- )
123
- except Exception:
124
- # Fallback to older chat API
125
- resp = _co_client.chat(
126
- model="command-r7b-12-2024",
127
- message=message,
128
- temperature=0.3,
129
- max_tokens=350,
130
- )
131
  return _cohere_parse(resp)
132
  except Exception as e:
133
  return f"Error calling Cohere API: {e}"
134
 
135
  # -------------------
136
- # Local HF Path
137
  # -------------------
138
  @lru_cache(maxsize=1)
139
  def load_local_model():
140
  if not HF_TOKEN:
141
  raise RuntimeError(
142
- "HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. "
143
- "Either set it, or provide COHERE_API_KEY to use Cohere's hosted API."
144
  )
145
  login(token=HF_TOKEN, add_to_git_credential=False)
146
  dtype, device_map = pick_dtype_and_map()
@@ -188,11 +148,10 @@ def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
188
  eos_token_id=tokenizer.eos_token_id,
189
  )
190
  gen_only = out[0, input_ids.shape[-1]:]
191
- text = tokenizer.decode(gen_only, skip_special_tokens=True)
192
- return text.strip()
193
 
194
  # -------------------
195
- # Chat callback (no meta in replies)
196
  # -------------------
197
  def chat_fn(message, history, user_tz):
198
  try:
@@ -203,97 +162,39 @@ def chat_fn(message, history, user_tz):
203
  model, tokenizer = load_local_model()
204
  inputs = build_inputs(tokenizer, message, history)
205
  return local_generate(model, tokenizer, inputs, max_new_tokens=350)
206
- except RuntimeError as e:
207
- emsg = str(e)
208
- if "out of memory" in emsg.lower() or "cuda" in emsg.lower():
209
- return "Local load likely OOM. Use a GPU Space or set COHERE_API_KEY to run via Cohere hosted API."
210
- return f"Error during chat: {e}"
211
  except Exception as e:
212
- return f"Error during chat: {e}"
213
 
214
  # -------------------
215
- # Theme & Styles (compatible with broad Gradio versions)
216
  # -------------------
217
  theme = gr.themes.Soft(
218
  primary_hue="teal",
219
  neutral_hue="slate",
220
  radius_size=gr.themes.sizes.radius_lg,
221
- ).set(
222
- shadow_drop="0 6px 24px rgba(0,0,0,.06)",
223
- shadow_spread="0 2px 8px rgba(0,0,0,.04)",
224
  )
225
 
226
  custom_css = """
227
  :root {
228
- --brand-bg: #e6f7f8; /* soft medical teal */
229
- --brand-card: #ffffff;
230
- --brand-text: #0f172a; /* slate-900 */
231
- --brand-subtle: #475569; /* slate-600 */
232
- --brand-accent: #0d9488; /* teal-600 */
233
- --brand-border: #cbd5e1; /* slate-300 */
234
  }
235
 
236
- /* Page background */
237
  .gradio-container {
238
  background: var(--brand-bg);
239
- color: var(--brand-text);
240
  }
241
 
242
- /* Title */
243
- h1, .prose h1 {
244
- color: var(--brand-text);
245
  font-weight: 700;
246
- letter-spacing: -0.01em;
247
- margin-bottom: 0.25rem !important;
248
- font-size: 28px !important; /* set via CSS for compatibility */
249
- }
250
-
251
- /* Chat bubbles */
252
- .message.user {
253
- background: var(--brand-accent) !important; /* teal bubble */
254
- color: #ffffff !important; /* white text */
255
- }
256
- .message.bot {
257
- background: var(--brand-card) !important; /* white bubble */
258
- color: var(--brand-text) !important; /* dark text */
259
- }
260
-
261
- /* Status badge wrapper */
262
- .status-wrap {
263
- display: flex;
264
- align-items: center;
265
- gap: .5rem;
266
- margin-bottom: 0.75rem;
267
  }
268
 
269
- /* Badge */
270
- .badge {
271
- display: inline-flex;
272
- align-items: center;
273
- gap: .5rem;
274
- padding: .45rem .75rem;
275
- border-radius: 999px;
276
- border: 1px solid var(--brand-border);
277
- background: #ecfdf5; /* green-50 */
278
- color: #065f46; /* green-800 */
279
- font-weight: 600;
280
- font-size: 14px;
281
- }
282
-
283
- /* Helper text */
284
- .helper {
285
- color: var(--brand-subtle);
286
- margin: .25rem 0 1rem 0;
287
- }
288
-
289
- /* Card rounding */
290
- .block, .gr-box, .gr-panel, .gr-group, .gr-form, .gradio-container .form {
291
- border-radius: 16px !important;
292
- }
293
-
294
- /* Inputs */
295
- textarea, input, .gr-input {
296
- border-radius: 12px !important;
297
  }
298
  """
299
 
@@ -301,57 +202,16 @@ textarea, input, .gr-input {
301
  # UI
302
  # -------------------
303
  with gr.Blocks(theme=theme, css=custom_css) as demo:
304
- # Hidden textbox to hold browser timezone
305
  tz_box = gr.Textbox(visible=False)
 
 
306
 
307
- # Capture browser timezone via JS and store in tz_box
308
- demo.load(
309
- fn=lambda tz: tz, # echo JS value
310
- inputs=[tz_box],
311
- outputs=[tz_box],
312
- js="() => Intl.DateTimeFormat().resolvedOptions().timeZone"
313
- )
314
-
315
- # Model status (auto, one-line badge)
316
- def model_status(_user_tz):
317
- try:
318
- if USE_HOSTED_COHERE:
319
- return (
320
- '<div class="status-wrap">'
321
- '<span class="badge">✅ Connected • Cohere API — model: '
322
- '<strong>command-r7b-12-2024</strong></span></div>'
323
- )
324
- api = HfApi(token=HF_TOKEN)
325
- mi = api.model_info(MODEL_ID)
326
- return (
327
- '<div class="status-wrap">'
328
- f'<span class="badge">✅ Connected • Local HF — model: '
329
- f'<strong>{mi.modelId}</strong></span></div>'
330
- )
331
- except Exception as e:
332
- return (
333
- '<div class="status-wrap">'
334
- f'<span class="badge" style="background:#fff7ed;color:#9a3412;border-color:#fed7aa;">'
335
- f'⚠️ Connection Issue — {str(e)}</span></div>'
336
- )
337
-
338
- # Header + status
339
  gr.Markdown("# Medical Decision Support AI")
340
- status_line = gr.HTML("<div class='status-wrap'><span class='badge'>Connecting…</span></div>")
341
- demo.load(fn=model_status, inputs=[tz_box], outputs=[status_line])
342
-
343
- # Helper text
344
- gr.Markdown(
345
- "<div class='helper'>Designed for healthcare executives: concise, reliable decision support. "
346
- "First response may take a moment while the model warms up.</div>"
347
- )
348
 
349
- # Chat
350
  gr.ChatInterface(
351
  fn=chat_fn,
352
  type="messages",
353
- additional_inputs=[tz_box], # pass timezone into chat_fn (future use)
354
- description="",
355
  examples=[
356
  ["What are the symptoms of hypertension?", ""],
357
  ["What are common drug interactions with aspirin?", ""],
 
6
  import gradio as gr
7
  import torch
8
 
9
+ # Timezone (Python 3.9+)
10
  try:
11
  from zoneinfo import ZoneInfo
12
  except Exception:
13
+ ZoneInfo = None
14
 
15
+ # Cohere SDK
16
  try:
17
+ import cohere
18
  _HAS_COHERE = True
19
  except Exception:
20
  _HAS_COHERE = False
 
23
  from huggingface_hub import login, HfApi
24
 
25
  # -------------------
26
+ # Config
27
  # -------------------
28
  MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
29
+ HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
 
 
 
 
 
30
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
31
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
32
 
33
  # -------------------
34
+ # Helpers
35
  # -------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def pick_dtype_and_map():
37
  if torch.cuda.is_available():
38
  return torch.float16, "auto"
39
  if torch.backends.mps.is_available():
40
  return torch.float16, {"": "mps"}
41
+ return torch.float32, "cpu"
42
 
43
+ def is_identity_query(message, history):
 
44
  patterns = [
45
+ r"\bwho\s+are\s+you\b", r"\bwhat\s+are\s+you\b",
46
+ r"\bwhat\s+is\s+your\s+name\b", r"\bwho\s+is\s+this\b",
47
+ r"\bidentify\s+yourself\b", r"\btell\s+me\s+about\s+yourself\b",
48
+ r"\bdescribe\s+yourself\b", r"\band\s+you\s*\?\b",
49
+ r"\byour\s+name\b", r"\bwho\s+am\s+i\s+chatting\s+with\b"
 
 
 
 
 
50
  ]
51
+ def match(t):
52
+ return any(re.search(p, (t or "").strip().lower()) for p in patterns)
53
+ if match(message):
 
54
  return True
55
  if history:
56
+ last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None
57
+ if match(last_user):
58
  return True
59
  return False
60
 
61
  # -------------------
62
+ # Cohere Hosted
63
  # -------------------
64
  _co_client = None
65
  if USE_HOSTED_COHERE:
66
  _co_client = cohere.Client(api_key=COHERE_API_KEY)
67
 
68
  def _cohere_parse(resp):
 
69
  if hasattr(resp, "output_text") and resp.output_text:
70
  return resp.output_text.strip()
71
  if getattr(resp, "message", None) and getattr(resp.message, "content", None):
72
  for p in resp.message.content:
73
  if hasattr(p, "text") and p.text:
74
  return p.text.strip()
 
75
  if hasattr(resp, "text") and resp.text:
76
  return resp.text.strip()
77
  return "Sorry, I couldn't parse the response from Cohere."
78
 
79
  def cohere_chat(message, history):
80
  try:
81
+ msgs = []
82
+ for u, a in (history or []):
83
+ msgs.append({"role": "user", "content": u})
84
+ msgs.append({"role": "assistant", "content": a})
85
+ msgs.append({"role": "user", "content": message})
86
+ resp = _co_client.responses.create(
87
+ model="command-r7b-12-2024",
88
+ messages=msgs,
89
+ temperature=0.3,
90
+ max_tokens=350,
91
+ )
 
 
 
 
 
 
 
 
 
 
92
  return _cohere_parse(resp)
93
  except Exception as e:
94
  return f"Error calling Cohere API: {e}"
95
 
96
  # -------------------
97
+ # Local HF Model
98
  # -------------------
99
  @lru_cache(maxsize=1)
100
  def load_local_model():
101
  if not HF_TOKEN:
102
  raise RuntimeError(
103
+ "HUGGINGFACE_HUB_TOKEN is not set."
 
104
  )
105
  login(token=HF_TOKEN, add_to_git_credential=False)
106
  dtype, device_map = pick_dtype_and_map()
 
148
  eos_token_id=tokenizer.eos_token_id,
149
  )
150
  gen_only = out[0, input_ids.shape[-1]:]
151
+ return tokenizer.decode(gen_only, skip_special_tokens=True).strip()
 
152
 
153
  # -------------------
154
+ # Chat Function
155
  # -------------------
156
  def chat_fn(message, history, user_tz):
157
  try:
 
162
  model, tokenizer = load_local_model()
163
  inputs = build_inputs(tokenizer, message, history)
164
  return local_generate(model, tokenizer, inputs, max_new_tokens=350)
 
 
 
 
 
165
  except Exception as e:
166
+ return f"Error: {e}"
167
 
168
  # -------------------
169
+ # Theme & CSS
170
  # -------------------
171
  theme = gr.themes.Soft(
172
  primary_hue="teal",
173
  neutral_hue="slate",
174
  radius_size=gr.themes.sizes.radius_lg,
 
 
 
175
  )
176
 
177
  custom_css = """
178
  :root {
179
+ --brand-bg: #e6f7f8; /* soft medical teal */
180
+ --brand-accent: #0d9488; /* teal-600 */
181
+ --brand-text-light: #ffffff;
 
 
 
182
  }
183
 
 
184
  .gradio-container {
185
  background: var(--brand-bg);
 
186
  }
187
 
188
+ h1 {
189
+ color: #0f172a;
 
190
  font-weight: 700;
191
+ font-size: 28px !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  }
193
 
194
+ /* Both bot and user bubbles teal with white text */
195
+ .message.user, .message.bot {
196
+ background: var(--brand-accent) !important;
197
+ color: var(--brand-text-light) !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  }
199
  """
200
 
 
202
  # UI
203
  # -------------------
204
  with gr.Blocks(theme=theme, css=custom_css) as demo:
 
205
  tz_box = gr.Textbox(visible=False)
206
+ demo.load(lambda tz: tz, inputs=[tz_box], outputs=[tz_box],
207
+ js="() => Intl.DateTimeFormat().resolvedOptions().timeZone")
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  gr.Markdown("# Medical Decision Support AI")
 
 
 
 
 
 
 
 
210
 
 
211
  gr.ChatInterface(
212
  fn=chat_fn,
213
  type="messages",
214
+ additional_inputs=[tz_box],
 
215
  examples=[
216
  ["What are the symptoms of hypertension?", ""],
217
  ["What are common drug interactions with aspirin?", ""],