Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,8 +2,10 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from torch import autocast
|
4 |
from kandinsky2 import get_kandinsky2
|
|
|
5 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
6 |
model = get_kandinsky2('cuda', task_type='text2img', cache_dir='/tmp/kandinsky2', model_version='2.1', use_flash_attention=False)
|
|
|
7 |
def generate_text(prompt, quality="High (Default)"):
|
8 |
length_dict = {"Low": 50, "High (Default)": 100, "Ultra": 150}
|
9 |
length = length_dict[quality]
|
@@ -19,4 +21,9 @@ iface = gr.Interface(
|
|
19 |
outputs=gr.outputs.Image(label="Generated image:")
|
20 |
)
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
from torch import autocast
|
4 |
from kandinsky2 import get_kandinsky2
|
5 |
+
|
6 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
model = get_kandinsky2('cuda', task_type='text2img', cache_dir='/tmp/kandinsky2', model_version='2.1', use_flash_attention=False)
|
8 |
+
|
9 |
def generate_text(prompt, quality="High (Default)"):
|
10 |
length_dict = {"Low": 50, "High (Default)": 100, "Ultra": 150}
|
11 |
length = length_dict[quality]
|
|
|
21 |
outputs=gr.outputs.Image(label="Generated image:")
|
22 |
)
|
23 |
|
24 |
+
if device.type == 'cpu':
|
25 |
+
model.load_state_dict(torch.load('path/to/model.pth', map_location=device))
|
26 |
+
else:
|
27 |
+
model.load_state_dict(torch.load('path/to/model.pth'))
|
28 |
+
|
29 |
+
iface.launch()
|