ciCic commited on
Commit
ab87e9c
·
1 Parent(s): eaadaf6
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ import torch
5
+ from diffusers import AutoencoderTiny
6
+ from torchvision.transforms.functional import to_pil_image, center_crop, resize, to_tensor
7
+
8
+ device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
9
+ d_type = torch.float32 if device == 'mps' else torch.float16
10
+
11
+ model_id = "madebyollin/taesd"
12
+ vae = AutoencoderTiny.from_pretrained(model_id, safetensors=True, torch_dtype=d_type).to(device)
13
+
14
+
15
+ @torch.no_grad()
16
+ def decode(image):
17
+ t = to_tensor(image).unsqueeze(0).to(device, dtype=d_type)
18
+ unscaled_t = vae.unscale_latents(t)
19
+ reconstructed = vae.decoder(unscaled_t).clamp(0, 1)
20
+ return to_pil_image(reconstructed[0])
21
+
22
+
23
+ astronaut = os.path.join(os.path.dirname(__file__), "images/21.encoded.png")
24
+
25
+
26
+ def app():
27
+ return gr.Interface(decode,
28
+ gr.Image(type="pil",
29
+ image_mode="RGBA",
30
+ mirror_webcam=False,
31
+ label='64x64',
32
+ value=astronaut),
33
+ gr.Image(type="pil",
34
+ image_mode="RGB",
35
+ label='512x512',
36
+ height=256,
37
+ width=256
38
+ ),
39
+ examples=[
40
+ os.path.join(os.path.dirname(__file__), "images/18.encoded.png"),
41
+ os.path.join(os.path.dirname(__file__), "images/20.encoded.png")
42
+ ])
43
+
44
+
45
+ if __name__ == "__main__":
46
+ app().launch()
images/18.encoded.png ADDED
images/20.encoded.png ADDED
images/21.encoded.png ADDED