Anime_Images_Style_Embedder / minimal_script.py
Fgdfgfthgr's picture
Upload 3 files
c4fbd39 verified
import os
import torch
import math
import numpy as np
import torch.nn as nn
import lightning.pytorch as pl
import imageio
import safetensors
from torchvision.transforms import v2
from safetensors.torch import save_file, load_file
class BasicBlock(nn.Module):
def __init__(self, channels, kernel_size=(3,3), dropout=0.0):
super().__init__()
layers = []
num_conv = len(channels)-1
for i in range(num_conv):
layers.append(nn.Conv2d(channels[i], channels[i+1],
kernel_size=kernel_size, padding='same', padding_mode='reflect', bias=False))
layers.append(nn.InstanceNorm2d(channels[i+1]))
layers.append(nn.LeakyReLU(inplace=True))
if dropout > 0.0:
layers.append(nn.Dropout2d(dropout))
self.operations = nn.Sequential(*layers)
def forward(self, x):
return self.operations(x)
class ResBlock(nn.Module):
def __init__(self, channels, kernel_size=3, num_conv=2, dropout=0.0):
super().__init__()
layers = []
for i in range(num_conv):
layers.append(nn.InstanceNorm2d(channels))
if i == num_conv-1 and dropout > 0.0:
layers.append(nn.Dropout2d(dropout))
layers.append(nn.LeakyReLU(inplace=True))
layers.append(nn.Conv2d(channels, channels,
kernel_size=kernel_size, padding='same', padding_mode='reflect', bias=False))
self.operations = nn.Sequential(*layers)
def forward(self, x):
return x + self.operations(x)
class ConvPool(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
layers = []
layers.append(nn.Conv2d(in_channels, out_channels, 4, 2, 1, padding_mode='reflect', bias=False))
layers.append(nn.InstanceNorm2d(out_channels))
#layers.append(nn.LeakyReLU(inplace=True))
self.operations = nn.Sequential(*layers)
def forward(self, x):
return self.operations(x)
class CompactGramMatrix(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
# Precompute indices for lower triangle (including diagonal)
self.register_buffer('tril_indices',
torch.tril_indices(in_channels, in_channels, offset=0, dtype=torch.int32))
def forward(self, x):
"""
Input: (B, C, H, W)
Output: (B, C*(C+1)//2) compact Gram features
"""
b, c, h, w = x.size()
x = x.view(b, c, -1) / ((h * w) ** 0.5) # Flatten spatial dimensions -> (B, C, H*W), then normalise
# Compute full Gram matrix (still needed temporarily)
gram = torch.bmm(x, x.transpose(1, 2)) # (B, C, C)
# Extract lower triangle including diagonal
compact_gram = gram[:, self.tril_indices[0], self.tril_indices[1]] # (B, n_unique)
return compact_gram
class EmbeddingNetwork(nn.Module):
def __init__(self):
super(EmbeddingNetwork, self).__init__()
self.input_conv = nn.Conv2d(3, 32, 5, padding='same', padding_mode='reflect', bias=False)
self.conv1 = ResBlock(32, 3, 3)
self.pool1 = ConvPool(32, 64) # 2
self.conv2 = ResBlock(64, 3, 3)
self.pool2 = ConvPool(64, 128) # 4
self.conv3 = ResBlock(128, 3, 3)
self.pool3 = ConvPool(128, 256) # 8
self.conv4 = ResBlock(256, 3, 3)
self.gram = CompactGramMatrix(256)
self.compact = nn.Linear(256*(256+1)//2, 1024, bias=False)
self.conpactnorm = nn.LayerNorm(1024, elementwise_affine=False)
self.fc1 = nn.Linear(1024, 1024, bias=False)
self.fc1norm = nn.LayerNorm(1024, elementwise_affine=False)
self.act = nn.LeakyReLU(inplace=True)
self.fc2 = nn.Linear(1024, 1024, bias=False)
self.fc2norm = nn.LayerNorm(1024, elementwise_affine=False)
self.fc3 = nn.Linear(1024, 6)
def forward(self, x):
x = self.input_conv(x)
x = self.pool1(self.conv1(x))
x = self.pool2(self.conv2(x))
x = self.pool3(self.conv3(x))
x = self.conv4(x)
x = self.gram(x)
x = self.compact(x)
x = self.conpactnorm(x)
x = self.act(self.fc1norm(self.fc1(x)))
x = self.act(self.fc2norm(self.fc2(x)))
x = self.fc3(x)
return x
class PLModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.network = EmbeddingNetwork()
self.register_buffer("val_pos_sum", torch.tensor(0.0))
self.register_buffer("val_neg_sum", torch.tensor(0.0))
self.register_buffer("val_count", torch.tensor(0))
self.register_buffer("train_pos_sum", torch.tensor(0.0))
self.register_buffer("train_neg_sum", torch.tensor(0.0))
self.register_buffer("train_count", torch.tensor(0))
def forward(self, x):
return self.network(x)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self.forward(batch[0])
return outputs, batch[1]
def adj_size(img, size=1536):
h, w = img.shape[1], img.shape[2]
area = h * w
if area > size ** 2:
scale_factor = (size ** 2 / area) ** 0.5
new_h = math.floor(h * scale_factor)
new_w = math.floor(w * scale_factor)
img = v2.functional.resize(img, (new_w, new_h))
return img
def closest_interval(img, interval=8):
c, h, w = img.shape
new_h = h - (h % interval) if h % interval != 0 else h
new_w = w - (w % interval) if w % interval != 0 else w
h_start = (h - new_h) // 2
w_start = (w - new_w) // 2
new_h, new_w = max(new_h, interval), max(new_w, interval)
return img[:, h_start:h_start + new_h, w_start:w_start + new_w]
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EmbeddingNetwork()
state_dict = load_file("Style_Embedder_v3.safetensors")
model.load_state_dict(state_dict)
model.to(device).to(torch.float16)
model.eval()
img = imageio.v3.imread('images_for_style_embedding/6857740.webp').copy()
img = torch.from_numpy(img).permute(2, 0, 1)
img = closest_interval(adj_size(img))
img = 2*(img/255)-1
img = img.unsqueeze(0).to(device).to(torch.float16)
pred = model(img)
print(pred)