Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
import matplotlib.pyplot as plt | |
import cv2 | |
import torch.nn.functional as F | |
#torch.set_printoptions(precision=10) | |
class _bn_relu_conv(nn.Module): | |
def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): | |
super(_bn_relu_conv, self).__init__() | |
self.model = nn.Sequential( | |
nn.BatchNorm2d(in_filters, eps=1e-3), | |
nn.LeakyReLU(0.2), | |
nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros') | |
) | |
def forward(self, x): | |
return self.model(x) | |
# the following are for debugs | |
print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) | |
for i,layer in enumerate(self.model): | |
if i != 2: | |
x = layer(x) | |
else: | |
x = layer(x) | |
#x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0) | |
print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) | |
print(x[0]) | |
return x | |
class _u_bn_relu_conv(nn.Module): | |
def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): | |
super(_u_bn_relu_conv, self).__init__() | |
self.model = nn.Sequential( | |
nn.BatchNorm2d(in_filters, eps=1e-3), | |
nn.LeakyReLU(0.2), | |
nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)), | |
nn.Upsample(scale_factor=2, mode='nearest') | |
) | |
def forward(self, x): | |
return self.model(x) | |
class _shortcut(nn.Module): | |
def __init__(self, in_filters, nb_filters, subsample=1): | |
super(_shortcut, self).__init__() | |
self.process = False | |
self.model = None | |
if in_filters != nb_filters or subsample != 1: | |
self.process = True | |
self.model = nn.Sequential( | |
nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample) | |
) | |
def forward(self, x, y): | |
#print(x.size(), y.size(), self.process) | |
if self.process: | |
y0 = self.model(x) | |
#print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape) | |
return y0 + y | |
else: | |
#print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape) | |
return x + y | |
class _u_shortcut(nn.Module): | |
def __init__(self, in_filters, nb_filters, subsample): | |
super(_u_shortcut, self).__init__() | |
self.process = False | |
self.model = None | |
if in_filters != nb_filters: | |
self.process = True | |
self.model = nn.Sequential( | |
nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'), | |
nn.Upsample(scale_factor=2, mode='nearest') | |
) | |
def forward(self, x, y): | |
if self.process: | |
return self.model(x) + y | |
else: | |
return x + y | |
class basic_block(nn.Module): | |
def __init__(self, in_filters, nb_filters, init_subsample=1): | |
super(basic_block, self).__init__() | |
self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) | |
self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) | |
self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample) | |
def forward(self, x): | |
x1 = self.conv1(x) | |
x2 = self.residual(x1) | |
return self.shortcut(x, x2) | |
class _u_basic_block(nn.Module): | |
def __init__(self, in_filters, nb_filters, init_subsample=1): | |
super(_u_basic_block, self).__init__() | |
self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) | |
self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) | |
self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample) | |
def forward(self, x): | |
y = self.residual(self.conv1(x)) | |
return self.shortcut(x, y) | |
class _residual_block(nn.Module): | |
def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False): | |
super(_residual_block, self).__init__() | |
layers = [] | |
for i in range(repetitions): | |
init_subsample = 1 | |
if i == repetitions - 1 and not is_first_layer: | |
init_subsample = 2 | |
if i == 0: | |
l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample) | |
else: | |
l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample) | |
layers.append(l) | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.model(x) | |
class _upsampling_residual_block(nn.Module): | |
def __init__(self, in_filters, nb_filters, repetitions): | |
super(_upsampling_residual_block, self).__init__() | |
layers = [] | |
for i in range(repetitions): | |
l = None | |
if i == 0: | |
l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input) | |
else: | |
l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input) | |
layers.append(l) | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.model(x) | |
class res_skip(nn.Module): | |
def __init__(self): | |
super(res_skip, self).__init__() | |
self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True)#(input) | |
self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3)#(block0) | |
self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5)#(block1) | |
self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7)#(block2) | |
self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12)#(block3) | |
self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7)#(block4) | |
self.res1 = _shortcut(in_filters=192, nb_filters=192)#(block3, block5, subsample=(1,1)) | |
self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5)#(res1) | |
self.res2 = _shortcut(in_filters=96, nb_filters=96)#(block2, block6, subsample=(1,1)) | |
self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3)#(res2) | |
self.res3 = _shortcut(in_filters=48, nb_filters=48)#(block1, block7, subsample=(1,1)) | |
self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2)#(res3) | |
self.res4 = _shortcut(in_filters=24, nb_filters=24)#(block0,block8, subsample=(1,1)) | |
self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True)#(res4) | |
self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1)#(block7) | |
def forward(self, x): | |
x0 = self.block0(x) | |
x1 = self.block1(x0) | |
x2 = self.block2(x1) | |
x3 = self.block3(x2) | |
x4 = self.block4(x3) | |
x5 = self.block5(x4) | |
res1 = self.res1(x3, x5) | |
x6 = self.block6(res1) | |
res2 = self.res2(x2, x6) | |
x7 = self.block7(res2) | |
res3 = self.res3(x1, x7) | |
x8 = self.block8(res3) | |
res4 = self.res4(x0, x8) | |
x9 = self.block9(res4) | |
y = self.conv15(x9) | |
return y | |
class MyDataset(Dataset): | |
def __init__(self, image_paths, transform=None): | |
self.image_paths = image_paths | |
self.transform = transform | |
def get_class_label(self, image_name): | |
# your method here | |
head, tail = os.path.split(image_name) | |
#print(tail) | |
return tail | |
def __getitem__(self, index): | |
image_path = self.image_paths[index] | |
x = Image.open(image_path) | |
y = self.get_class_label(image_path.split('/')[-1]) | |
if self.transform is not None: | |
x = self.transform(x) | |
return x, y | |
def __len__(self): | |
return len(self.image_paths) | |
def loadImages(folder): | |
imgs = [] | |
matches = [] | |
# 获取当前目录下的所有文件和文件夹 | |
for filename in os.listdir(folder): | |
# 拼接完整路径 | |
file_path = os.path.join(folder, filename) | |
# 检查是否是文件 | |
if os.path.isfile(file_path): | |
matches.append(file_path) | |
return matches | |
def crop_center_square(image): | |
""" | |
将图像中心裁剪为正方形 | |
:param image: PIL.Image对象 | |
:return: 裁剪后的PIL.Image对象 | |
""" | |
# 获取图像的宽度和高度 | |
width, height = image.size | |
# 确定正方形的边长 | |
side_length = min(width, height) | |
# 计算裁剪区域的左上角坐标 | |
left = (width - side_length) // 2 | |
top = (height - side_length) // 2 | |
right = left + side_length | |
bottom = top + side_length | |
# 执行裁剪 | |
cropped_image = image.crop((left, top, right, bottom)) | |
return cropped_image | |
def crop_image(image, crop_size, stride): | |
""" | |
根据给定的裁剪大小和步长裁剪图像,并返回裁剪后的图像列表。 | |
:param image: PIL.Image对象 | |
:param crop_size: 裁剪大小,例如 (384, 384) | |
:param stride: 重叠步长,例如 128 | |
:return: 裁剪后的图像列表 | |
""" | |
width, height = image.size | |
crop_width, crop_height = crop_size | |
cropped_images = [] | |
for j in range(0, height - crop_height + 1, stride): | |
for i in range(0, width - crop_width + 1, stride): | |
crop_box = (i, j, i + crop_width, j + crop_height) | |
cropped_image = image.crop(crop_box) | |
cropped_images.append(cropped_image) | |
return cropped_images | |
def process_image_ref(image): | |
""" | |
处理输入的PIL图像,返回包含所有裁剪后图像的列表。 | |
:param image: PIL.Image对象 | |
:return: 包含所有裁剪后图像的列表 | |
""" | |
# 调整图像到512*512 | |
resized_image_512 = image.resize((512, 512)) | |
# 创建一个列表,并将512*512的图像作为第一个元素 | |
image_list = [resized_image_512] | |
# 按照384*384的大小,有重叠的2*2的crop图像 | |
crop_size_384 = (384, 384) | |
stride_384 = 128 | |
image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384)) | |
# 按照256*256的大小,有重叠的2*2的crop图像 | |
# crop_size_256 = (256, 256) | |
# stride_256 = 256 | |
# image_list.extend(crop_image(resized_image_512, crop_size_256, stride_256)) | |
return image_list | |
def process_image_Q(image): | |
""" | |
处理输入的PIL图像,返回包含所有裁剪后图像的列表。 | |
:param image: PIL.Image对象 | |
:return: 包含所有裁剪后图像的列表 | |
""" | |
# 调整图像到512*512 | |
resized_image_512 = image.resize((512, 512)).convert("RGB").convert("RGB") | |
# 创建一个列表,并将512*512的图像作为第一个元素 | |
image_list = [] | |
# 按照384*384的大小,有重叠的2*2的crop图像 | |
crop_size_384 = (384, 384) | |
stride_384 = 128 | |
image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384)) | |
return image_list | |
def process_image(image, target_width=512, target_height = 512): | |
# 获取输入图像的宽高 | |
img_width, img_height = image.size | |
img_ratio = img_width / img_height | |
# 计算目标宽高比 | |
# target_width, target_height = target_ratio | |
target_ratio = target_width / target_height | |
# 计算宽高比误差 | |
ratio_error = abs(img_ratio - target_ratio) / target_ratio | |
if ratio_error < 0.15: | |
# 如果误差小于15%,直接resize到目标宽高比 | |
resized_image = image.resize((target_width, target_height), Image.BICUBIC) | |
else: | |
# 否则,随机裁剪到目标宽高比PIL.Image.BICUBIC | |
if img_ratio > target_ratio: | |
# 图像太宽,裁剪宽度 | |
new_width = int(img_height * target_ratio) | |
# left = random.randint(0, img_width - new_width) | |
left = int((0 + img_width - new_width)/2) | |
top = 0 | |
right = left + new_width | |
bottom = img_height | |
else: | |
# 图像太高,裁剪高度 | |
new_height = int(img_width / target_ratio) | |
left = 0 | |
# top = random.randint(0, img_height - new_height) | |
top = int((0 + img_height - new_height)/2) | |
right = img_width | |
bottom = top + new_height | |
cropped_image = image.crop((left, top, right, bottom)) | |
resized_image = cropped_image.resize((target_width, target_height), Image.BICUBIC) | |
return resized_image.convert('RGB') | |
def crop_image_varres(image, crop_size, h_stride, w_stride): | |
""" | |
根据给定的裁剪大小和步长裁剪图像,并返回裁剪后的图像列表。 | |
:param image: PIL.Image对象 | |
:param crop_size: 裁剪大小,例如 (384, 384) | |
:param stride: 重叠步长,例如 128 | |
:return: 裁剪后的图像列表 | |
""" | |
width, height = image.size | |
crop_width, crop_height = crop_size | |
cropped_images = [] | |
for j in range(0, height - crop_height + 1, h_stride): | |
for i in range(0, width - crop_width + 1, w_stride): | |
crop_box = (i, j, i + crop_width, j + crop_height) | |
cropped_image = image.crop(crop_box) | |
cropped_images.append(cropped_image) | |
return cropped_images | |
def process_image_ref_varres(image, target_width=512, target_height = 512): | |
""" | |
处理输入的PIL图像,返回包含所有裁剪后图像的列表。 | |
:param image: PIL.Image对象 | |
:return: 包含所有裁剪后图像的列表 | |
""" | |
# 调整图像到512*512 | |
resized_image_512 = image.resize((target_width, target_height)) | |
# 创建一个列表,并将512*512的图像作为第一个元素 | |
image_list = [resized_image_512] | |
# 按照384*384的大小,有重叠的2*2的crop图像 | |
crop_size_384 = (target_width//4*3, target_height//4*3) | |
w_stride_384 = target_width//4 | |
h_stride_384 = target_height//4 | |
image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384)) | |
# 按照256*256的大小,有重叠的2*2的crop图像 | |
# crop_size_256 = (256, 256) | |
# stride_256 = 256 | |
# image_list.extend(self.crop_image(resized_image_512, crop_size_256, stride_256)) | |
return image_list | |
def process_image_Q_varres(image, target_width=512, target_height = 512): | |
""" | |
处理输入的PIL图像,返回包含所有裁剪后图像的列表。 | |
:param image: PIL.Image对象 | |
:return: 包含所有裁剪后图像的列表 | |
""" | |
# 调整图像到512*512 | |
resized_image_512 = image.resize((target_width, target_height)).convert("RGB").convert("RGB") | |
# 创建一个列表,并将512*512的图像作为第一个元素 | |
image_list = [] | |
# 按照384*384的大小,有重叠的2*2的crop图像 | |
crop_size_384 = (target_width//4*3, target_height//4*3) | |
w_stride_384 = target_width//4 | |
h_stride_384 = target_height//4 | |
image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384)) | |
return image_list | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# 定义一个简单的 ResNet 块 | |
class ResNetBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, stride=1): | |
super(ResNetBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.shortcut = nn.Sequential() | |
if stride != 1 or in_channels != out_channels: | |
self.shortcut = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), | |
nn.BatchNorm2d(out_channels) | |
) | |
def forward(self, x): | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out += self.shortcut(x) # 直接相加 | |
out = F.relu(out) | |
return out | |
# 定义两层 ResNet 块模型 | |
class TwoLayerResNet(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(TwoLayerResNet, self).__init__() | |
self.block1 = ResNetBlock(in_channels, out_channels) | |
self.block2 = ResNetBlock(out_channels, out_channels) | |
self.block3 = ResNetBlock(out_channels, out_channels) | |
self.block4 = ResNetBlock(out_channels, out_channels) | |
def forward(self, x): | |
x = self.block1(x) | |
x = self.block2(x) | |
x = self.block3(x) | |
x = self.block4(x) | |
return x | |
class MultiHiddenResNetModel(nn.Module): | |
def __init__(self, channels_list, num_tensors): | |
super(MultiHiddenResNetModel, self).__init__() | |
self.two_layer_resnets = nn.ModuleList([TwoLayerResNet(channels_list[idx]*2, channels_list[min(len(channels_list)-1,idx+2)]) for idx in range(num_tensors)]) | |
def forward(self, tensor_list): | |
processed_list = [] | |
for i, tensor in enumerate(tensor_list): | |
# 应用对应的两层 ResNet 块模型 | |
tensor = self.two_layer_resnets[i](tensor) | |
processed_list.append(tensor) | |
return processed_list | |
def calculate_target_size(h, w): | |
# 计算目标高度和宽度,使得它们尽量保持原始比例,并且是 8 的倍数 | |
if random.random()>0.5: | |
target_h = (h // 8) * 8 | |
target_w = (w // 8) * 8 | |
elif random.random()>0.5: | |
target_h = (h // 8) * 8 | |
target_w = (w // 8) * 8 | |
else: | |
target_h = (h // 8) * 8 | |
target_w = (w // 8) * 8 | |
# 如果目标高度或宽度为 0,则调整为 8 | |
if target_h == 0: | |
target_h = 8 | |
if target_w == 0: | |
target_w = 8 | |
return target_h, target_w | |
def downsample_tensor(tensor): | |
# 获取 tensor 的高度和宽度 | |
b, c, h, w = tensor.shape | |
# 计算目标高度和宽度 | |
target_h, target_w = calculate_target_size(h, w) | |
# 使用插值方法将分辨率降为指定的目标高度和宽度 | |
downsampled_tensor = F.interpolate(tensor, size=(target_h, target_w), mode='bilinear', align_corners=False) | |
return downsampled_tensor | |
def get_pixart_config(): | |
pixart_config = { | |
"_class_name": "Transformer2DModel", | |
"_diffusers_version": "0.22.0.dev0", | |
"activation_fn": "gelu-approximate", | |
"attention_bias": True, | |
"attention_head_dim": 72, | |
"attention_type": "default", | |
"caption_channels": 4096, | |
"cross_attention_dim": 1152, | |
"double_self_attention": False, | |
"dropout": 0.0, | |
"in_channels": 4, | |
# "interpolation_scale": 2, | |
"norm_elementwise_affine": False, | |
"norm_eps": 1e-06, | |
"norm_num_groups": 32, | |
"norm_type": "ada_norm_single", | |
"num_attention_heads": 16, | |
"num_embeds_ada_norm": 1000, | |
"num_layers": 28, | |
"num_vector_embeds": None, | |
"only_cross_attention": False, | |
"out_channels": 8, | |
"patch_size": 2, | |
"sample_size": 128, | |
"upcast_attention": False, | |
# "use_additional_conditions": False, | |
"use_linear_projection": False | |
} | |
return pixart_config | |
class DoubleConv(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 3, 1, 1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(), | |
nn.Conv2d(out_channels, out_channels, 3, 1, 1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU() | |
) | |
def forward(self, x): | |
return self.double_conv(x) | |
class UNet(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# left | |
self.left_conv_1 = DoubleConv(6, 64) | |
self.down_1 = nn.MaxPool2d(2, 2) | |
self.left_conv_2 = DoubleConv(64, 128) | |
self.down_2 = nn.MaxPool2d(2, 2) | |
self.left_conv_3 = DoubleConv(128, 256) | |
self.down_3 = nn.MaxPool2d(2, 2) | |
self.left_conv_4 = DoubleConv(256, 512) | |
self.down_4 = nn.MaxPool2d(2, 2) | |
# center | |
self.center_conv = DoubleConv(512, 1024) | |
# right | |
self.up_1 = nn.ConvTranspose2d(1024, 512, 2, 2) | |
self.right_conv_1 = DoubleConv(1024, 512) | |
self.up_2 = nn.ConvTranspose2d(512, 256, 2, 2) | |
self.right_conv_2 = DoubleConv(512, 256) | |
self.up_3 = nn.ConvTranspose2d(256, 128, 2, 2) | |
self.right_conv_3 = DoubleConv(256, 128) | |
self.up_4 = nn.ConvTranspose2d(128, 64, 2, 2) | |
self.right_conv_4 = DoubleConv(128, 64) | |
# output | |
self.output = nn.Conv2d(64, 3, 1, 1, 0) | |
def forward(self, x): | |
# left | |
x1 = self.left_conv_1(x) | |
x1_down = self.down_1(x1) | |
x2 = self.left_conv_2(x1_down) | |
x2_down = self.down_2(x2) | |
x3 = self.left_conv_3(x2_down) | |
x3_down = self.down_3(x3) | |
x4 = self.left_conv_4(x3_down) | |
x4_down = self.down_4(x4) | |
# center | |
x5 = self.center_conv(x4_down) | |
# right | |
x6_up = self.up_1(x5) | |
temp = torch.cat((x6_up, x4), dim=1) | |
x6 = self.right_conv_1(temp) | |
x7_up = self.up_2(x6) | |
temp = torch.cat((x7_up, x3), dim=1) | |
x7 = self.right_conv_2(temp) | |
x8_up = self.up_3(x7) | |
temp = torch.cat((x8_up, x2), dim=1) | |
x8 = self.right_conv_3(temp) | |
x9_up = self.up_4(x8) | |
temp = torch.cat((x9_up, x1), dim=1) | |
x9 = self.right_conv_4(temp) | |
# output | |
output = self.output(x9) | |
return output | |
# | |
import sys | |
sys.path.append('./BidirectionalTranslation') | |
from data.base_dataset import BaseDataset, get_params, get_transform | |
from data.image_folder import make_dataset | |
def get_ScreenVAE_input(A_img, opt): | |
# 加载图像 | |
# A_img = Image.open(image_path).convert('RGB') | |
# 加载线条图像(如果存在) | |
# if os.path.exists(image_path.replace('imgs','line')[:-4]+'.jpg'): | |
# L_img = cv2.imread(image_path.replace('imgs','line')[:-4]+'.jpg') | |
# kernel = np.ones((3,3), np.uint8) | |
# L_img = cv2.erode(L_img, kernel, iterations=1) | |
# L_img = Image.fromarray(L_img) | |
# else: | |
L_img = A_img | |
# 调整图像尺寸 | |
if A_img.size != L_img.size: | |
A_img = A_img.resize(L_img.size, Image.ANTIALIAS) | |
if A_img.size[1] > 2500: | |
A_img = A_img.resize((A_img.size[0]//2, A_img.size[1]//2), Image.ANTIALIAS) | |
# 获取变换参数 | |
ow, oh = A_img.size | |
transform_params = get_params(opt, A_img.size) | |
# 应用变换 | |
A_transform = get_transform(opt, transform_params, grayscale=False) | |
L_transform = get_transform(opt, transform_params, grayscale=True) | |
A = A_transform(A_img) | |
L = L_transform(L_img) | |
# 生成灰度图像 | |
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 | |
Ai = tmp.unsqueeze(0) | |
return {'A': A.unsqueeze(0), 'Ai': Ai.unsqueeze(0), 'L': L.unsqueeze(0), 'A_paths': '', 'h': oh, 'w': ow, 'B': torch.zeros(1), | |
'Bs': torch.zeros(1), | |
'Bi': torch.zeros(1), | |
'Bl': torch.zeros(1),} | |
def get_bidirectional_translation_opt(opt): | |
opt.results_dir = './results/test/western2manga' | |
opt.dataroot = './datasets/color2manga' | |
opt.checkpoints_dir = '/group/40034/zhuangjunhao/ScreenStyle/BidirectionalTranslation/checkpoints/color2manga/' | |
opt.name = 'color2manga_cycle_ganstft' | |
opt.model = 'cycle_ganstft' | |
opt.direction = 'BtoA' | |
opt.preprocess = 'none' | |
opt.load_size = 512 | |
opt.crop_size = 1024 | |
opt.input_nc = 1 | |
opt.output_nc = 3 | |
opt.nz = 64 | |
opt.netE = 'conv_256' | |
opt.num_test = 30 | |
opt.n_samples = 1 | |
opt.upsample = 'bilinear' | |
opt.ngf = 48 | |
opt.nef = 48 | |
opt.ndf = 32 | |
opt.center_crop = True | |
opt.color2screen = True | |
opt.no_flip = True | |
# Set other options | |
opt.num_threads = 1 | |
opt.batch_size = 1 | |
opt.serial_batches = True | |
return opt |