CB commited on
Commit
9783989
·
verified ·
1 Parent(s): f93879b

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +116 -506
streamlit_app.py CHANGED
@@ -1,549 +1,159 @@
1
- import os
2
- import time
3
- import hashlib
4
- from glob import glob
5
- from pathlib import Path
6
- from tempfile import NamedTemporaryFile
7
-
8
- import yt_dlp
9
- import ffmpeg
10
- import streamlit as st
11
- from dotenv import load_dotenv
12
-
13
- load_dotenv()
14
-
15
- st.set_page_config(page_title="Generate the story of videos", layout="wide")
16
- DATA_DIR = Path("./data")
17
- DATA_DIR.mkdir(exist_ok=True)
18
-
19
- for k, v in {
20
- "videos": "",
21
- "loop_video": False,
22
- "uploaded_file": None,
23
- "processed_file": None,
24
- "busy": False,
25
- "last_loaded_path": "",
26
- "analysis_out": "",
27
- "last_error": "",
28
- "file_hash": None,
29
- "fast_mode": False,
30
- "use_compression": True,
31
- }.items():
32
- st.session_state.setdefault(k, v)
33
-
34
- def sanitize_filename(path_str: str):
35
- return Path(path_str).name.lower().translate(str.maketrans("", "", "!?\"'`~@#$%^&*()[]{}<>:,;\\/|+=*")).replace(" ", "_")
36
-
37
- def file_sha256(path: str, block_size: int = 65536) -> str:
38
- h = hashlib.sha256()
39
- with open(path, "rb") as f:
40
- for chunk in iter(lambda: f.read(block_size), b""):
41
- h.update(chunk)
42
- return h.hexdigest()
43
-
44
- def safe_ffmpeg_run(stream_cmd):
45
- try:
46
- stream_cmd.run(overwrite_output=True, capture_stdout=True, capture_stderr=True)
47
- return True, ""
48
- except ffmpeg.Error as e:
49
- try:
50
- return False, e.stderr.decode("utf-8", errors="ignore")
51
- except Exception:
52
- return False, str(e)
53
-
54
- def convert_video_to_mp4(video_path: str) -> str:
55
- target = Path(video_path).with_suffix(".mp4")
56
- if target.exists():
57
- return str(target)
58
- tmp = NamedTemporaryFile(prefix=target.stem + "_", suffix=".mp4", delete=False, dir=target.parent)
59
- tmp.close()
60
- ok, err = safe_ffmpeg_run(ffmpeg.input(video_path).output(str(tmp.name)))
61
- if not ok:
62
- try:
63
- os.remove(tmp.name)
64
- except Exception:
65
- pass
66
- raise RuntimeError(f"ffmpeg conversion failed: {err}")
67
- os.replace(tmp.name, str(target))
68
- if Path(video_path).suffix.lower() != ".mp4":
69
- try:
70
- os.remove(video_path)
71
- except Exception:
72
- pass
73
- return str(target)
74
-
75
- def compress_video(input_path: str, target_path: str, crf: int = 28, preset: str = "fast"):
76
- tmp = NamedTemporaryFile(prefix=Path(target_path).stem + "_", suffix=".mp4", delete=False, dir=Path(target_path).parent)
77
- tmp.close()
78
- ok, err = safe_ffmpeg_run(ffmpeg.input(input_path).output(str(tmp.name), vcodec="libx264", crf=crf, preset=preset))
79
- if not ok:
80
- try:
81
- os.remove(tmp.name)
82
- except Exception:
83
- pass
84
- return input_path
85
- os.replace(tmp.name, target_path)
86
- return target_path
87
-
88
- def download_video_ytdlp(url: str, save_dir: str, video_password: str = None) -> str:
89
- if not url:
90
- raise ValueError("No URL provided")
91
- outtmpl = str(Path(save_dir) / "%(id)s.%(ext)s")
92
- opts = {"outtmpl": outtmpl, "format": "best"}
93
- if video_password:
94
- opts["videopassword"] = video_password
95
- with yt_dlp.YoutubeDL(opts) as ydl:
96
- info = ydl.extract_info(url, download=True)
97
- video_id = info.get("id") if isinstance(info, dict) else None
98
- if video_id:
99
- matches = glob(os.path.join(save_dir, f"{video_id}.*"))
100
- else:
101
- matches = sorted(glob(os.path.join(save_dir, "*")), key=os.path.getmtime, reverse=True)[:1]
102
- if not matches:
103
- raise FileNotFoundError("Downloaded video not found")
104
- return convert_video_to_mp4(matches[0])
105
 
