jbilcke-hf HF staff commited on
Commit
53f3635
1 Parent(s): 4fec0c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -24
app.py CHANGED
@@ -6,29 +6,24 @@ import gradio as gr
6
  import numpy as np
7
  import PIL.Image
8
  import torch
9
- from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = int(os.getenv('MAX_IMAGE_SIZE', '1024'))
13
  SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
14
 
 
 
 
15
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
16
  if torch.cuda.is_available():
17
- unet = UNet2DConditionModel.from_pretrained(
18
- "latent-consistency/lcm-ssd-1b",
19
- torch_dtype=torch.float16,
20
- variant="fp16"
21
- )
22
-
23
- pipe = DiffusionPipeline.from_pretrained(
24
- "segmind/SSD-1B",
25
- unet=unet,
26
- torch_dtype=torch.float16,
27
- variant="fp16"
28
- )
29
-
30
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
31
- pipe.to(device)
 
 
 
 
32
  else:
33
  pipe = None
34
 
@@ -44,8 +39,8 @@ def generate(prompt: str,
44
  seed: int = 0,
45
  width: int = 1024,
46
  height: int = 1024,
47
- guidance_scale: float = 1.0,
48
- num_inference_steps: int = 6,
49
  secret_token: str = '') -> PIL.Image.Image:
50
  if secret_token != SECRET_TOKEN:
51
  raise gr.Error(
@@ -69,7 +64,7 @@ with gr.Blocks() as demo:
69
  gr.HTML("""
70
  <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
71
  <div style="text-align: center; color: black;">
72
- <p style="color: black;">This space is a REST API to programmatically generate images using LCM-SSD-1B.</p>
73
  <p style="color: black;">It is not meant to be directly used through a user interface, but using code and an access key.</p>
74
  </div>
75
  </div>""")
@@ -117,16 +112,16 @@ with gr.Blocks() as demo:
117
  )
118
  guidance_scale = gr.Slider(
119
  label='Guidance scale',
120
- minimum=1,
121
- maximum=20,
122
  step=0.1,
123
- value=1.0)
124
  num_inference_steps = gr.Slider(
125
  label='Number of inference steps',
126
- minimum=2,
127
- maximum=40,
128
  step=1,
129
- value=6)
130
 
131
  use_negative_prompt.change(
132
  fn=lambda x: gr.update(visible=x),
 
6
  import numpy as np
7
  import PIL.Image
8
  import torch
9
+ from diffusers import LCMScheduler, AutoPipelineForText2Image
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = int(os.getenv('MAX_IMAGE_SIZE', '1024'))
13
  SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
14
 
15
+ MODEL_ID = "segmind/SSD-1B"
16
+ ADAPTER_ID = "latent-consistency/lcm-lora-ssd-1b"
17
+
18
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19
  if torch.cuda.is_available():
20
+ pipe = AutoPipelineForText2Image.from_pretrained(MODEL_ID, torch_dtype=torch.float16, variant="fp16")
 
 
 
 
 
 
 
 
 
 
 
 
21
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
22
+ pipe.to("cuda")
23
+
24
+ # load and fuse
25
+ pipe.load_lora_weights(ADAPTER_ID)
26
+ pipe.fuse_lora()
27
  else:
28
  pipe = None
29
 
 
39
  seed: int = 0,
40
  width: int = 1024,
41
  height: int = 1024,
42
+ guidance_scale: float = 0.0,
43
+ num_inference_steps: int = 4,
44
  secret_token: str = '') -> PIL.Image.Image:
45
  if secret_token != SECRET_TOKEN:
46
  raise gr.Error(
 
64
  gr.HTML("""
65
  <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
66
  <div style="text-align: center; color: black;">
67
+ <p style="color: black;">This space is a REST API to programmatically generate images using LCM SDXL LoRA.</p>
68
  <p style="color: black;">It is not meant to be directly used through a user interface, but using code and an access key.</p>
69
  </div>
70
  </div>""")
 
112
  )
113
  guidance_scale = gr.Slider(
114
  label='Guidance scale',
115
+ minimum=0,
116
+ maximum=2,
117
  step=0.1,
118
+ value=0.0)
119
  num_inference_steps = gr.Slider(
120
  label='Number of inference steps',
121
+ minimum=1,
122
+ maximum=8,
123
  step=1,
124
+ value=4)
125
 
126
  use_negative_prompt.change(
127
  fn=lambda x: gr.update(visible=x),