salient-style-transfer / inference.py
jamino30's picture
Upload folder using huggingface_hub
28ac920 verified
raw
history blame
2.34 kB
import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import gaussian_blur
from tqdm import tqdm
def gram_matrix(feature):
b, c, h, w = feature.size()
feature = feature.view(b * c, h * w)
return feature @ feature.t()
def compute_loss(generated, content, style, bg_masks, alpha, beta):
content_loss = sum(F.mse_loss(gf, cf) for gf, cf in zip(generated, content))
style_loss = sum(
F.mse_loss(
gram_matrix(gf * bg) if bg is not None else gram_matrix(gf),
gram_matrix(sf * bg) if bg is not None else gram_matrix(sf),
) / len(generated)
for gf, sf, bg in zip(generated, style, bg_masks or [None] * len(generated))
)
return alpha * content_loss, beta * style_loss, alpha * content_loss + beta * style_loss
def inference(
*,
model,
sod_model,
content_image,
content_image_norm,
style_features,
apply_to_background,
lr=5e-2,
iterations=101,
optim_caller=optim.AdamW,
alpha=1,
beta=1,
):
generated_image = content_image.clone().requires_grad_(True)
optimizer = optim_caller([generated_image], lr=lr)
with torch.no_grad():
content_features = model(content_image)
bg_masks = None
if apply_to_background:
seg_output = torch.sigmoid(sod_model(content_image_norm)[0])
bg_mask = (seg_output <= 0.7).float()
bg_masks = [
F.interpolate(bg_mask.unsqueeze(1), size=cf.shape[2:], mode='bilinear', align_corners=False)
for cf in content_features
]
def closure():
optimizer.zero_grad()
generated_features = model(generated_image)
content_loss, style_loss, total_loss = compute_loss(
generated_features, content_features, style_features, bg_masks, alpha, beta
)
total_loss.backward()
return total_loss
for _ in tqdm(range(iterations)):
optimizer.step(closure)
if apply_to_background:
with torch.no_grad():
fg_mask = F.interpolate(1 - bg_masks[0], size=generated_image.shape[2:], mode='nearest')
generated_image.data.mul_(1 - fg_mask).add_(content_image.data * fg_mask)
return generated_image