Text2video2024 / lama.py
Tello2020's picture
Upload 14 files
fd5f698
"""
Based on the implementation from:
https://huggingface.co/spaces/fffiloni/lama-video-watermark-remover/tree/main
Modules were adapted by Hans Brouwer to only support the final configuration of the model uploaded here:
https://huggingface.co/akhaliq/lama
Apache License 2.0: https://github.com/advimman/lama/blob/main/LICENSE
@article{suvorov2021resolution,
title={Resolution-robust Large Mask Inpainting with Fourier Convolutions},
author={Suvorov, Roman and Logacheva, Elizaveta and Mashikhin, Anton and Remizova, Anastasia and Ashukha, Arsenii and Silvestrov, Aleksei and Kong, Naejin and Goka, Harshith and Park, Kiwoong and Lempitsky, Victor},
journal={arXiv preprint arXiv:2109.07161},
year={2021}
}
"""
import os
import sys
from urllib.request import urlretrieve
import torch
from einops import rearrange
from PIL import Image
from torch import nn
from torch.nn import functional as F
from torchvision.transforms.functional import to_tensor
from tqdm import tqdm
from train import export_to_video
LAMA_URL = "https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt"
LAMA_PATH = "models/lama.ckpt"
def download_progress(t):
last_b = [0]
def update_to(b=1, bsize=1, tsize=None):
if tsize is not None:
t.total = tsize
t.update((b - last_b[0]) * bsize)
last_b[0] = b
return update_to
def download(url, path):
with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=path) as t:
urlretrieve(url, filename=path, reporthook=download_progress(t), data=None)
class FourierUnit(nn.Module):
def __init__(self, in_channels, out_channels, groups=1):
super(FourierUnit, self).__init__()
self.groups = groups
self.conv_layer = torch.nn.Conv2d(
in_channels=in_channels * 2,
out_channels=out_channels * 2,
kernel_size=1,
stride=1,
padding=0,
groups=self.groups,
bias=False,
)
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, x):
batch = x.shape[0]
# (batch, c, h, w/2+1, 2)
fft_dim = (-2, -1)
ffted = torch.fft.rfftn(x, dim=fft_dim, norm="ortho")
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
ffted = ffted.view((batch, -1) + ffted.size()[3:])
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
ffted = self.relu(self.bn(ffted))
# (batch,c, t, h, w/2+1, 2)
ffted = ffted.view((batch, -1, 2) + ffted.size()[2:]).permute(0, 1, 3, 4, 2).contiguous()
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
ifft_shape_slice = x.shape[-2:]
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm="ortho")
return output
class SpectralTransform(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(SpectralTransform, self).__init__()
self.stride = stride
if stride == 2:
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
else:
self.downsample = nn.Identity()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False),
nn.BatchNorm2d(out_channels // 2),
nn.ReLU(inplace=True),
)
self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups)
self.conv2 = torch.nn.Conv2d(out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
def forward(self, x):
x = self.downsample(x)
x = self.conv1(x)
output = self.fu(x)
output = self.conv2(x + output)
return output
class FFC(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
ratio_gin,
ratio_gout,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
padding_type="reflect",
gated=False,
):
super(FFC, self).__init__()
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
self.stride = stride
in_cg = int(in_channels * ratio_gin)
in_cl = in_channels - in_cg
out_cg = int(out_channels * ratio_gout)
out_cl = out_channels - out_cg
self.ratio_gin = ratio_gin
self.ratio_gout = ratio_gout
self.global_in_num = in_cg
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
self.convl2l = module(
in_cl, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type
)
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
self.convl2g = module(
in_cl, out_cg, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type
)
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
self.convg2l = module(
in_cg, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type
)
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
self.convg2g = module(in_cg, out_cg, stride, 1 if groups == 1 else groups // 2)
self.gated = gated
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
self.gate = module(in_channels, 2, 1)
def forward(self, x):
x_l, x_g = x if type(x) is tuple else (x, 0)
out_xl, out_xg = 0, 0
if self.gated:
total_input_parts = [x_l]
if torch.is_tensor(x_g):
total_input_parts.append(x_g)
total_input = torch.cat(total_input_parts, dim=1)
gates = torch.sigmoid(self.gate(total_input))
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
else:
g2l_gate, l2g_gate = 1, 1
if self.ratio_gout != 1:
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
if self.ratio_gout != 0:
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
return out_xl, out_xg
class FFC_BN_ACT(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
ratio_gin=0,
ratio_gout=0,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
norm_layer=nn.BatchNorm2d,
activation_layer=nn.ReLU,
):
super(FFC_BN_ACT, self).__init__()
self.ffc = FFC(
in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride, padding, dilation, groups, bias
)
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
global_channels = int(out_channels * ratio_gout)
self.bn_l = lnorm(out_channels - global_channels)
self.bn_g = gnorm(global_channels)
lact = nn.Identity if ratio_gout == 1 else activation_layer
gact = nn.Identity if ratio_gout == 0 else activation_layer
self.act_l = lact(inplace=True)
self.act_g = gact(inplace=True)
def forward(self, x):
x_l, x_g = self.ffc(x)
x_l = self.act_l(self.bn_l(x_l))
x_g = self.act_g(self.bn_g(x_g))
return x_l, x_g
class FFCResnetBlock(nn.Module):
def __init__(self, dim, ratio_gin, ratio_gout):
super().__init__()
self.conv1 = FFC_BN_ACT(
dim, dim, kernel_size=3, padding=1, dilation=1, ratio_gin=ratio_gin, ratio_gout=ratio_gout
)
self.conv2 = FFC_BN_ACT(
dim, dim, kernel_size=3, padding=1, dilation=1, ratio_gin=ratio_gin, ratio_gout=ratio_gout
)
def forward(self, x):
x_l, x_g = x if type(x) is tuple else (x, 0)
id_l, id_g = x_l, x_g
x_l, x_g = self.conv1((x_l, x_g))
x_l, x_g = self.conv2((x_l, x_g))
x_l, x_g = id_l + x_l, id_g + x_g
out = x_l, x_g
return out
class ConcatTupleLayer(nn.Module):
def forward(self, x):
assert isinstance(x, tuple)
x_l, x_g = x
assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
if not torch.is_tensor(x_g):
return x_l
return torch.cat(x, dim=1)
class LargeMaskInpainting(nn.Module):
def __init__(self, input_nc=4, output_nc=3, ngf=64, n_downsampling=3, n_blocks=18, max_features=1024):
super().__init__()
model = [nn.ReflectionPad2d(3), FFC_BN_ACT(input_nc, ngf, kernel_size=7)]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [
FFC_BN_ACT(
min(max_features, ngf * mult),
min(max_features, ngf * mult * 2),
kernel_size=3,
stride=2,
padding=1,
ratio_gout=0.75 if i == n_downsampling - 1 else 0,
)
]
### resnet blocks
for i in range(n_blocks):
cur_resblock = FFCResnetBlock(min(max_features, ngf * 2**n_downsampling), ratio_gin=0.75, ratio_gout=0.75)
model += [cur_resblock]
model += [ConcatTupleLayer()]
### upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [
nn.ConvTranspose2d(
min(max_features, ngf * mult),
min(max_features, int(ngf * mult / 2)),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
nn.BatchNorm2d(min(max_features, int(ngf * mult / 2))),
nn.ReLU(True),
]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7), nn.Sigmoid()]
self.model = nn.Sequential(*model)
def forward(self, img, mask):
masked_img = img * (1 - mask)
masked_img = torch.cat([masked_img, mask], dim=1)
pred = self.model(masked_img)
inpainted = mask * pred + (1 - mask) * img
return inpainted
@torch.inference_mode()
def inpaint_watermark(imgs):
if not os.path.exists(LAMA_PATH):
download(LAMA_URL, LAMA_PATH)
mask = to_tensor(Image.open("./utils/mask.png").convert("L")).unsqueeze(0).to(imgs.device)
if mask.shape[-1] != imgs.shape[-1]:
mask = F.interpolate(mask, size=(imgs.shape[2], imgs.shape[3]), mode="nearest")
mask = mask.expand(imgs.shape[0], 1, mask.shape[2], mask.shape[3])
model = LargeMaskInpainting().to(imgs.device)
state_dict = torch.load(LAMA_PATH, map_location=imgs.device)["state_dict"]
g_dict = {k.replace("generator.", ""): v for k, v in state_dict.items() if k.startswith("generator")}
model.load_state_dict(g_dict)
inpainted = model.forward(imgs, mask)
return inpainted
if __name__ == "__main__":
import decord
decord.bridge.set_bridge("torch")
if len(sys.argv) < 2:
print("Usage: python -m utils.lama <path/to/video>")
sys.exit(1)
video_path = sys.argv[1]
out_path = video_path.replace(".mp4", " inpainted.mp4")
vr = decord.VideoReader(video_path)
fps = vr.get_avg_fps()
video = rearrange(vr[:], "f h w c -> f c h w").div(255)
inpainted = inpaint_watermark(video)
inpainted = rearrange(inpainted, "f c h w -> f h w c").clamp(0, 1).mul(255).byte().cpu().numpy()
export_to_video(inpainted, out_path, fps)