lachine commited on
Commit
06afc8b
·
1 Parent(s): e8c277e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -2,8 +2,10 @@ import gradio as gr
2
  import torch
3
  from torch import autocast
4
  from kandinsky2 import get_kandinsky2
 
5
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
  model = get_kandinsky2('cuda', task_type='text2img', cache_dir='/tmp/kandinsky2', model_version='2.1', use_flash_attention=False)
 
7
  def generate_text(prompt, quality="High (Default)"):
8
  length_dict = {"Low": 50, "High (Default)": 100, "Ultra": 150}
9
  length = length_dict[quality]
@@ -19,4 +21,9 @@ iface = gr.Interface(
19
  outputs=gr.outputs.Image(label="Generated image:")
20
  )
21
 
22
- iface.launch()
 
 
 
 
 
 
2
  import torch
3
  from torch import autocast
4
  from kandinsky2 import get_kandinsky2
5
+
6
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
  model = get_kandinsky2('cuda', task_type='text2img', cache_dir='/tmp/kandinsky2', model_version='2.1', use_flash_attention=False)
8
+
9
  def generate_text(prompt, quality="High (Default)"):
10
  length_dict = {"Low": 50, "High (Default)": 100, "Ultra": 150}
11
  length = length_dict[quality]
 
21
  outputs=gr.outputs.Image(label="Generated image:")
22
  )
23
 
24
+ if device.type == 'cpu':
25
+ model.load_state_dict(torch.load('path/to/model.pth', map_location=device))
26
+ else:
27
+ model.load_state_dict(torch.load('path/to/model.pth'))
28
+
29
+ iface.launch()