MySafeCode commited on
Commit
f3670b8
·
verified ·
1 Parent(s): 7e090a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -78
app.py CHANGED
@@ -18,19 +18,7 @@ current_outputs = []
18
  current_page = 0
19
  page_size = 10
20
 
21
- # Default generation settings
22
- DEFAULT_SETTINGS = {
23
- "model_id": "048b4ea3-5586-47ed-900f-f4341c96bdb2", # SDXL 1.0
24
- "width": 1024,
25
- "height": 1024,
26
- "num_outputs": 1,
27
- "guidance_scale": 7.5,
28
- "inference_steps": 30,
29
- "scheduler_id": "euler",
30
- "seed": None
31
- }
32
-
33
- # Get available models for dropdown
34
  def get_model_list():
35
  """Get list of models for dropdown"""
36
  try:
@@ -40,31 +28,39 @@ def get_model_list():
40
  models = response.json().get('models', [])
41
  # Sort: default first, then by name
42
  models.sort(key=lambda x: (not x.get('is_default', False), x.get('name', '')))
43
- return [(m['name'], m['id']) for m in models]
44
- except:
45
- pass
46
- return [("SDXL 1.0", "048b4ea3-5586-47ed-900f-f4341c96bdb2")]
 
47
 
48
- # Get schedulers for dropdown
49
  def get_scheduler_list():
50
  """Get list of schedulers"""
51
  return [
52
- ("Euler", "euler"),
53
- ("Euler A", "euler_a"),
54
- ("DDIM", "ddim"),
55
- ("DPMSolver++", "dpmpp_2m"),
56
- ("DPM++ 2M Karras", "dpmpp_2m_karras"),
57
- ("DPM++ SDE", "dpmpp_sde"),
58
- ("DPM++ SDE Karras", "dpmpp_sde_karras"),
59
- ("Heun", "heun"),
60
- ("LMS", "lms"),
61
- ("LMS Karras", "lms_karras")
62
  ]
63
 
64
  # Initialize model and scheduler lists
65
  MODELS = get_model_list()
66
  SCHEDULERS = get_scheduler_list()
67
 
 
 
 
 
 
 
 
 
68
  # ========== MODELS TAB ==========
69
  def get_models():
70
  """Fetch and display available models"""
@@ -102,13 +98,17 @@ def get_models():
102
  return f"❌ Error: {str(e)}", "No data"
103
 
104
  # ========== GENERATE TAB ==========
105
- def generate_image(prompt, negative_prompt, model_id, width, height,
106
  num_outputs, guidance_scale, inference_steps,
107
- scheduler_id, seed, init_image_url, prompt_strength):
108
  """Generate images using StableCog API"""
109
  try:
110
  url = f'{API_HOST}/v1/image/generation/create'
111
 
 
 
 
 
112
  # Prepare request data
