Prgckwb commited on
Commit
8aae6c2
1 Parent(s): 4af9e39

:tada: init

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -9,26 +9,34 @@ from diffusers.schedulers import (
9
  DDIMScheduler,
10
  EulerAncestralDiscreteScheduler,
11
  DPMSolverMultistepScheduler,
12
- FlowMatchEulerDiscreteScheduler,
 
13
  )
14
  from diffusers.utils.torch_utils import randn_tensor
15
 
 
 
 
 
 
16
  SCHEDULERS = {
17
  "DDPMScheduler": DDPMScheduler,
18
  "DDIMScheduler": DDIMScheduler,
 
 
19
  "EulerAncestralDiscreteScheduler": EulerAncestralDiscreteScheduler,
20
  "DPMSolverMultistepScheduler": DPMSolverMultistepScheduler,
21
- "FlowMatchEulerDiscreteScheduler": FlowMatchEulerDiscreteScheduler,
22
  }
23
 
24
 
25
  def inference(
26
  image_pil: Image.Image,
 
27
  scheduler_name: str,
28
  per_step_time: int = 1,
29
  n_total_steps: int = 1000,
30
  ):
31
- scheduler = SCHEDULERS[scheduler_name]()
32
  scheduler.set_timesteps(num_inference_steps=n_total_steps)
33
  timesteps = torch.flip(scheduler.timesteps, dims=[0])
34
 
@@ -40,6 +48,8 @@ def inference(
40
  noise = randn_tensor(image_tensor.shape, generator)
41
 
42
  for i, t in enumerate(timesteps):
 
 
43
  noised_image_tensor = scheduler.add_noise(image_tensor, noise, timesteps=t)
44
  noised_image_pil = image_processor.postprocess(noised_image_tensor)[0]
45
  time.sleep(per_step_time)
@@ -61,6 +71,7 @@ if __name__ == '__main__':
61
  fn=inference,
62
  inputs=[
63
  gr.Image(type='pil', label='Input Image'),
 
64
  gr.Dropdown(list(SCHEDULERS.keys()), value='DDPMScheduler', label='Scheduler'),
65
  gr.Radio(choices=[0, 1, 2], value=0, label='Per-Step time'),
66
  gr.Radio(choices=[10, 25, 50, 100, 1000], value=50, label='Total Steps'),
@@ -79,7 +90,7 @@ if __name__ == '__main__':
79
  """,
80
  cache_examples=True,
81
  examples=[
82
- [Image.open("assets/corgi.png"), 'DDIMScheduler', 0, 50],
83
  ],
84
  )
85
  demo.launch()
 
9
  DDIMScheduler,
10
  EulerAncestralDiscreteScheduler,
11
  DPMSolverMultistepScheduler,
12
+ PNDMScheduler,
13
+ EulerDiscreteScheduler,
14
  )
15
  from diffusers.utils.torch_utils import randn_tensor
16
 
17
+ MODEL_IDS = {
18
+ 'Stable Diffusion v1.4': 'CompVis/stable-diffusion-v1-4',
19
+ 'Stable Diffusion v3 medium': 'stabilityai/stable-diffusion-3-medium-diffusers',
20
+ }
21
+
22
  SCHEDULERS = {
23
  "DDPMScheduler": DDPMScheduler,
24
  "DDIMScheduler": DDIMScheduler,
25
+ "PNDMScheduler": PNDMScheduler,
26
+ "EulerDiscreteScheduler": EulerDiscreteScheduler,
27
  "EulerAncestralDiscreteScheduler": EulerAncestralDiscreteScheduler,
28
  "DPMSolverMultistepScheduler": DPMSolverMultistepScheduler,
 
29
  }
30
 
31
 
32
  def inference(
33
  image_pil: Image.Image,
34
+ model_name: str,
35
  scheduler_name: str,
36
  per_step_time: int = 1,
37
  n_total_steps: int = 1000,
38
  ):
39
+ scheduler = SCHEDULERS[scheduler_name].from_pretrained(MODEL_IDS[model_name], subfolder='scheduler')
40
  scheduler.set_timesteps(num_inference_steps=n_total_steps)
41
  timesteps = torch.flip(scheduler.timesteps, dims=[0])
42
 
 
48
  noise = randn_tensor(image_tensor.shape, generator)
49
 
50
  for i, t in enumerate(timesteps):
51
+ t = torch.tensor([t])
52
+
53
  noised_image_tensor = scheduler.add_noise(image_tensor, noise, timesteps=t)
54
  noised_image_pil = image_processor.postprocess(noised_image_tensor)[0]
55
  time.sleep(per_step_time)
 
71
  fn=inference,
72
  inputs=[
73
  gr.Image(type='pil', label='Input Image'),
74
+ gr.Dropdown(list(MODEL_IDS.keys()), value='Stable Diffusion v1.4', label='Model ID'),
75
  gr.Dropdown(list(SCHEDULERS.keys()), value='DDPMScheduler', label='Scheduler'),
76
  gr.Radio(choices=[0, 1, 2], value=0, label='Per-Step time'),
77
  gr.Radio(choices=[10, 25, 50, 100, 1000], value=50, label='Total Steps'),
 
90
  """,
91
  cache_examples=True,
92
  examples=[
93
+ [Image.open("assets/corgi.png"), 'Stable Diffusion v1.4', 'DDIMScheduler', 0, 50],
94
  ],
95
  )
96
  demo.launch()