|
import os |
|
import torch |
|
import math |
|
import numpy as np |
|
import torch.nn as nn |
|
import lightning.pytorch as pl |
|
import imageio |
|
import safetensors |
|
from torchvision.transforms import v2 |
|
from safetensors.torch import save_file, load_file |
|
|
|
|
|
class BasicBlock(nn.Module): |
|
def __init__(self, channels, kernel_size=(3,3), dropout=0.0): |
|
super().__init__() |
|
layers = [] |
|
num_conv = len(channels)-1 |
|
for i in range(num_conv): |
|
layers.append(nn.Conv2d(channels[i], channels[i+1], |
|
kernel_size=kernel_size, padding='same', padding_mode='reflect', bias=False)) |
|
layers.append(nn.InstanceNorm2d(channels[i+1])) |
|
layers.append(nn.LeakyReLU(inplace=True)) |
|
if dropout > 0.0: |
|
layers.append(nn.Dropout2d(dropout)) |
|
self.operations = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.operations(x) |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, channels, kernel_size=3, num_conv=2, dropout=0.0): |
|
super().__init__() |
|
layers = [] |
|
for i in range(num_conv): |
|
layers.append(nn.InstanceNorm2d(channels)) |
|
if i == num_conv-1 and dropout > 0.0: |
|
layers.append(nn.Dropout2d(dropout)) |
|
layers.append(nn.LeakyReLU(inplace=True)) |
|
layers.append(nn.Conv2d(channels, channels, |
|
kernel_size=kernel_size, padding='same', padding_mode='reflect', bias=False)) |
|
self.operations = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return x + self.operations(x) |
|
|
|
|
|
class ConvPool(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
layers = [] |
|
layers.append(nn.Conv2d(in_channels, out_channels, 4, 2, 1, padding_mode='reflect', bias=False)) |
|
layers.append(nn.InstanceNorm2d(out_channels)) |
|
|
|
self.operations = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.operations(x) |
|
|
|
|
|
class CompactGramMatrix(nn.Module): |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.register_buffer('tril_indices', |
|
torch.tril_indices(in_channels, in_channels, offset=0, dtype=torch.int32)) |
|
|
|
def forward(self, x): |
|
""" |
|
Input: (B, C, H, W) |
|
Output: (B, C*(C+1)//2) compact Gram features |
|
""" |
|
b, c, h, w = x.size() |
|
x = x.view(b, c, -1) / ((h * w) ** 0.5) |
|
|
|
|
|
gram = torch.bmm(x, x.transpose(1, 2)) |
|
|
|
|
|
compact_gram = gram[:, self.tril_indices[0], self.tril_indices[1]] |
|
return compact_gram |
|
|
|
|
|
class EmbeddingNetwork(nn.Module): |
|
def __init__(self): |
|
super(EmbeddingNetwork, self).__init__() |
|
self.input_conv = nn.Conv2d(3, 32, 5, padding='same', padding_mode='reflect', bias=False) |
|
self.conv1 = ResBlock(32, 3, 3) |
|
self.pool1 = ConvPool(32, 64) |
|
self.conv2 = ResBlock(64, 3, 3) |
|
self.pool2 = ConvPool(64, 128) |
|
self.conv3 = ResBlock(128, 3, 3) |
|
self.pool3 = ConvPool(128, 256) |
|
self.conv4 = ResBlock(256, 3, 3) |
|
self.gram = CompactGramMatrix(256) |
|
self.compact = nn.Linear(256*(256+1)//2, 1024, bias=False) |
|
self.conpactnorm = nn.LayerNorm(1024, elementwise_affine=False) |
|
self.fc1 = nn.Linear(1024, 1024, bias=False) |
|
self.fc1norm = nn.LayerNorm(1024, elementwise_affine=False) |
|
self.act = nn.LeakyReLU(inplace=True) |
|
self.fc2 = nn.Linear(1024, 1024, bias=False) |
|
self.fc2norm = nn.LayerNorm(1024, elementwise_affine=False) |
|
self.fc3 = nn.Linear(1024, 6) |
|
|
|
def forward(self, x): |
|
x = self.input_conv(x) |
|
x = self.pool1(self.conv1(x)) |
|
x = self.pool2(self.conv2(x)) |
|
x = self.pool3(self.conv3(x)) |
|
x = self.conv4(x) |
|
|
|
x = self.gram(x) |
|
x = self.compact(x) |
|
x = self.conpactnorm(x) |
|
x = self.act(self.fc1norm(self.fc1(x))) |
|
x = self.act(self.fc2norm(self.fc2(x))) |
|
x = self.fc3(x) |
|
return x |
|
|
|
|
|
|
|
class PLModule(pl.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.network = EmbeddingNetwork() |
|
self.register_buffer("val_pos_sum", torch.tensor(0.0)) |
|
self.register_buffer("val_neg_sum", torch.tensor(0.0)) |
|
self.register_buffer("val_count", torch.tensor(0)) |
|
self.register_buffer("train_pos_sum", torch.tensor(0.0)) |
|
self.register_buffer("train_neg_sum", torch.tensor(0.0)) |
|
self.register_buffer("train_count", torch.tensor(0)) |
|
|
|
def forward(self, x): |
|
return self.network(x) |
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0): |
|
outputs = self.forward(batch[0]) |
|
return outputs, batch[1] |
|
|
|
|
|
def adj_size(img, size=1536): |
|
h, w = img.shape[1], img.shape[2] |
|
area = h * w |
|
if area > size ** 2: |
|
scale_factor = (size ** 2 / area) ** 0.5 |
|
new_h = math.floor(h * scale_factor) |
|
new_w = math.floor(w * scale_factor) |
|
img = v2.functional.resize(img, (new_w, new_h)) |
|
return img |
|
|
|
|
|
def closest_interval(img, interval=8): |
|
c, h, w = img.shape |
|
new_h = h - (h % interval) if h % interval != 0 else h |
|
new_w = w - (w % interval) if w % interval != 0 else w |
|
h_start = (h - new_h) // 2 |
|
w_start = (w - new_w) // 2 |
|
new_h, new_w = max(new_h, interval), max(new_w, interval) |
|
return img[:, h_start:h_start + new_h, w_start:w_start + new_w] |
|
|
|
|
|
if __name__ == '__main__': |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model = EmbeddingNetwork() |
|
state_dict = load_file("Style_Embedder_v3.safetensors") |
|
model.load_state_dict(state_dict) |
|
|
|
model.to(device).to(torch.float16) |
|
model.eval() |
|
|
|
img = imageio.v3.imread('images_for_style_embedding/6857740.webp').copy() |
|
img = torch.from_numpy(img).permute(2, 0, 1) |
|
img = closest_interval(adj_size(img)) |
|
img = 2*(img/255)-1 |
|
img = img.unsqueeze(0).to(device).to(torch.float16) |
|
|
|
pred = model(img) |
|
print(pred) |
|
|