Spaces:
Build error
Build error
File size: 4,830 Bytes
d171496 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import os, glob
from functools import partial
import glob
import torch
from torch import nn
from PIL import Image
import numpy as np
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class RuleCA(nn.Module):
def __init__(self, hidden_n=6, rule_channels=4, zero_w2=True, device=device):
super().__init__()
# The hard-coded filters:
self.filters = torch.stack([torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]]),
torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]),
torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]).T,
torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])]).to(device)
self.chn = 4
self.rule_channels = rule_channels
self.w1 = nn.Conv2d(4*4+rule_channels, hidden_n, 1).to(device)
self.relu = nn.ReLU()
self.w2 = nn.Conv2d(hidden_n, 4, 1, bias=False).to(device)
if zero_w2:
self.w2.weight.data.zero_()
self.device = device
def perchannel_conv(self, x, filters):
'''filters: [filter_n, h, w]'''
b, ch, h, w = x.shape
y = x.reshape(b*ch, 1, h, w)
y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')
y = torch.nn.functional.conv2d(y, filters[:,None])
return y.reshape(b, -1, h, w)
def forward(self, x, rule=0, update_rate=0.5):
b, ch, xsz, ysz = x.shape
rule_grid = torch.zeros(b, self.rule_channels, xsz, ysz).to(self.device)
rule_grid[:,rule] = 1
y = self.perchannel_conv(x, self.filters) # Apply the filters
y = torch.cat([y, rule_grid], dim=1)
y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
b, c, h, w = y.shape
update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
return x+y*update_mask
def forward_w_rule_grid(self, x, rule_grid, update_rate=0.5):
y = self.perchannel_conv(x, self.filters) # Apply the filters
y = torch.cat([y, rule_grid], dim=1)
y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
b, c, h, w = y.shape
update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
return x+y*update_mask
def to_rgb(self, x):
# TODO: rename this to_rgb & explain
return x[...,:3,:,:]+0.5
def seed(self, n, sz=128):
"""Initializes n 'grids', size sz. In this case all 0s."""
return torch.zeros(n, self.chn, sz, sz).to(self.device)
def to_frames(video_file):
os.system('rm -r guide_frames;mkdir guide_frames')
os.system(f"ffmpeg -i {video_file} guide_frames/%04d.jpg")
def update(preset, enhance, video_file):
# Load presets
ca = RuleCA(hidden_n=32, rule_channels=3)
ca_fn = ''
if preset == 'Glowing Crystals':
ca_fn = 'glowing_crystals.pt'
elif preset == 'Rainbow Diamonds':
ca_fn = 'rainbow_diamonds.pt'
elif preset == 'Dark Diamonds':
ca_fn = 'dark_diamonds.pt'
elif preset == 'Dragon Scales':
ca = RuleCA(hidden_n=16, rule_channels=3)
ca_fn = 'dragon_scales.pt'
ca.load_state_dict(torch.load(ca_fn, map_location=device))
# Get video frames
to_frames(video_file)
size=(426, 240)
vid_size = Image.open(f'guide_frames/0001.jpg').size
if vid_size[0]>vid_size[1]:
size = (256, int(256*(vid_size[1]/vid_size[0])))
else:
size = (int(256*(vid_size[0]/vid_size[1])), 256)
# Starting grid
x = torch.zeros(1, 4, size[1], size[0]).to(ca.device)
os.system("rm -r steps;mkdir steps")
for i in range(2*len(glob.glob('guide_frames/*.jpg'))-1):
# load frame
im = Image.open(f'guide_frames/{i//2+1:04}.jpg').resize(size)
# make rule grid
rule_grid = torch.tensor(np.array(im)/255).permute(2, 0, 1).unsqueeze(0).to(ca.device)
if enhance:
rule_grid = rule_grid * 2 - 0.3 # Add * 2 - 0.3 to 'enhance' an effect
# Apply the updates
with torch.no_grad():
x = ca.forward_w_rule_grid(x, rule_grid.float())
if i%2==0:
img = ca.to_rgb(x).detach().cpu().clip(0, 1).squeeze().permute(1, 2, 0)
img = Image.fromarray(np.array(img*255).astype(np.uint8))
img.save(f'steps/{i//2:05}.jpeg')
# Write output video from saved frames
os.system("ffmpeg -y -v 0 -framerate 24 -i steps/%05d.jpeg video.mp4")
return 'video.mp4'
demo = gr.Blocks()
with demo:
gr.Markdown("Start typing below and then click **Run** to see the output.")
with gr.Row():
preset = gr.Dropdown(['Glowing Crystals', 'Rainbow Diamonds', 'Dark Diamonds', 'Dragon Scales'], label='Preset')
enhance = gr.Checkbox(label='Rescale inputs (more extreme results)')
with gr.Row():
inp = gr.Video(format='mp4', source='upload', label="Input video (ideally <30s)")
out = gr.Video(label="Output")
btn = gr.Button("Run")
btn.click(fn=update, inputs=[preset, enhance, inp], outputs=out)
demo.launch(enable_queue=True) |