ImageEncoder / app.py
ciCic's picture
init
5a3c9bd
raw
history blame contribute delete
No virus
1.64 kB
import gradio as gr
import os
import torch
from diffusers import AutoencoderTiny
from torchvision.transforms.functional import to_pil_image, center_crop, resize, to_tensor
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
model_id = "madebyollin/taesd"
vae = AutoencoderTiny.from_pretrained(model_id, safetensors=True).to(device)
@torch.no_grad()
def encode(image):
DIM = 512
processed = center_crop(resize(image, DIM), DIM)
tensor = to_tensor(processed).unsqueeze(0).to(device)
latents = vae.encoder(tensor)
scaled = vae.scale_latents(latents).mul_(255).round_().byte()
return to_pil_image(scaled[0])
astronaut = os.path.join(os.path.dirname(__file__), "images/6.png")
def app():
return gr.Interface(encode,
gr.Image(type="pil",
mirror_webcam=False,
label='512x512',
value=astronaut),
gr.Image(type="pil",
image_mode="RGBA",
label='64x64',
height=256,
width=256
),
examples=[
astronaut,
os.path.join(os.path.dirname(__file__), "images/7.png"),
os.path.join(os.path.dirname(__file__), "images/34.png")
], allow_flagging='never', title='Image Encoder')
if __name__ == "__main__":
app().launch()