Prgckwb commited on
Commit
941ac0f
1 Parent(s): 7926358

:tada: init

Browse files
Files changed (1) hide show
  1. app.py +28 -8
app.py CHANGED
@@ -2,6 +2,15 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  @spaces.GPU()
@@ -11,13 +20,18 @@ def inference(
11
  prompt: str,
12
  negative_prompt: str = "",
13
  progress=gr.Progress(track_tqdm=True),
14
- ):
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- pipe = DiffusionPipeline.from_pretrained(
18
- model_id,
19
- torch_dtype=torch.float16,
20
- ).to(device)
 
 
 
 
 
21
 
22
  image = pipe(
23
  prompt,
@@ -31,8 +45,14 @@ if __name__ == "__main__":
31
  demo = gr.Interface(
32
  fn=inference,
33
  inputs=[
34
- gr.Text(
35
  label="Model ID",
 
 
 
 
 
 
36
  value="stabilityai/stable-diffusion-3-medium-diffusers",
37
  ),
38
  gr.Text(label="Prompt", value=""),
@@ -42,4 +62,4 @@ if __name__ == "__main__":
42
  gr.Image(label="Image", type="pil"),
43
  ],
44
  )
45
- demo.launch()
 
2
  import spaces
3
  import torch
4
  from diffusers import DiffusionPipeline
5
+ from PIL import Image
6
+
7
+ # Global Variables
8
+ current_model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ pipe = DiffusionPipeline.from_pretrained(
11
+ current_model_id,
12
+ torch_dtype=torch.float16,
13
+ ).to(device)
14
 
15
 
16
  @spaces.GPU()
 
20
  prompt: str,
21
  negative_prompt: str = "",
22
  progress=gr.Progress(track_tqdm=True),
23
+ ) -> Image.Image:
24
+ global current_model_id, pipe
25
 
26
+ if model_id != current_model_id:
27
+ try:
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ model_id,
30
+ torch_dtype=torch.float16,
31
+ ).to(device)
32
+ current_model_id = model_id
33
+ except Exception as e:
34
+ raise gr.Error(str(e))
35
 
36
  image = pipe(
37
  prompt,
 
45
  demo = gr.Interface(
46
  fn=inference,
47
  inputs=[
48
+ gr.Dropdown(
49
  label="Model ID",
50
+ choices=[
51
+ "stabilityai/stable-diffusion-3-medium-diffusers",
52
+ "stabilityai/stable-diffusion-xl-base-1.0",
53
+ "stabilityai/stable-diffusion-2-1",
54
+ "runwayml/stable-diffusion-v1-5",
55
+ ],
56
  value="stabilityai/stable-diffusion-3-medium-diffusers",
57
  ),
58
  gr.Text(label="Prompt", value=""),
 
62
  gr.Image(label="Image", type="pil"),
63
  ],
64
  )
65
+ demo.queue().launch()