File size: 5,482 Bytes
d171496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e6d026
d171496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfe9ff5
d171496
 
 
2e6d026
 
d171496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d8858
d171496
 
2e6d026
 
 
d171496
 
 
 
2e6d026
d28e00b
 
 
d171496
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import os, glob
from functools import partial
import glob
import torch
from torch import nn
from PIL import Image
import numpy as np

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class RuleCA(nn.Module):
  def __init__(self, hidden_n=6, rule_channels=4, zero_w2=True, device=device):
    super().__init__()
    # The hard-coded filters:
    self.filters = torch.stack([torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]]),
          torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]),
          torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]).T,
          torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])]).to(device)
    self.chn = 4
    self.rule_channels = rule_channels
    self.w1 = nn.Conv2d(4*4+rule_channels, hidden_n, 1).to(device)
    self.relu = nn.ReLU()
    self.w2 = nn.Conv2d(hidden_n, 4, 1, bias=False).to(device)
    if zero_w2:
      self.w2.weight.data.zero_()
    self.device = device

  def perchannel_conv(self, x, filters):
    '''filters: [filter_n, h, w]'''
    b, ch, h, w = x.shape
    y = x.reshape(b*ch, 1, h, w)
    y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')
    y = torch.nn.functional.conv2d(y, filters[:,None])
    return y.reshape(b, -1, h, w)
    
  def forward(self, x, rule=0, update_rate=0.5):
    b, ch, xsz, ysz = x.shape
    rule_grid = torch.zeros(b, self.rule_channels, xsz, ysz).to(self.device)
    rule_grid[:,rule] = 1
    y = self.perchannel_conv(x, self.filters) # Apply the filters
    y = torch.cat([y, rule_grid], dim=1)
    y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
    b, c, h, w = y.shape
    update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
    return x+y*update_mask

  def forward_w_rule_grid(self, x, rule_grid, update_rate=0.5):
    y = self.perchannel_conv(x, self.filters) # Apply the filters
    y = torch.cat([y, rule_grid], dim=1)
    y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
    b, c, h, w = y.shape
    update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
    return x+y*update_mask
  
  def to_rgb(self, x):
    # TODO: rename this to_rgb & explain
    return x[...,:3,:,:]+0.5

  def seed(self, n, sz=128):
    """Initializes n 'grids', size sz. In this case all 0s."""
    return torch.zeros(n, self.chn, sz, sz).to(self.device)

def to_frames(video_file):
  os.system('rm -r guide_frames;mkdir guide_frames')
  os.system(f"ffmpeg -i {video_file} guide_frames/%04d.jpg")

def update(preset, enhance, scale2x, video_file):

  # Load presets
  ca = RuleCA(hidden_n=32, rule_channels=3)
  ca_fn = ''
  if preset == 'Glowing Crystals':
    ca_fn = 'glowing_crystals.pt'
  elif preset == 'Rainbow Diamonds':
    ca_fn = 'rainbow_diamonds.pt'
  elif preset == 'Dark Diamonds':
    ca_fn = 'dark_diamonds.pt'
  elif preset == 'Dragon Scales':
    ca = RuleCA(hidden_n=16, rule_channels=3)
    ca_fn = 'dragon_scales.pt'
  
  ca.load_state_dict(torch.load(ca_fn, map_location=device))

  # Get video frames
  to_frames(video_file)

  size=(426, 240)
  vid_size = Image.open(f'guide_frames/0001.jpg').size
  if vid_size[0]>vid_size[1]: # Change < to > if larger side should be capped at 256px
    size = (256, int(256*(vid_size[1]/vid_size[0])))
  else:
    size = (int(256*(vid_size[0]/vid_size[1])), 256)
  if scale2x:
    size = (size[0]*2, size[1]*2)

  # Starting grid
  x = torch.zeros(1, 4, size[1], size[0]).to(ca.device)
  os.system("rm -r steps;mkdir steps")
  for i in range(2*len(glob.glob('guide_frames/*.jpg'))-1):
    # load frame
    im = Image.open(f'guide_frames/{i//2+1:04}.jpg').resize(size)

    # make rule grid
    rule_grid = torch.tensor(np.array(im)/255).permute(2, 0, 1).unsqueeze(0).to(ca.device)
    if enhance:
      rule_grid = rule_grid * 2 - 0.3 # Add * 2 - 0.3 to 'enhance' an effect
    
    # Apply the updates
    with torch.no_grad():
      x = ca.forward_w_rule_grid(x, rule_grid.float())
      if i%2==0:
        img = ca.to_rgb(x).detach().cpu().clip(0, 1).squeeze().permute(1, 2, 0)
        img = Image.fromarray(np.array(img*255).astype(np.uint8))
        img.save(f'steps/{i//2:05}.jpeg')

  # Write output video from saved frames
  os.system("ffmpeg -y -v 0 -framerate 24 -i steps/%05d.jpeg video.mp4")
  return 'video.mp4'


demo = gr.Blocks()

with demo:
    gr.Markdown("Choose a preset below, upload a video and then click **Run** to see the output. Read [this report](https://wandb.ai/johnowhitaker/nca/reports/Fun-with-Neural-Cellular-Automata--VmlldzoyMDQ5Mjg0) for background on this project, or check out my [AI art course](https://github.com/johnowhitaker/aiaiart) for an in-depth lesson on Neural Cellular Automata like this.")
    with gr.Row():
        preset = gr.Dropdown(['Glowing Crystals', 'Rainbow Diamonds', 'Dark Diamonds', 'Dragon Scales'], label='Preset')
        with gr.Column():
          enhance = gr.Checkbox(label='Rescale inputs (more extreme results)')
          scale2x = gr.Checkbox(label='Larger output (slower)')
    with gr.Row():
        inp = gr.Video(format='mp4', source='upload', label="Input video (ideally <30s)")
        out = gr.Video(label="Output")
    btn = gr.Button("Run")
    btn.click(fn=update, inputs=[preset, enhance, scale2x, inp], outputs=out)
    
    with gr.Row():
        gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=gradio-blocks_video_nca)")

demo.launch(enable_queue=True)