hysts HF staff commited on
Commit
1c9931d
1 Parent(s): 4c57c14
Files changed (1) hide show
  1. model.py +5 -5
model.py CHANGED
@@ -38,7 +38,7 @@ CONTROLNET_MODEL_IDS = {
38
  }
39
 
40
 
41
- def download_all_controlnet_weights():
42
  for model_id in CONTROLNET_MODEL_IDS.values():
43
  ControlNetModel.from_pretrained(model_id)
44
 
@@ -101,7 +101,7 @@ class Model:
101
  num_steps: int,
102
  guidance_scale: float,
103
  seed: int,
104
- ):
105
  generator = torch.Generator().manual_seed(seed)
106
  return self.pipe(prompt=prompt,
107
  negative_prompt=negative_prompt,
@@ -109,7 +109,7 @@ class Model:
109
  num_images_per_prompt=num_images,
110
  num_inference_steps=num_steps,
111
  generator=generator,
112
- image=control_image)
113
 
114
  def process(
115
  self,
@@ -123,7 +123,7 @@ class Model:
123
  num_steps: int,
124
  guidance_scale: float,
125
  seed: int,
126
- ):
127
  self.load_controlnet_weight(task_name)
128
  results = self.run_pipe(
129
  prompt=self.get_prompt(prompt, additional_prompt),
@@ -134,7 +134,7 @@ class Model:
134
  guidance_scale=guidance_scale,
135
  seed=seed,
136
  )
137
- return [vis_control_image] + results.images
138
 
139
  @staticmethod
140
  def preprocess_canny(
 
38
  }
39
 
40
 
41
+ def download_all_controlnet_weights() -> None:
42
  for model_id in CONTROLNET_MODEL_IDS.values():
43
  ControlNetModel.from_pretrained(model_id)
44
 
 
101
  num_steps: int,
102
  guidance_scale: float,
103
  seed: int,
104
+ ) -> list[PIL.Image.Image]:
105
  generator = torch.Generator().manual_seed(seed)
106
  return self.pipe(prompt=prompt,
107
  negative_prompt=negative_prompt,
 
109
  num_images_per_prompt=num_images,
110
  num_inference_steps=num_steps,
111
  generator=generator,
112
+ image=control_image).images
113
 
114
  def process(
115
  self,
 
123
  num_steps: int,
124
  guidance_scale: float,
125
  seed: int,
126
+ ) -> list[PIL.Image.Image]:
127
  self.load_controlnet_weight(task_name)
128
  results = self.run_pipe(
129
  prompt=self.get_prompt(prompt, additional_prompt),
 
134
  guidance_scale=guidance_scale,
135
  seed=seed,
136
  )
137
+ return [vis_control_image] + results
138
 
139
  @staticmethod
140
  def preprocess_canny(