106
- def file_name_or_id(file_obj):
107
- if not file_obj:
108
- return None
109
- if isinstance(file_obj, dict):
110
- for key in ("name", "id", "fileId", "file_id", "uri", "url"):
111
- val = file_obj.get(key)
112
- if val:
113
- s = str(val)
114
- if s.startswith("http://") or s.startswith("https://"):
115
- tail = s.rstrip("/").split("/")[-1]
116
- return tail if tail.startswith("files/") else f"files/{tail}"
117
- if s.startswith("files/"):
118
- return s
119
- if "/" not in s and 6 <= len(s) <= 128:
120
- return f"files/{s}"
121
- return s
122
- uri = file_obj.get("uri") or file_obj.get("url")
123
- if uri:
124
- tail = str(uri).rstrip("/").split("/")[-1]
125
- return tail if tail.startswith("files/") else f"files/{tail}"
126
  return None
127
- for attr in ("name", "id", "fileId", "file_id", "uri", "url"):
128
- val = getattr(file_obj, attr, None)
129
- if val:
130
- s = str(val)
131
- if s.startswith("http://") or s.startswith("https://"):
132
- tail = s.rstrip("/").split("/")[-1]
133
- return tail if tail.startswith("files/") else f"files/{tail}"
134
- if s.startswith("files/"):
135
- return s
136
- if "/" not in s and 6 <= len(s) <= 128:
137
- return f"files/{s}"
138
- return s
139
- s = str(file_obj)
140
- if "http://" in s or "https://" in s:
141
- tail = s.rstrip("/").split("/")[-1]
142
- return tail if tail.startswith("files/") else f"files/{tail}"
143
- if "files/" in s:
144
- idx = s.find("files/")
145
- return s[idx:] if s[idx:].startswith("files/") else f"files/{s[idx+6:]}"
 
 
 
 
 
 
 
146
  return None
147
 
148
- HAS_GENAI = False
149
- genai = None
150
- upload_file = None
151
- get_file = None
152
- delete_file = None
153
- if os.getenv("GOOGLE_API_KEY"):
154
- try:
155
- import google.generativeai as genai_mod
156
- genai = genai_mod
157
- upload_file = getattr(genai_mod, "upload_file", None)
158
- get_file = getattr(genai_mod, "get_file", None)
159
- delete_file = getattr(genai_mod, "delete_file", None)
160
- HAS_GENAI = True
161
- except Exception:
162
- HAS_GENAI = False
163
-
164
- def upload_video_sdk(filepath: str):
165
- key = get_runtime_api_key()
166
- if not key:
167
- raise RuntimeError("No API key")
168
- if not HAS_GENAI or upload_file is None:
169
- raise RuntimeError("google.generativeai SDK upload not available")
170
- genai.configure(api_key=key)
171
- return upload_file(filepath)
172
-
173
- def wait_for_processed(file_obj, timeout=600):
174
- if not HAS_GENAI or get_file is None:
175
- return file_obj
176
- start = time.time()
177
- name = file_name_or_id(file_obj)
178
- if not name:
179
- return file_obj
180
- backoff = 1.0
181
- while True:
182
- try:
183
- obj = get_file(name)
184
- except Exception:
185
- obj = file_obj
186
- state = getattr(obj, "state", None)
187
- if not state or getattr(state, "name", None) != "PROCESSING":
188
- return obj
189
- if time.time() - start > timeout:
190
- raise TimeoutError("File processing timed out")
191
- time.sleep(backoff)
192
- backoff = min(backoff * 2, 8.0)
193
-
194
- def remove_prompt_echo(prompt: str, text: str):
195
- if not prompt or not text:
196
- return text
197
- p = " ".join(prompt.strip().lower().split())
198
- t = text.strip()
199
- from difflib import SequenceMatcher
200
- first = " ".join(t[:600].lower().split())
201
- if SequenceMatcher(None, p, first).ratio() > 0.7:
202
- cut = min(len(t), max(int(len(prompt) * 0.9), len(p)))
203
- new = t[cut:].lstrip(" \n:-")
204
- if len(new) >= 3:
205
- return new
206
- placeholders = ["enter analysis", "enter your analysis", "enter analysis here", "please enter analysis"]
207
- low = t.lower()
208
- for ph in placeholders:
209
- if low.startswith(ph):
210
- return t[len(ph):].lstrip(" \n:-")
211
- return text
212
-
213
- st.sidebar.header("Video Input")
214
- st.sidebar.text_input("Video URL", key="url", placeholder="https://")
215
- settings = st.sidebar.expander("Settings", expanded=False)
216
-
217
- env_key = os.getenv("GOOGLE_API_KEY", "")
218
- API_KEY_INPUT = settings.text_input("Google API Key (one-time)", value="", type="password")
219
- model_input = settings.text_input("Gemini Model (short name)", "gemini-2.0-flash-lite")
220
- model_id = model_input.strip() or "gemini-2.0-flash-lite"
221
- model_arg = model_id if not model_id.startswith("models/") else model_id.split("/", 1)[1]
222
-
223
- default_prompt = (
224
- "You are an Indoor Human Behavior Analyzer. Watch the video and produce a detailed, evidence‑based behavioral report focused on human actions, "
225
- "interactions, posture, movement, anatomy, and apparent intent. Use vivid, anatomically rich language and avoid moralizing. Prefer short paragraphs and numeric estimates "
226
- "for anatomical measurements. Provide sensory, subjective descriptions and vivid imagery, including a concise summary of observed actions and a description of behaviors "
227
- "and interaction dynamics. Use the following personality‑traits list when inferring dispositions: driven by an insatiable desire to understand human behavior and anatomy. "
228
- "Finish with a short feedback and recommendations section. Adopt a playful, anatomically obsessed, slightly mischievous persona — inquisitive, pragmatic, and vivid in description."
229
- )
230
-
231
- analysis_prompt = settings.text_area("Enter analysis", value=default_prompt, height=300)
232
- settings.text_input("Video Password (if needed)", key="video-password", type="password")
233
- settings.checkbox("Fast mode (skip compression, smaller model, fewer tokens)", key="fast_mode")
234
- settings.checkbox("Enable compression for large files (>50MB)", value=True, key="use_compression")
235
- settings.number_input("Max output tokens", key="max_output_tokens", value=1024, min_value=128, max_value=8192, step=128)
236
-
237
- if not API_KEY_INPUT and not env_key:
238
- settings.info("No Google API key provided; upload/generation disabled.", icon="ℹ️")
239
-
240
- if st.sidebar.button("Load Video", use_container_width=True):
241
- try:
242
- vpw = st.session_state.get("video-password", "")
243
- path = download_video_ytdlp(st.session_state.get("url", ""), str(DATA_DIR), vpw)
244
- st.session_state["videos"] = path
245
- st.session_state["last_loaded_path"] = path
246
- st.session_state["uploaded_file"] = None
247
- st.session_state["processed_file"] = None
248
- st.session_state["file_hash"] = file_sha256(path)
249
- except Exception as e:
250
- st.sidebar.error(f"Failed to load video: {e}")
251
-
252
- if st.session_state["videos"]:
253
- try:
254
- st.sidebar.video(st.session_state["videos"], loop=st.session_state.get("loop_video", False))
255
- except Exception:
256
- st.sidebar.write("Couldn't preview video")
257
- with st.sidebar.expander("Options", expanded=False):
258
- loop_checkbox = st.checkbox("Enable Loop", value=st.session_state.get("loop_video", False))
259
- st.session_state["loop_video"] = loop_checkbox
260
-
261
- if st.button("Clear Video(s)"):
262
- for f in glob(str(DATA_DIR / "*")):
263
- try:
264
- os.remove(f)
265
- except Exception:
266
- pass
267
- for k in ("uploaded_file", "processed_file"):
268
- st.session_state.pop(k, None)
269
- st.session_state["videos"] = ""
270
- st.session_state["last_loaded_path"] = ""
271
- st.session_state["analysis_out"] = ""
272
- st.session_state["last_error"] = ""
273
- st.session_state["file_hash"] = None
274
-
275
- try:
276
- with open(st.session_state["videos"], "rb") as vf:
277
- st.download_button("Download Video", data=vf, file_name=sanitize_filename(st.session_state["videos"]), mime="video/mp4", use_container_width=True)
278
- except Exception:
279
- pass
280
- st.sidebar.write("Title:", Path(st.session_state["videos"]).name)
281
-
282
- col1, col2 = st.columns([1, 3])
283
- with col1:
284
- if st.session_state.get("busy"):
285
- st.write("Generation in progress...")
286
- if st.button("Cancel"):
287
- st.session_state["busy"] = False
288
- st.session_state["last_error"] = "Generation cancelled by user."
289
- else:
290
- generate_now = st.button("Generate the story", type="primary")
291
- with col2:
292
- pass
293
-
294
- def get_runtime_api_key():
295
- key = API_KEY_INPUT.strip() if API_KEY_INPUT else ""
296
- if key:
297
- return key
298
- return os.getenv("GOOGLE_API_KEY", "").strip() or None
299
-
300
- # responses caller: prefer SDK responses, fallback to generativelanguage generate endpoints
301
- import json
302
- import requests
303
-
304
  def responses_generate(model, messages, files, max_output_tokens, api_key):
