Spaces:
Sleeping
Sleeping
Prgckwb
commited on
Commit
•
857a4f1
1
Parent(s):
16ea9e8
add requirements
Browse files- app.py +23 -18
- requirements.txt +3 -0
app.py
CHANGED
@@ -2,16 +2,16 @@ import gradio as gr
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
|
5 |
-
import spaces
|
6 |
from diffusers import DiffusionPipeline
|
7 |
import torch
|
8 |
|
9 |
|
10 |
model_ids = [
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
]
|
16 |
|
17 |
if torch.cuda.is_available():
|
@@ -21,15 +21,21 @@ else:
|
|
21 |
torch_dtype = torch.float32
|
22 |
device = "cpu"
|
23 |
|
24 |
-
pipelines = {
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
@spaces.GPU()
|
27 |
def inference(
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
):
|
34 |
pipe = pipelines[model_id].to(device)
|
35 |
|
@@ -42,21 +48,20 @@ def inference(
|
|
42 |
return image
|
43 |
|
44 |
|
45 |
-
if __name__ ==
|
46 |
theme = gr.themes.Ocean()
|
47 |
|
48 |
demo = gr.Interface(
|
49 |
fn=inference,
|
50 |
inputs=[
|
51 |
-
gr.Dropdown(label=
|
52 |
-
gr.Textbox(label=
|
53 |
-
gr.Slider(label=
|
54 |
-
gr.Slider(label=
|
55 |
],
|
56 |
outputs=[
|
57 |
-
gr.Image(label=
|
58 |
],
|
59 |
theme=theme,
|
60 |
)
|
61 |
demo.queue().launch()
|
62 |
-
|
|
|
2 |
import numpy as np
|
3 |
import random
|
4 |
|
5 |
+
import spaces # [uncomment to use ZeroGPU]
|
6 |
from diffusers import DiffusionPipeline
|
7 |
import torch
|
8 |
|
9 |
|
10 |
model_ids = [
|
11 |
+
"Prgckwb/trpfrog-sd3.5-large",
|
12 |
+
"Prgckwb/trpfrog-sd3.5-medium",
|
13 |
+
"Prgckwb/trpfrog-sdxl",
|
14 |
+
"Prgckwb/trpfrog-diffusion",
|
15 |
]
|
16 |
|
17 |
if torch.cuda.is_available():
|
|
|
21 |
torch_dtype = torch.float32
|
22 |
device = "cpu"
|
23 |
|
24 |
+
pipelines = {
|
25 |
+
model_id: DiffusionPipeline.from_pretrained(
|
26 |
+
model_id, device=device, dtype=torch_dtype
|
27 |
+
)
|
28 |
+
for model_id in model_ids
|
29 |
+
}
|
30 |
+
|
31 |
|
32 |
@spaces.GPU()
|
33 |
def inference(
|
34 |
+
model_id: str,
|
35 |
+
prompt: str,
|
36 |
+
width: int,
|
37 |
+
height: int,
|
38 |
+
progress=gr.Progress(track_tqdm=True),
|
39 |
):
|
40 |
pipe = pipelines[model_id].to(device)
|
41 |
|
|
|
48 |
return image
|
49 |
|
50 |
|
51 |
+
if __name__ == "__main__":
|
52 |
theme = gr.themes.Ocean()
|
53 |
|
54 |
demo = gr.Interface(
|
55 |
fn=inference,
|
56 |
inputs=[
|
57 |
+
gr.Dropdown(label="Model", choices=model_ids, value=model_ids[0]),
|
58 |
+
gr.Textbox(label="Prompt", placeholder="an icon of trpfrog"),
|
59 |
+
gr.Slider(label="Width", minimum=64, maximum=1024, step=64, value=1024),
|
60 |
+
gr.Slider(label="Height", minimum=64, maximum=1024, step=64, value=1024),
|
61 |
],
|
62 |
outputs=[
|
63 |
+
gr.Image(label="Output"),
|
64 |
],
|
65 |
theme=theme,
|
66 |
)
|
67 |
demo.queue().launch()
|
|
requirements.txt
CHANGED
@@ -1,3 +1,6 @@
|
|
1 |
diffusers
|
2 |
safetensors
|
3 |
accelerate
|
|
|
|
|
|
|
|
1 |
diffusers
|
2 |
safetensors
|
3 |
accelerate
|
4 |
+
transformers
|
5 |
+
sentencepiece
|
6 |
+
protobuf
|