tokeron commited on
Commit
887521b
1 Parent(s): 740a0ae

Update diffusion_lens.py

Browse files
Files changed (1) hide show
  1. diffusion_lens.py +27 -7
diffusion_lens.py CHANGED
@@ -1,19 +1,39 @@
1
  from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
2
  import torch
3
 
4
- pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
5
- pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
 
 
6
 
7
- # Check if CUDA is available and set the device accordingly
 
 
 
 
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
9
 
10
- # Move the pipeline to the device
11
- pipeline.to(device)
12
 
13
- def get_images(prompt, skip_layers):
 
 
 
 
 
 
 
 
 
 
14
  print('inside get images')
 
 
15
  print(f'skipping {skip_layers}')
16
- pipeline_output = pipeline(prompt, clip_skip=skip_layers, num_images_per_prompt=1, return_tensors=False)
17
  print('after pipeline')
18
  images = pipeline_output.images
19
  print('got images')
 
1
  from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
2
  import torch
3
 
4
+ model_dict = {
5
+ 'sd1': "CompVis/stable-diffusion-v1-4",
6
+ 'sd2': "stabilityai/stable-diffusion-2-1",
7
+ }
8
 
9
+ model_num_of_layers = {
10
+ 'sd1': 12,
11
+ 'sd2': 22,
12
+ }
13
+
14
+
15
+ # global variable
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
18
+
19
 
 
 
20
 
21
+ def get_images(prompt, skip_layers, model, seed):
22
+ model_name = model_dict[model]
23
+ pipeline = StableDiffusionPipeline.from_pretrained(
24
+ model_name,
25
+ torch_dtype=dtype,
26
+ variant="fp16",
27
+ add_watermarker=False,
28
+ )
29
+ # Move the pipeline to the device
30
+ pipeline.to(device)
31
+ pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
32
  print('inside get images')
33
+ layer = model_num_of_layers[model] - skip_layers
34
+ gr.Info(f:"Generating image from {layer}'th layer")
35
  print(f'skipping {skip_layers}')
36
+ pipeline_output = pipeline(prompt, clip_skip=skip_layers, num_images_per_prompt=1, return_tensors=False, seed=seed)
37
  print('after pipeline')
38
  images = pipeline_output.images
39
  print('got images')