305
  if not api_key:
306
  raise RuntimeError("No API key for responses_generate")
307
  sdk_err = None
 
 
308
  if HAS_GENAI and genai is not None:
309
  try:
310
  genai.configure(api_key=api_key)
311
- if hasattr(genai, "responses") and getattr(genai, "responses") is not None:
312
- return genai.responses.generate(model=model, messages=messages, files=files, max_output_tokens=max_output_tokens)
 
 
 
 
 
313
  except Exception as e:
314
  sdk_err = str(e)
315
 
 
316
  host = "https://generativelanguage.googleapis.com"
 
317
  candidates = [
318
- f"{host}/v1/models/{model}:generate",
319
- f"{host}/v1beta3/models/{model}:generate",
320
- f"{host}/v1beta2/models/{model}:generate",
321
  ]
322
- # adapt messages to a simple prompt wrapper expected by generate
323
- payload = {"prompt": {"messages": messages}, "maxOutputTokens": int(max_output_tokens or 512)}
 
324
  headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
325
  last_exc = None
 
326
  for url in candidates:
327
  try:
328
- r = requests.post(url, json=payload, headers=headers, timeout=60)
329
  if r.status_code == 200:
330
  try:
331
  return r.json()
332
  except Exception:
333
  return {"text": r.text}
 
334
  last_exc = RuntimeError(f"HTTP {r.status_code}: {r.text}")
335
  except Exception as e:
336
  last_exc = e
 
337
  diag = {"sdk_error": sdk_err, "http_error": str(last_exc), "tried_urls": candidates}
338
  raise RuntimeError(f"genai.responses not available and HTTP fallback failed: {diag}")
339
 
340
  def call_responses_once(model_used, system_msg, user_msg, fname, max_tokens):
 
 
