File size: 4,830 Bytes
d171496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import os, glob
from functools import partial
import glob
import torch
from torch import nn
from PIL import Image
import numpy as np

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class RuleCA(nn.Module):
  def __init__(self, hidden_n=6, rule_channels=4, zero_w2=True, device=device):
    super().__init__()
    # The hard-coded filters:
    self.filters = torch.stack([torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]]),
          torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]),
          torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]).T,
          torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])]).to(device)
    self.chn = 4
    self.rule_channels = rule_channels
    self.w1 = nn.Conv2d(4*4+rule_channels, hidden_n, 1).to(device)
    self.relu = nn.ReLU()
    self.w2 = nn.Conv2d(hidden_n, 4, 1, bias=False).to(device)
    if zero_w2:
      self.w2.weight.data.zero_()
    self.device = device

  def perchannel_conv(self, x, filters):
    '''filters: [filter_n, h, w]'''
    b, ch, h, w = x.shape
    y = x.reshape(b*ch, 1, h, w)
    y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')
    y = torch.nn.functional.conv2d(y, filters[:,None])
    return y.reshape(b, -1, h, w)
    
  def forward(self, x, rule=0, update_rate=0.5):
    b, ch, xsz, ysz = x.shape
    rule_grid = torch.zeros(b, self.rule_channels, xsz, ysz).to(self.device)
    rule_grid[:,rule] = 1
    y = self.perchannel_conv(x, self.filters) # Apply the filters
    y = torch.cat([y, rule_grid], dim=1)
    y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
    b, c, h, w = y.shape
    update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
    return x+y*update_mask

  def forward_w_rule_grid(self, x, rule_grid, update_rate=0.5):
    y = self.perchannel_conv(x, self.filters) # Apply the filters
    y = torch.cat([y, rule_grid], dim=1)
    y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
    b, c, h, w = y.shape
    update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
    return x+y*update_mask
  
  def to_rgb(self, x):
    # TODO: rename this to_rgb & explain
    return x[...,:3,:,:]+0.5

  def seed(self, n, sz=128):
    """Initializes n 'grids', size sz. In this case all 0s."""
    return torch.zeros(n, self.chn, sz, sz).to(self.device)

def to_frames(video_file):
  os.system('rm -r guide_frames;mkdir guide_frames')
  os.system(f"ffmpeg -i {video_file} guide_frames/%04d.jpg")

def update(preset, enhance, video_file):

  # Load presets
  ca = RuleCA(hidden_n=32, rule_channels=3)
  ca_fn = ''
  if preset == 'Glowing Crystals':
    ca_fn = 'glowing_crystals.pt'
  elif preset == 'Rainbow Diamonds':
    ca_fn = 'rainbow_diamonds.pt'
  elif preset == 'Dark Diamonds':
    ca_fn = 'dark_diamonds.pt'
  elif preset == 'Dragon Scales':
    ca = RuleCA(hidden_n=16, rule_channels=3)
    ca_fn = 'dragon_scales.pt'
  
  ca.load_state_dict(torch.load(ca_fn, map_location=device))

  # Get video frames
  to_frames(video_file)

  size=(426, 240)
  vid_size = Image.open(f'guide_frames/0001.jpg').size
  if vid_size[0]>vid_size[1]:
    size = (256, int(256*(vid_size[1]/vid_size[0])))
  else:
    size = (int(256*(vid_size[0]/vid_size[1])), 256)

  # Starting grid
  x = torch.zeros(1, 4, size[1], size[0]).to(ca.device)
  os.system("rm -r steps;mkdir steps")
  for i in range(2*len(glob.glob('guide_frames/*.jpg'))-1):
    # load frame
    im = Image.open(f'guide_frames/{i//2+1:04}.jpg').resize(size)

    # make rule grid
    rule_grid = torch.tensor(np.array(im)/255).permute(2, 0, 1).unsqueeze(0).to(ca.device)
    if enhance:
      rule_grid = rule_grid * 2 - 0.3 # Add * 2 - 0.3 to 'enhance' an effect
    
    # Apply the updates
    with torch.no_grad():
      x = ca.forward_w_rule_grid(x, rule_grid.float())
      if i%2==0:
        img = ca.to_rgb(x).detach().cpu().clip(0, 1).squeeze().permute(1, 2, 0)
        img = Image.fromarray(np.array(img*255).astype(np.uint8))
        img.save(f'steps/{i//2:05}.jpeg')

  # Write output video from saved frames
  os.system("ffmpeg -y -v 0 -framerate 24 -i steps/%05d.jpeg video.mp4")
  return 'video.mp4'


demo = gr.Blocks()

with demo:
    gr.Markdown("Start typing below and then click **Run** to see the output.")
    with gr.Row():
        preset = gr.Dropdown(['Glowing Crystals', 'Rainbow Diamonds', 'Dark Diamonds', 'Dragon Scales'], label='Preset')
        enhance = gr.Checkbox(label='Rescale inputs (more extreme results)')
    with gr.Row():
        inp = gr.Video(format='mp4', source='upload', label="Input video (ideally <30s)")
        out = gr.Video(label="Output")
    btn = gr.Button("Run")
    btn.click(fn=update, inputs=[preset, enhance, inp], outputs=out)

demo.launch(enable_queue=True)