|
import gradio as gr |
|
import os |
|
import json |
|
import torch |
|
import torch.nn as nn |
|
import diffusers |
|
from einops import rearrange |
|
from PIL import Image |
|
from omegaconf import OmegaConf |
|
from tqdm import tqdm |
|
import cv2 |
|
|
|
NUM_STEPS = 64 |
|
FRAMES = 192 |
|
FPS=32 |
|
|
|
mycss = """ |
|
.contain { |
|
width: 1000px; |
|
margin: 0 auto; |
|
} |
|
|
|
.svelte-1pijsyv { |
|
width: 448px; |
|
} |
|
|
|
.arrow { |
|
display: flex; |
|
align-items: center; |
|
margin: 7px 0; |
|
} |
|
|
|
.arrow-tail { |
|
width: 270px; |
|
height: 50px; |
|
background-color: black; |
|
transition: background-color 0.3s; |
|
} |
|
|
|
.arrow-head { |
|
width: 0; |
|
height: 0; |
|
border-top: 70px solid transparent; |
|
border-bottom: 70px solid transparent; |
|
border-left: 120px solid black; |
|
transition: border-left-color 0.3s; |
|
} |
|
|
|
@media (prefers-color-scheme: dark) { |
|
.arrow-tail { |
|
background-color: white; |
|
} |
|
.arrow-head { |
|
border-left-color: white; |
|
} |
|
} |
|
|
|
""" |
|
|
|
myhtml = """ |
|
<div class="arrow"> |
|
<div class="arrow-tail"></div> |
|
<div class="arrow-head"></div> |
|
</div> |
|
""" |
|
|
|
myjs = """ |
|
function setLoopTrue() { |
|
let videos = document.getElementsByTagName('video'); |
|
if (videos.length > 0) { |
|
document.getElementsByTagName('video')[0].loop = true; |
|
} |
|
setTimeout(setLoopTrue, 3000); |
|
} |
|
""" |
|
|
|
def load_model(path): |
|
|
|
|
|
json_path = os.path.join(path, "config.json") |
|
assert os.path.exists(json_path), f"Could not find config.json at {json_path}" |
|
with open(json_path, "r") as f: |
|
config = json.load(f) |
|
|
|
|
|
klass_name = config["_class_name"] |
|
klass = getattr(diffusers, klass_name, None) |
|
if klass is None: |
|
klass = globals().get(klass_name, None) |
|
assert klass is not None, f"Could not find class {klass_name} in diffusers or global scope." |
|
assert getattr(klass, "from_pretrained", None) is not None, f"Class {klass_name} does not support 'from_pretrained'." |
|
|
|
|
|
model = klass.from_pretrained(path) |
|
|
|
return model, config |
|
|
|
def load_scheduler(config): |
|
scheduler_kwargs = OmegaConf.to_container(config.noise_scheduler) |
|
scheduler_klass_name = scheduler_kwargs.pop("_class_name") |
|
scheduler_klass = getattr(diffusers, scheduler_klass_name, None) |
|
scheduler = scheduler_klass(**scheduler_kwargs) |
|
return scheduler |
|
|
|
def padf(tensor, mult=3): |
|
pad = 2**mult - (tensor.shape[-1] % 2**mult) |
|
pad = pad//2 |
|
tensor = nn.functional.pad(tensor, (pad, pad, pad, pad, 0, 0), mode='replicate') |
|
return tensor, pad |
|
|
|
def unpadf(tensor, pad=1): |
|
return tensor[..., pad:-pad, pad:-pad] |
|
|
|
def pad_reshape(tensor, mult=3): |
|
tensor, pad = padf(tensor, mult=mult) |
|
tensor = rearrange(tensor, "b c t h w -> b t c h w") |
|
return tensor, pad |
|
|
|
def unpad_reshape(tensor, pad=1): |
|
tensor = rearrange(tensor, "b t c h w -> b c t h w") |
|
tensor = unpadf(tensor, pad=pad) |
|
return tensor |
|
|
|
class Context: |
|
def __init__(self, lidm_path, lvdm_path, vae_path, config_path): |
|
self.lidm, self.lidm_config = load_model(lidm_path) |
|
self.lvdm, self.lvdm_config = load_model(lvdm_path) |
|
self.vae, self.vae_config = load_model(vae_path) |
|
self.config = OmegaConf.load(config_path) |
|
self.models = [self.lidm, self.lvdm, self.vae] |
|
self.scheduler = load_scheduler(self.config) |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.dtype = torch.float32 |
|
|
|
for model in self.models: |
|
model.to(self.device, dtype=self.dtype) |
|
model.eval() |
|
|
|
print("Models loaded") |
|
|
|
def get_img(self, steps): |
|
print("generating image") |
|
self.scheduler.set_timesteps(steps) |
|
with torch.no_grad(): |
|
B, C, H, W = 1, self.lidm_config["in_channels"], self.lidm_config["sample_size"], self.lidm_config["sample_size"] |
|
|
|
timesteps = self.scheduler.timesteps |
|
forward_kwargs = {} |
|
|
|
latents = torch.randn((B, C, H, W), device=self.device, dtype=self.dtype) |
|
with torch.autocast("cuda"): |
|
for t in tqdm(timesteps): |
|
forward_kwargs["timestep"] = t |
|
latent_model_input = latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep=t) |
|
latent_model_input, padding = padf(latent_model_input, mult=3) |
|
noise_pred = self.lidm(latent_model_input, **forward_kwargs).sample |
|
noise_pred = unpadf(noise_pred, pad=padding) |
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
latents = latents / self.vae.config.scaling_factor |
|
img = self.vae.decode(latents).sample |
|
img = (img + 1) * 128 |
|
img = img.mean(1).unsqueeze(1).repeat([1, 3, 1, 1]) |
|
img = img.clamp(0, 255).to(torch.uint8).cpu().numpy() |
|
img = img[0].transpose(1, 2, 0) |
|
img = Image.fromarray(img) |
|
|
|
return img, latents |
|
|
|
def get_vid(self, lvef: int, ref_latent: torch.Tensor, steps: int): |
|
print("generating video") |
|
self.scheduler.set_timesteps(steps) |
|
|
|
with torch.no_grad(): |
|
B, C, T, H, W = 1, 4, self.lvdm_config["num_frames"], self.lvdm_config["sample_size"], self.lvdm_config["sample_size"] |
|
|
|
if FRAMES > T: |
|
OT = T//2 |
|
TR = (FRAMES - T) / 32 |
|
TR = int(TR + 1) |
|
NT = (T-OT) * TR + OT |
|
else: |
|
OT = 0 |
|
TR = 1 |
|
NT = T |
|
|
|
timesteps = self.scheduler.timesteps |
|
|
|
lvef = lvef / 100 |
|
lvef = torch.tensor([lvef]*TR, device=self.device, dtype=self.dtype) |
|
lvef = lvef[:, None, None] |
|
print(lvef.shape) |
|
|
|
forward_kwargs = {} |
|
forward_kwargs["added_time_ids"] = torch.zeros((B*TR, self.config.unet.addition_time_embed_dim), device=self.device, dtype=self.dtype) |
|
forward_kwargs["encoder_hidden_states"] = lvef |
|
print(forward_kwargs["added_time_ids"].shape) |
|
|
|
latent_cond_images = ref_latent * self.vae.config.scaling_factor |
|
latent_cond_images = latent_cond_images[:,:,None,:,:].repeat([1, 1, NT, 1, 1]).to(self.device, dtype=self.dtype) |
|
print(latent_cond_images.shape) |
|
|
|
latents = torch.randn((B, C, NT, H, W), device=self.device, dtype=self.dtype) |
|
print(latents.shape) |
|
|
|
with torch.autocast("cuda"): |
|
for t in tqdm(timesteps): |
|
forward_kwargs["timestep"] = t |
|
latent_model_input = latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep=t) |
|
latent_model_input = torch.cat((latent_model_input, latent_cond_images), dim=1) |
|
latent_model_input, padding = pad_reshape(latent_model_input, mult=3) |
|
|
|
inputs = torch.cat([latent_model_input[:,r*(T-OT):r*(T-OT)+T] for r in range(TR)], dim=0) |
|
noise_pred = self.lvdm(inputs, **forward_kwargs).sample |
|
outputs = torch.chunk(noise_pred, TR, dim=0) |
|
noise_predictions = [] |
|
for r in range(TR): |
|
noise_predictions.append(outputs[r] if r == 0 else outputs[r][:,OT:]) |
|
noise_pred = torch.cat(noise_predictions, dim=1) |
|
noise_pred = unpad_reshape(noise_pred, pad=padding) |
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
print("done generating noise") |
|
|
|
latents = latents / self.vae.config.scaling_factor |
|
latents = rearrange(latents, "b c t h w -> (b t) c h w") |
|
|
|
chunk_size = 16 |
|
chunked_latents = torch.split(latents, chunk_size, dim=0) |
|
decoded_chunks = [] |
|
for chunk in chunked_latents: |
|
decoded_chunks.append(self.vae.decode(chunk.float().cuda()).sample.cpu()) |
|
video = torch.cat(decoded_chunks, dim=0) |
|
video = rearrange(video, "(b t) c h w -> b t h w c", b=B)[0] |
|
video = (video + 1) * 128 |
|
video = video.mean(-1).unsqueeze(-1).repeat([1, 1, 1, 3]) |
|
video = video.clamp(0, 255).to(torch.uint8).cpu().numpy() |
|
out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), FPS, (112, 112)) |
|
for img in video: |
|
out.write(img) |
|
out.release() |
|
|
|
return "output.mp4" |
|
|
|
|
|
ctx = Context( |
|
lidm_path="resources/lidm", |
|
lvdm_path="resources/lvdm", |
|
vae_path="resources/ivae", |
|
config_path="resources/config.yaml" |
|
) |
|
|
|
with gr.Blocks(css=mycss, js=myjs) as demo: |
|
with gr.Row(): |
|
|
|
gr.Markdown(""" |
|
# EchoNet-Synthetic: Privacy-preserving Video Generation for Safe Medical Data Sharing |
|
This demo is attached to a paper under review at MICCAI 2024, and is targeted at the reviewers of that paper. |
|
|
|
1. Start by generating an image using the "Generate Image" button. This will generate a random image, similar to the EchoNet-Dynamic dataset. |
|
2. Adjust the "Ejection Fraction Score" slider to change the ejection fraction of the generated image. |
|
3. Generate a video using the "Generate Video" button. This will generate a video from the generated image, with the ejection fraction score you chose. |
|
|
|
We leave the ejection fraction input completely open, so you can see how the video generation changes with different ejection fraction scores, even unrealistic ones. The normal ejection fraction range is 50-75.<br> |
|
We recommend 64 steps for ideal image quality, but you can adjust this to see how it affects the video generation. |
|
|
|
""") |
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(): |
|
|
|
img = gr.Image(interactive=False, label="Generated Image") |
|
img_btn = gr.Button("Generate Image") |
|
|
|
with gr.Column(): |
|
|
|
|
|
gr.HTML(myhtml) |
|
efslider = gr.Slider(minimum=0, maximum=100, value=65, step=1, label="Ejection Fraction Score (%)") |
|
dsslider = gr.Slider(minimum=1, maximum=999, value=64, step=1, label="Sampling Steps") |
|
pass |
|
|
|
with gr.Column(): |
|
|
|
vid = gr.Video(interactive=False, autoplay=True, label="Generated Video") |
|
vid_btn = gr.Button("Generate Video") |
|
|
|
with gr.Row(): |
|
|
|
gr.Examples( |
|
examples=[[f"resources/examples/ef{i}.png", f"resources/examples/ef{i}.mp4", i, 64] for i in [20, 30, 40, 50, 60, 70, 80, 90]], |
|
inputs=[img, vid, efslider, dsslider], |
|
outputs=None, |
|
fn=None, |
|
cache_examples=False, |
|
) |
|
|
|
|
|
ltt_img = gr.State() |
|
|
|
img.change() |
|
img_btn.click(fn=ctx.get_img, inputs=[dsslider], outputs=[img, ltt_img]) |
|
|
|
vid_btn.click(fn=ctx.get_vid, inputs=[efslider, ltt_img, dsslider], outputs=[vid]) |
|
|
|
demo.launch(share=False) |