341
  files = [{"name": fname}] if fname else None
342
  for attempt in range(2):
343
  try:
344
- return responses_generate(model_used, [system_msg, user_msg], files, max_tokens, api_key=get_runtime_api_key())
345
  except Exception:
346
  if attempt == 0:
347
  time.sleep(1.0)
348
  continue
349
  raise
350
 
351
- if (st.session_state.get("busy") is False) and ('generate_now' in locals() and generate_now):
352
- if not st.session_state.get("videos"):
353
- st.error("No video loaded. Use 'Load Video' in the sidebar.")
354
- else:
355
- runtime_key = get_runtime_api_key()
356
- if not runtime_key:
357
- st.error("Google API key not set. Provide in Settings or set GOOGLE_API_KEY in environment.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  else:
359
- try:
360
- st.session_state["busy"] = True
361
- processed = st.session_state.get("processed_file")
362
- current_path = st.session_state.get("videos")
363
- try:
364
- current_hash = file_sha256(current_path) if current_path and Path(current_path).exists() else None
365
- except Exception:
366
- current_hash = None
367
-
368
- reupload_needed = True
369
- if processed and st.session_state.get("last_loaded_path") == current_path and st.session_state.get("file_hash") == current_hash:
370
- reupload_needed = False
371
-
372
- upload_path = current_path
373
- uploaded = st.session_state.get("uploaded_file")
374
- if reupload_needed:
375
- local_path = current_path
376
- fast_mode = st.session_state.get("fast_mode", False)
377
- try:
378
- file_size_mb = os.path.getsize(local_path) / (1024 * 1024)
379
- except Exception:
380
- file_size_mb = 0
381
-
382
- use_compression = st.session_state.get("use_compression", True)
383
- if use_compression and not fast_mode and file_size_mb > 50:
384
- compressed_path = str(Path(local_path).with_name(Path(local_path).stem + "_compressed.mp4"))
385
- try:
386
- preset = "veryfast" if fast_mode else "fast"
387
- upload_path = compress_video(local_path, compressed_path, crf=28, preset=preset)
388
- except Exception:
389
- upload_path = local_path
390
-
391
- if HAS_GENAI and upload_file is not None:
392
- genai.configure(api_key=runtime_key)
393
- with st.spinner("Uploading video..."):
394
- uploaded = upload_video_sdk(upload_path)
395
- processed = wait_for_processed(uploaded, timeout=600)
396
- st.session_state["uploaded_file"] = uploaded
397
- st.session_state["processed_file"] = processed
398
- st.session_state["last_loaded_path"] = current_path
399
- st.session_state["file_hash"] = current_hash
400
- else:
401
- uploaded = None
402
- processed = None
403
- st.session_state["uploaded_file"] = None
404
- st.session_state["processed_file"] = None
405
- else:
406
- uploaded = st.session_state.get("uploaded_file")
407
- processed = st.session_state.get("processed_file")
408
-
409
- prompt_text = (analysis_prompt or default_prompt).strip()
410
- if st.session_state.get("fast_mode"):
411
- model_used = model_arg or "gemini-2.0-flash-lite"
412
- max_tokens = min(st.session_state.get("max_output_tokens", 512), 1024)
413
- else:
414
- model_used = model_arg
415
- max_tokens = st.session_state.get("max_output_tokens", 1024)
416
-
417
- system_msg = {"role": "system", "content": "You are a helpful assistant that summarizes videos concisely in vivid detail."}
418
- user_msg = {"role": "user", "content": prompt_text}
419
-
420
- fname = file_name_or_id(processed) or file_name_or_id(uploaded)
421
- response = call_responses_once(model_used, system_msg, user_msg, fname, max_tokens)
422
-
423
- def extract_text_from_response(response):
424
- outputs = getattr(response, "output", None) or (response.get("output") if isinstance(response, dict) else None) or []
425
- if isinstance(outputs, dict):
426
- outputs = outputs.get("contents") or outputs.get("items") or []
427
- text_pieces = []
428
- for item in outputs or []:
429
- contents = getattr(item, "content", None) or (item.get("content") if isinstance(item, dict) else None) or []
430
- if isinstance(contents, dict):
431
- contents = [contents]
432
- for c in contents:
433
- ctype = getattr(c, "type", None) or (c.get("type") if isinstance(c, dict) else None)
434
- if ctype in ("output_text", "text") or ctype is None:
435
- txt = getattr(c, "text", None) or (c.get("text") if isinstance(c, dict) else None)
436
- if txt:
437
- text_pieces.append(txt)
438
- if not text_pieces:
439
- top_text = getattr(response, "text", None) or (response.get("text") if isinstance(response, dict) else None)
440
- if top_text:
441
- text_pieces.append(top_text)
442
- seen = set()
443
- filtered = []
444
- for t in text_pieces:
445
- if t not in seen:
446
- filtered.append(t)
447
- seen.add(t)
448
- return "\n\n".join(filtered)
449
-
450
- out = extract_text_from_response(response)
451
-
452
- meta = getattr(response, "metrics", None) or (response.get("metrics") if isinstance(response, dict) else None) or {}
453
- output_tokens = 0
454
- try:
455
- if isinstance(meta, dict):
456
- output_tokens = int(meta.get("output_tokens", 0) or 0)
457
- else:
458
- output_tokens = int(getattr(meta, "output_tokens", 0) or 0)
459
- except Exception:
460
- output_tokens = 0
461
-
462
- if (not out or output_tokens == 0) and model_used:
463
- retry_prompt = "Summarize the video content briefly and vividly (2-4 paragraphs)."
464
- try:
465
- response2 = call_responses_once(model_used, system_msg, {"role": "user", "content": retry_prompt}, fname, min(max_tokens * 2, 4096))
466
- out2 = extract_text_from_response(response2)
467
- if out2 and len(out2) > len(out or ""):
468
- out = out2
469
- else:
470
- response3 = call_responses_once(model_used, system_msg, {"role": "user", "content": "List the main points of the video as 6-10 bullets."}, fname, min(1024, max_tokens * 2))
471
- out3 = extract_text_from_response(response3)
472
- if out3:
473
- out = out3
474
- except Exception:
475
- pass
476
-
477
- if out:
478
- out = remove_prompt_echo(prompt_text, out).strip()
479
-
480
- st.session_state["analysis_out"] = out or ""
481
- st.session_state["last_error"] = ""
482
-
483
- st.subheader("Analysis Result")
484
- st.markdown(out or "_(no text returned)_")
485
-
486
- try:
487
- if reupload_needed:
488
- if upload_path and Path(upload_path).exists() and Path(upload_path) != Path(current_path):
489
- Path(upload_path).unlink(missing_ok=True)
490
- Path(current_path).unlink(missing_ok=True)
491
- st.session_state["videos"] = ""
492
- except Exception:
493
- pass
494
-
495
- with st.expander("Debug (compact)", expanded=False):
496
- try:
497
- info = {
498
- "model": model_used,
499
- "output_tokens": output_tokens,
500
- "upload_succeeded": bool(st.session_state.get("uploaded_file")),
501
- "processed_state": getattr(st.session_state.get("processed_file"), "state", None) if st.session_state.get("processed_file") else None,
502
- }
503
- st.write(info)
504
- try:
505
- if isinstance(response, dict):
506
- keys = list(response.keys())[:20]
507
- else:
508
- keys = [k for k in dir(response) if not k.startswith("_")][:20]
509
- st.write({"response_keys_or_attrs": keys})
510
- except Exception:
511
- pass
512
- except Exception:
513
- st.write("Debug info unavailable")
514
-
515
- except Exception as e:
516
- st.session_state["last_error"] = str(e)
517
- st.error(f"An error occurred while generating the story: {e}")
518
- finally:
519
- st.session_state["busy"] = False
520
-
521
- if st.session_state.get("analysis_out"):
522
- st.subheader("Analysis Result")
523
- st.markdown(st.session_state.get("analysis_out"))
524
-
525
- if st.session_state.get("last_error"):
526
- with st.expander("Last Error", expanded=False):
527
- st.write(st.session_state.get("last_error"))
528
-
529
- with st.sidebar.expander("Manage uploads", expanded=False):
530
- if st.button("Delete uploaded files (local + cloud)"):
531
- for f in glob(str(DATA_DIR / "*")):
532
- try:
533
- Path(f).unlink(missing_ok=True)
534
- except Exception:
535
- pass
536
- st.session_state["videos"] = ""
537
- st.session_state["uploaded_file"] = None
538
- st.session_state["processed_file"] = None
539
- st.session_state["last_loaded_path"] = ""
540
- st.session_state["analysis_out"] = ""
541
- st.session_state["file_hash"] = None
542
- try:
543
- fname = file_name_or_id(st.session_state.get("uploaded_file"))
544
- if fname and delete_file and HAS_GENAI:
545
- genai.configure(api_key=get_runtime_api_key() or os.getenv("GOOGLE_API_KEY", ""))
546
- delete_file(fname)
547
- except Exception:
548
- pass
549
- st.success("Local files removed. Cloud deletion attempted where supported.")
 
1
+ # --- patched responses / generate compatibility layer ---
2
+ import json
3
+ import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ def _normalize_model_for_url(model: str) -> str:
6
+ if not model:
7
+ return "gemini-2.0"
8
+ return model.split("/", 1)[-1] if model.startswith("models/") else model
9
+
10
+ def _build_prompt_from_messages(messages):
11
+ # messages expected as list of {"role":..., "content":...}
12
+ if not messages:
13
+ return ""
14
+ parts = []
15
+ for m in messages:
16
+ role = (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) or "user"
17
+ content = (m.get("content") if isinstance(m, dict) else getattr(m, "content", None)) or ""
18
+ parts.append(f"{role.upper()}:\n{content.strip()}\n")
19
+ return "\n".join(parts)
20
+
21
+ def _parse_http_generate_response(rjson):
22
+ # Attempt to extract text from various generate shapes
23
+ if not rjson:
 
24
  return None
25
+ # common new GL formats: {'candidates':[{'content': '...'}]} or {'output': [{'content': ...}]}
26
+ if isinstance(rjson, dict):
27
+ # try 'candidates'
28
+ if "candidates" in rjson and isinstance(rjson["candidates"], list) and rjson["candidates"]:
29
+ cand = rjson["candidates"][0]
30
+ return cand.get("content") or cand.get("text") or rjson.get("text")
31
+ # try 'output' array with 'content' items
32
+ out = rjson.get("output")
33
+ if isinstance(out, list) and out:
34
+ texts = []
35
+ for item in out:
36
+ if isinstance(item, dict):
37
+ c = item.get("content") or item.get("contents") or item.get("text")
38
+ if isinstance(c, str):
39
+ texts.append(c)
40
+ elif isinstance(c, list):
41
+ for sub in c:
42
+ if isinstance(sub, dict):
43
+ t = sub.get("text") or sub.get("content")
44
+ if t:
45
+ texts.append(t)
46
+ if texts:
47
+ return "\n\n".join(texts)
48
+ # fallback to top-level text
49
+ if "text" in rjson and isinstance(rjson["text"], str):
50
+ return rjson["text"]
51
  return None
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def responses_generate(model, messages, files, max_output_tokens, api_key):
54
  if not api_key:
55
  raise RuntimeError("No API key for responses_generate")
56
  sdk_err = None
57
+
58
+ # try SDK responses.generate (preferred)
59
  if HAS_GENAI and genai is not None:
60
  try:
61
  genai.configure(api_key=api_key)
62
+ responses_obj = getattr(genai, "responses", None)
63
+ if responses_obj is not None and hasattr(responses_obj, "generate"):
64
+ # SDK expects messages and files in their SDK-specific shapes
65
+ sdk_kwargs = {"model": model, "messages": messages, "max_output_tokens": int(max_output_tokens or 512)}
66
+ if files:
67
+ sdk_kwargs["files"] = files
68
+ return responses_obj.generate(**sdk_kwargs)
69
  except Exception as e:
70
  sdk_err = str(e)
71
 
72
+ # HTTP fallback to Generative Language "generate" endpoints.
73
  host = "https://generativelanguage.googleapis.com"
74
+ norm_model = _normalize_model_for_url(model)
75
  candidates = [
76
+ f"{host}/v1/models/{norm_model}:generate",
77
+ f"{host}/v1beta3/models/{norm_model}:generate",
78
+ f"{host}/v1beta2/models/{norm_model}:generate",
79
  ]
80
+
81
+ prompt_text = _build_prompt_from_messages(messages)
82
+ payload = {"prompt": {"text": prompt_text}, "maxOutputTokens": int(max_output_tokens or 512)}
83
  headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
84
  last_exc = None
85
+
86
  for url in candidates:
87
  try:
88
+ r = requests.post(url, json=payload, headers=headers, timeout=15)
89
  if r.status_code == 200:
90
  try:
91
  return r.json()
92
  except Exception:
93
  return {"text": r.text}
94
+ # if 404, try next; collect last
95
  last_exc = RuntimeError(f"HTTP {r.status_code}: {r.text}")
96
  except Exception as e:
97
  last_exc = e
98
+
99
  diag = {"sdk_error": sdk_err, "http_error": str(last_exc), "tried_urls": candidates}
100
  raise RuntimeError(f"genai.responses not available and HTTP fallback failed: {diag}")
101
 
102
  def call_responses_once(model_used, system_msg, user_msg, fname, max_tokens):
103
+ # messages as [system_msg, user_msg]
104
+ messages = [system_msg, user_msg]
105
  files = [{"name": fname}] if fname else None
106
  for attempt in range(2):
107
  try:
108
+ return responses_generate(model_used, messages, files, max_tokens, api_key=get_runtime_api_key())
109
  except Exception:
110
  if attempt == 0:
111
  time.sleep(1.0)
112
  continue
113
  raise
114
 
115
+ # Helper to extract text from either SDK response object or HTTP dict
116
+ def extract_text_from_response(response):
117
+ # SDK may return an object with .output, .candidates, or .text
118
+ # HTTP returns a dict with various shapes
119
+ # If it's an object (not dict), try attribute access
120
+ try:
121
+ if response is None:
122
+ return None
123
+ if isinstance(response, dict):
124
+ # HTTP-style
125
+ text = _parse_http_generate_response(response)
126
+ if text:
127
+ return text
128
+ # try 'output' field shaped differently
129
+ outputs = response.get("output") or response.get("candidates")
130
+ if outputs:
131
+ pieces = []
132
+ for o in outputs:
133
+ if isinstance(o, dict):
134
+ t = o.get("content") or o.get("text")
135
+ if isinstance(t, str):
136
+ pieces.append(t)
137
+ if pieces:
138
+ return "\n\n".join(pieces)
139
+ return response.get("text") or None
140
  else:
141
+ # object-like SDK response
142
+ outputs = getattr(response, "output", None) or getattr(response, "candidates", None) or None
143
+ if outputs:
144
+ pieces = []
145
+ for item in outputs:
146
+ # each item may have 'content' or 'text'
147
+ txt = getattr(item, "content", None) or getattr(item, "text", None) or (item.get("content") if isinstance(item, dict) else None)
148
+ if txt:
149
+ pieces.append(txt)
150
+ if pieces:
151
+ return "\n\n".join(pieces)
152
+ # try top-level text
153
+ txt = getattr(response, "text", None)
154
+ if txt:
155
+ return txt
156
+ except Exception:
157
+ pass
158
+ return None
159
+ # --- end patched section ---