video_nca / app.py
johnowhitaker's picture
Create app.py
d171496
raw
history blame
4.83 kB
import gradio as gr
import os, glob
from functools import partial
import glob
import torch
from torch import nn
from PIL import Image
import numpy as np
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class RuleCA(nn.Module):
def __init__(self, hidden_n=6, rule_channels=4, zero_w2=True, device=device):
super().__init__()
# The hard-coded filters:
self.filters = torch.stack([torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]]),
torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]),
torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]).T,
torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])]).to(device)
self.chn = 4
self.rule_channels = rule_channels
self.w1 = nn.Conv2d(4*4+rule_channels, hidden_n, 1).to(device)
self.relu = nn.ReLU()
self.w2 = nn.Conv2d(hidden_n, 4, 1, bias=False).to(device)
if zero_w2:
self.w2.weight.data.zero_()
self.device = device
def perchannel_conv(self, x, filters):
'''filters: [filter_n, h, w]'''
b, ch, h, w = x.shape
y = x.reshape(b*ch, 1, h, w)
y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')
y = torch.nn.functional.conv2d(y, filters[:,None])
return y.reshape(b, -1, h, w)
def forward(self, x, rule=0, update_rate=0.5):
b, ch, xsz, ysz = x.shape
rule_grid = torch.zeros(b, self.rule_channels, xsz, ysz).to(self.device)
rule_grid[:,rule] = 1
y = self.perchannel_conv(x, self.filters) # Apply the filters
y = torch.cat([y, rule_grid], dim=1)
y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
b, c, h, w = y.shape
update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
return x+y*update_mask
def forward_w_rule_grid(self, x, rule_grid, update_rate=0.5):
y = self.perchannel_conv(x, self.filters) # Apply the filters
y = torch.cat([y, rule_grid], dim=1)
y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
b, c, h, w = y.shape
update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
return x+y*update_mask
def to_rgb(self, x):
# TODO: rename this to_rgb & explain
return x[...,:3,:,:]+0.5
def seed(self, n, sz=128):
"""Initializes n 'grids', size sz. In this case all 0s."""
return torch.zeros(n, self.chn, sz, sz).to(self.device)
def to_frames(video_file):
os.system('rm -r guide_frames;mkdir guide_frames')
os.system(f"ffmpeg -i {video_file} guide_frames/%04d.jpg")
def update(preset, enhance, video_file):
# Load presets
ca = RuleCA(hidden_n=32, rule_channels=3)
ca_fn = ''
if preset == 'Glowing Crystals':
ca_fn = 'glowing_crystals.pt'
elif preset == 'Rainbow Diamonds':
ca_fn = 'rainbow_diamonds.pt'
elif preset == 'Dark Diamonds':
ca_fn = 'dark_diamonds.pt'
elif preset == 'Dragon Scales':
ca = RuleCA(hidden_n=16, rule_channels=3)
ca_fn = 'dragon_scales.pt'
ca.load_state_dict(torch.load(ca_fn, map_location=device))
# Get video frames
to_frames(video_file)
size=(426, 240)
vid_size = Image.open(f'guide_frames/0001.jpg').size
if vid_size[0]>vid_size[1]:
size = (256, int(256*(vid_size[1]/vid_size[0])))
else:
size = (int(256*(vid_size[0]/vid_size[1])), 256)
# Starting grid
x = torch.zeros(1, 4, size[1], size[0]).to(ca.device)
os.system("rm -r steps;mkdir steps")
for i in range(2*len(glob.glob('guide_frames/*.jpg'))-1):
# load frame
im = Image.open(f'guide_frames/{i//2+1:04}.jpg').resize(size)
# make rule grid
rule_grid = torch.tensor(np.array(im)/255).permute(2, 0, 1).unsqueeze(0).to(ca.device)
if enhance:
rule_grid = rule_grid * 2 - 0.3 # Add * 2 - 0.3 to 'enhance' an effect
# Apply the updates
with torch.no_grad():
x = ca.forward_w_rule_grid(x, rule_grid.float())
if i%2==0:
img = ca.to_rgb(x).detach().cpu().clip(0, 1).squeeze().permute(1, 2, 0)
img = Image.fromarray(np.array(img*255).astype(np.uint8))
img.save(f'steps/{i//2:05}.jpeg')
# Write output video from saved frames
os.system("ffmpeg -y -v 0 -framerate 24 -i steps/%05d.jpeg video.mp4")
return 'video.mp4'
demo = gr.Blocks()
with demo:
gr.Markdown("Start typing below and then click **Run** to see the output.")
with gr.Row():
preset = gr.Dropdown(['Glowing Crystals', 'Rainbow Diamonds', 'Dark Diamonds', 'Dragon Scales'], label='Preset')
enhance = gr.Checkbox(label='Rescale inputs (more extreme results)')
with gr.Row():
inp = gr.Video(format='mp4', source='upload', label="Input video (ideally <30s)")
out = gr.Video(label="Output")
btn = gr.Button("Run")
btn.click(fn=update, inputs=[preset, enhance, inp], outputs=out)
demo.launch(enable_queue=True)