113
  data = {
114
  "prompt": prompt,
@@ -136,14 +136,20 @@ def generate_image(prompt, negative_prompt, model_id, width, height,
136
  if prompt_strength is not None:
137
  data["prompt_strength"] = float(prompt_strength)
138
 
 
 
 
139
  # Make API request
140
  response = requests.post(
141
  url,
142
- data=json.dumps(data),
143
  headers=headers,
144
  timeout=30 # Longer timeout for generation
145
  )
146
 
 
 
 
147
  if response.status_code == 200:
148
  result = response.json()
149
  outputs = result.get('outputs', [])
@@ -178,33 +184,22 @@ def generate_image(prompt, negative_prompt, model_id, width, height,
178
  for key, value in settings.items():
179
  display_text += f" {key}: {value}\n"
180
 
181
- return display_text, gallery_html, str(result), None
182
 
183
  else:
184
  error_msg = f"❌ Generation failed: {response.status_code}"
185
- return error_msg, "", f"Error: {response.text}", None
 
 
 
 
 
186
 
187
  except Exception as e:
188
  error_msg = f"❌ Error: {str(e)}"
189
- return error_msg, "", "No data", None
190
 
191
  # ========== OUTPUTS TAB ==========
192
- def fetch_outputs():
193
- """Fetch outputs from API"""
194
- global current_outputs
195
- try:
196
- url = f'{API_HOST}/v1/image/generation/outputs'
197
- response = requests.get(url, headers=headers, timeout=10)
198
-
199
- if response.status_code == 200:
200
- data = response.json()
201
- current_outputs = data.get('outputs', [])
202
- return data
203
- else:
204
- return None
205
- except:
206
- return None
207
-
208
  def create_gallery_html(image_urls, title="Gallery"):
209
  """Create HTML gallery with lightbox"""
210
  html = f"""
@@ -335,6 +330,22 @@ def create_gallery_html(image_urls, title="Gallery"):
335
  html += "</div>"
336
  return html
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def update_outputs_display():
339
  """Update outputs display with current page"""
340
  global current_outputs, current_page, page_size
@@ -517,27 +528,25 @@ with gr.Blocks(title="StableCog Dashboard", theme=gr.themes.Soft()) as demo:
517
  gr.Markdown("### ⚙️ Generation Settings")
518
 
519
  with gr.Row():
520
- model_id = gr.Dropdown(
521
  label="Model",
522
  choices=[m[0] for m in MODELS],
523
- value=MODELS[0][0] if MODELS else "SDXL 1.0"
524
  )
525
- # Store actual model IDs
526
- model_map = {m[0]: m[1] for m in MODELS}
527
-
528
  with gr.Row():
529
  width = gr.Slider(
530
  label="Width",
531
  minimum=256,
532
  maximum=1024,
533
- value=1024,
534
  step=8
535
  )
536
  height = gr.Slider(
537
  label="Height",
538
  minimum=256,
539
  maximum=1024,
540
- value=1024,
541
  step=8
542
  )
543
 
@@ -553,7 +562,7 @@ with gr.Blocks(title="StableCog Dashboard", theme=gr.themes.Soft()) as demo:
553
  label="Guidance Scale",
554
  minimum=1.0,
555
  maximum=20.0,
556
- value=7.5,
557
  step=0.5
558
  )
559
 
@@ -566,13 +575,11 @@ with gr.Blocks(title="StableCog Dashboard", theme=gr.themes.Soft()) as demo:
566
  )
567
 
568
  with gr.Row():
569
- scheduler_id = gr.Dropdown(
570
  label="Scheduler",
571
  choices=[s[0] for s in SCHEDULERS],
572
- value="Euler"
573
  )
574
- # Store actual scheduler IDs
575
- scheduler_map = {s[0]: s[1] for s in SCHEDULERS}
576
 
577
  seed = gr.Textbox(
578
  label="Seed (Optional)",
@@ -607,25 +614,12 @@ with gr.Blocks(title="StableCog Dashboard", theme=gr.themes.Soft()) as demo:
607
  )
608
 
609
  # Connect generate button
610
- def on_generate(prompt, negative_prompt, model_name, width, height,
611
- num_outputs, guidance_scale, inference_steps,
612
- scheduler_name, seed, init_image_url, prompt_strength):
613
- # Get actual IDs from maps
614
- actual_model_id = model_map.get(model_name, MODELS[0][1] if MODELS else DEFAULT_SETTINGS["model_id"])
615
- actual_scheduler_id = scheduler_map.get(scheduler_name, "euler")
616
-
617
- return generate_image(
618
- prompt, negative_prompt, actual_model_id, width, height,
619
- num_outputs, guidance_scale, inference_steps,
620
- actual_scheduler_id, seed, init_image_url, prompt_strength
621
- )
622
-
623
  generate_btn.click(
624
- on_generate,
625
  inputs=[
626
- prompt, negative_prompt, model_id, width, height,
627
  num_outputs, guidance_scale, inference_steps,
628
- scheduler_id, seed, init_image_url, prompt_strength
629
  ],
630
  outputs=[generate_output, generate_gallery, generate_raw]
631
  )
 
18
  current_page = 0
19
  page_size = 10
20
 
21
+ # ========== MODEL AND SCHEDULER MANAGEMENT ==========
 
 
 
 
 
 
 
 
 
 
 
 
22
  def get_model_list():
23
  """Get list of models for dropdown"""
24
  try:
 
28
  models = response.json().get('models', [])
29
  # Sort: default first, then by name
30
  models.sort(key=lambda x: (not x.get('is_default', False), x.get('name', '')))
31
+ return [(f"{m['name']}", m['id']) for m in models]
32
+ except Exception as e:
33
+ print(f"Error fetching models: {e}")
34
+ # Fallback to SDXL
35
+ return [("SDXL 1.0", "22b0857d-7edc-4d00-9cd9-45aa509db093")]
36
 
 
37
  def get_scheduler_list():
38
  """Get list of schedulers"""
39
  return [
40
+ ("Euler", "b7224e56-1440-43b9-ac86-66d66f9e8c91"),
41
+ ("Euler A", "6fb13a76-990d-49df-a2ab-7d9d22c33e3d"),
42
+ ("DDIM", "c5a0bad3-bd9d-4c5c-96e3-9d8e8c0c7a6b"),
43
+ ("DPMSolver++", "e9c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
44
+ ("DPM++ 2M Karras", "f1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
45
+ ("DPM++ SDE", "a1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
46
+ ("DPM++ SDE Karras", "b1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
47
+ ("Heun", "c1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
48
+ ("LMS", "d1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
49
+ ("LMS Karras", "e1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f")
50
  ]
51
 
52
  # Initialize model and scheduler lists
53
  MODELS = get_model_list()
54
  SCHEDULERS = get_scheduler_list()
55
 
56
+ # Create mapping dictionaries
57
+ MODEL_MAP = {name: id for name, id in MODELS}
58
+ SCHEDULER_MAP = {name: id for name, id in SCHEDULERS}
59
+
60
+ # Default values
61
+ DEFAULT_MODEL = MODELS[0][0] if MODELS else "SDXL 1.0"
62
+ DEFAULT_SCHEDULER = SCHEDULERS[0][0] if SCHEDULERS else "Euler"
63
+
64
  # ========== MODELS TAB ==========
65
  def get_models():
66
  """Fetch and display available models"""
 
98
  return f"❌ Error: {str(e)}", "No data"
99
 
100
  # ========== GENERATE TAB ==========
101
+ def generate_image(prompt, negative_prompt, model_name, width, height,
102
  num_outputs, guidance_scale, inference_steps,
103
+ scheduler_name, seed, init_image_url, prompt_strength):
104
  """Generate images using StableCog API"""
105
  try:
106
  url = f'{API_HOST}/v1/image/generation/create'
107
 
108
+ # Get actual IDs from maps
109
+ model_id = MODEL_MAP.get(model_name, MODELS[0][1] if MODELS else "22b0857d-7edc-4d00-9cd9-45aa509db093")
110
+ scheduler_id = SCHEDULER_MAP.get(scheduler_name, "b7224e56-1440-43b9-ac86-66d66f9e8c91")
111
+
112
  # Prepare request data
113
  data = {
114
  "prompt": prompt,
 
136
  if prompt_strength is not None:
137
  data["prompt_strength"] = float(prompt_strength)
138
 
139
+ # Debug: Print the data being sent
140
+ print(f"Sending data: {json.dumps(data, indent=2)}")
141
+
142
  # Make API request
143
  response = requests.post(
144
  url,
145
+ json=data, # Use json parameter instead of data=json.dumps()
146
  headers=headers,
147
  timeout=30 # Longer timeout for generation
148
  )
149
 
150
+ print(f"Response status: {response.status_code}")
151
+ print(f"Response text: {response.text[:200]}...")
152
+
153
  if response.status_code == 200:
154
  result = response.json()
155
  outputs = result.get('outputs', [])
 
184
  for key, value in settings.items():
185
  display_text += f" {key}: {value}\n"
186
 
187
+ return display_text, gallery_html, str(result)
188
 
189
  else:
190
  error_msg = f"❌ Generation failed: {response.status_code}"
191
+ try:
192
+ error_detail = response.json()
193
+ error_msg += f"\nDetails: {json.dumps(error_detail, indent=2)}"
194
+ except:
195
+ error_msg += f"\nResponse: {response.text}"
196
+ return error_msg, "", error_msg
197
 
198
  except Exception as e:
199
  error_msg = f"❌ Error: {str(e)}"
200
+ return error_msg, "", error_msg
201
 
202
  # ========== OUTPUTS TAB ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def create_gallery_html(image_urls, title="Gallery"):
204
  """Create HTML gallery with lightbox"""
205
  html = f"""
 
330
  html += "</div>"
331
  return html
332
 
333
+ def fetch_outputs():
334
+ """Fetch outputs from API"""
335
+ global current_outputs
336
+ try:
337
+ url = f'{API_HOST}/v1/image/generation/outputs'
338
+ response = requests.get(url, headers=headers, timeout=10)
339
+
340
+ if response.status_code == 200:
341
+ data = response.json()
342
+ current_outputs = data.get('outputs', [])
343
+ return data
344
+ else:
345
+ return None
346
+ except:
347
+ return None
348
+
349
  def update_outputs_display():
350
  """Update outputs display with current page"""
351
  global current_outputs, current_page, page_size
 
528
  gr.Markdown("### ⚙️ Generation Settings")
529
 
530
  with gr.Row():
531
+ model_dropdown = gr.Dropdown(
532
  label="Model",
533
  choices=[m[0] for m in MODELS],
534
+ value=DEFAULT_MODEL
535
  )
536
+
 
 
537
  with gr.Row():
538
  width = gr.Slider(
539
  label="Width",
540
  minimum=256,
541
  maximum=1024,
542
+ value=768,
543
  step=8
544
  )
545
  height = gr.Slider(
546
  label="Height",
547
  minimum=256,
548
  maximum=1024,
549
+ value=768,
550
  step=8
551
  )
552
 
 
562
  label="Guidance Scale",
563
  minimum=1.0,
564
  maximum=20.0,
565
+ value=7.0,
566
  step=0.5
567
  )
568
 
 
575
  )
576
 
577
  with gr.Row():
578
+ scheduler_dropdown = gr.Dropdown(
579
  label="Scheduler",
580
  choices=[s[0] for s in SCHEDULERS],
581
+ value=DEFAULT_SCHEDULER
582
  )
 
 
583
 
584
  seed = gr.Textbox(
585
  label="Seed (Optional)",
 
614
  )
615
 
616
  # Connect generate button
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  generate_btn.click(
618
+ generate_image,
619
  inputs=[
620
+ prompt, negative_prompt, model_dropdown, width, height,
621
  num_outputs, guidance_scale, inference_steps,
622
+ scheduler_dropdown, seed, init_image_url, prompt_strength
623
  ],
624
  outputs=[generate_output, generate_gallery, generate_raw]
625
  )