k-l-lambda commited on
Commit
a907317
1 Parent(s): 890b9fd

app.py: added controlnet configs.

Browse files
Files changed (1) hide show
  1. app.py +27 -4
app.py CHANGED
@@ -18,7 +18,7 @@ from style_template import styles
18
  MAX_SEED = np.iinfo(np.int32).max
19
  STYLE_NAMES = list(styles.keys())
20
  DEFAULT_STYLE_NAME = 'Watercolor'
21
- DEFAULT_MODEL_NAME = 'sd_xl_base_1.0'
22
  enable_lcm_arg = False
23
 
24
  # Path to InstantID models
@@ -100,6 +100,25 @@ SDXL_MODELS = [
100
  "zavychromaxl_v21_129006",
101
  ]
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def get_novita_client (novita_key):
105
  client = NovitaClient(novita_key, os.getenv('NOVITA_API_URI', None))
@@ -263,6 +282,10 @@ def generate_image (
263
  ref_image = PIL.Image.open(ref_image_path)
264
  width, height = ref_image.size
265
 
 
 
 
 
266
  res = client._post('/v3/async/instant-id', {
267
  'extra': {
268
  'response_image_type': 'jpeg',
@@ -273,7 +296,7 @@ def generate_image (
273
  'prompt': prompt,
274
  'negative_prompt': negative_prompt,
275
  'controlnet': {
276
- 'units': [], # TODO
277
  },
278
  'image_num': 1,
279
  'steps': num_steps,
@@ -290,9 +313,9 @@ def generate_image (
290
  def progress (x):
291
  print('progress:', x.task.status)
292
  final_res = client.wait_for_task_v3(res['task_id'], callback=progress)
 
 
293
  print('status:', final_res.task.status)
294
- if final_res.task.status == V3TaskResponseStatus.TASK_STATUS_FAILED:
295
- raise RuntimeError(f'Novita task failed: {final_res.task.status}')
296
 
297
  final_res.download_images()
298
  except Exception as e:
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  STYLE_NAMES = list(styles.keys())
20
  DEFAULT_STYLE_NAME = 'Watercolor'
21
+ DEFAULT_MODEL_NAME = 'sdxlUnstableDiffusers_v8HEAVENSWRATH_133813'
22
  enable_lcm_arg = False
23
 
24
  # Path to InstantID models
 
100
  "zavychromaxl_v21_129006",
101
  ]
102
 
103
+ CONTROLNET_DICT = dict(
104
+ pose={
105
+ 'model_name': 'controlnet-openpose-sdxl-1.0',
106
+ 'strength': 1,
107
+ 'preprocessor': 'openpose',
108
+ },
109
+ depth={
110
+ 'model_name': 'controlnet-depth-sdxl-1.0',
111
+ 'strength': 1,
112
+ 'preprocessor': 'depth',
113
+ },
114
+ canny={
115
+ 'model_name': 'controlnet-canny-sdxl-1.0',
116
+ 'strength': 1,
117
+ 'preprocessor': 'canny',
118
+ },
119
+ )
120
+
121
+
122
 
123
  def get_novita_client (novita_key):
124
  client = NovitaClient(novita_key, os.getenv('NOVITA_API_URI', None))
 
282
  ref_image = PIL.Image.open(ref_image_path)
283
  width, height = ref_image.size
284
 
285
+ CONTROLNET_DICT['pose']['strength'] = pose_strength
286
+ CONTROLNET_DICT['canny']['strength'] = canny_strength
287
+ CONTROLNET_DICT['depth']['strength'] = depth_strength
288
+
289
  res = client._post('/v3/async/instant-id', {
290
  'extra': {
291
  'response_image_type': 'jpeg',
 
296
  'prompt': prompt,
297
  'negative_prompt': negative_prompt,
298
  'controlnet': {
299
+ 'units': [CONTROLNET_DICT[name] for name in controlnet_selection if name in CONTROLNET_DICT],
300
  },
301
  'image_num': 1,
302
  'steps': num_steps,
 
313
  def progress (x):
314
  print('progress:', x.task.status)
315
  final_res = client.wait_for_task_v3(res['task_id'], callback=progress)
316
+ if final_res is None or final_res.task.status == V3TaskResponseStatus.TASK_STATUS_FAILED:
317
+ raise RuntimeError(f'Novita task failed: {final_res and final_res.task.status}')
318
  print('status:', final_res.task.status)
 
 
319
 
320
  final_res.download_images()
321
  except Exception as e: