zerhero commited on
Commit
35e6931
·
1 Parent(s): e2a2f56

refactor gui.py

Browse files
Files changed (2) hide show
  1. app.py +1 -8
  2. gui.py +1 -14
app.py CHANGED
@@ -1,12 +1,7 @@
1
  import spaces
2
  import os
3
- from stablepy import Model_Diffusers
4
- from stablepy.diffusers_vanilla.model import scheduler_names
5
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
6
- import torch
7
  import re
8
- import shutil
9
- import random
10
  from stablepy import (
11
  CONTROLNET_MODEL_IDS,
12
  VALID_TASKS,
@@ -22,7 +17,6 @@ from stablepy import (
22
  SD15_TASKS,
23
  SDXL_TASKS,
24
  )
25
- import urllib.parse
26
  from config import (
27
  MINIMUM_IMAGE_NUMBER,
28
  MAXIMUM_IMAGE_NUMBER,
@@ -204,7 +198,6 @@ warnings.filterwarnings(
204
 
205
  logger.setLevel(logging.DEBUG)
206
 
207
-
208
  # init GuiSD
209
  sd_gen: object = GuiSD(
210
  model_list=model_list,
@@ -220,7 +213,7 @@ sdxl_task = [k for k, v in task_stablepy.items() if v in SDXL_TASKS]
220
  sd_task = [k for k, v in task_stablepy.items() if v in SD15_TASKS]
221
 
222
 
223
- def update_task_options(model_name, task_name):
224
  if model_name in model_list:
225
  if "xl" in model_name.lower():
226
  new_choices = sdxl_task
 
1
  import spaces
2
  import os
 
 
3
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
 
4
  import re
 
 
5
  from stablepy import (
6
  CONTROLNET_MODEL_IDS,
7
  VALID_TASKS,
 
17
  SD15_TASKS,
18
  SDXL_TASKS,
19
  )
 
20
  from config import (
21
  MINIMUM_IMAGE_NUMBER,
22
  MAXIMUM_IMAGE_NUMBER,
 
198
 
199
  logger.setLevel(logging.DEBUG)
200
 
 
201
  # init GuiSD
202
  sd_gen: object = GuiSD(
203
  model_list=model_list,
 
213
  sd_task = [k for k, v in task_stablepy.items() if v in SD15_TASKS]
214
 
215
 
216
+ def update_task_options(model_name: str, task_name: str):
217
  if model_name in model_list:
218
  if "xl" in model_name.lower():
219
  new_choices = sdxl_task
gui.py CHANGED
@@ -1,21 +1,11 @@
1
  import spaces
2
  import os
3
  from stablepy import Model_Diffusers
4
- from stablepy.diffusers_vanilla.model import scheduler_names
5
- from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
6
  import torch
7
- import re
8
- import shutil
9
  import random
10
  import spaces
11
  import gradio as gr
12
- from PIL import Image
13
- import IPython.display
14
- import time, json
15
- from IPython.utils import capture
16
- import logging
17
- from utils.string_utils import extract_parameters
18
- from stablepy import logger
19
 
20
  from models.upscaler import upscaler_dict_gui
21
 
@@ -23,7 +13,6 @@ logging.getLogger("diffusers").setLevel(logging.ERROR)
23
  import diffusers
24
 
25
  diffusers.utils.logging.set_verbosity(40)
26
- import warnings
27
  from utils.download_utils import download_things
28
 
29
  hf_token: str = os.environ.get("HF_TOKEN")
@@ -411,8 +400,6 @@ class GuiSD:
411
  "ip_adapter_scale": params_ip_scale,
412
  }
413
 
414
- # print(pipe_params)
415
-
416
  random_number = random.randint(1, 100)
417
  if random_number < 25 and num_images < 3:
418
  if (not upscaler_model and
 
1
  import spaces
2
  import os
3
  from stablepy import Model_Diffusers
 
 
4
  import torch
5
+ import logging
 
6
  import random
7
  import spaces
8
  import gradio as gr
 
 
 
 
 
 
 
9
 
10
  from models.upscaler import upscaler_dict_gui
11
 
 
13
  import diffusers
14
 
15
  diffusers.utils.logging.set_verbosity(40)
 
16
  from utils.download_utils import download_things
17
 
18
  hf_token: str = os.environ.get("HF_TOKEN")
 
400
  "ip_adapter_scale": params_ip_scale,
401
  }
402
 
 
 
403
  random_number = random.randint(1, 100)
404
  if random_number < 25 and num_images < 3:
405
  if (not upscaler_model and