Spaces:
Build error
Build error
from PIL import Image | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from torchvision import transforms | |
import transformer as transformer | |
import StyTR as StyTR | |
import numpy as np | |
from collections import OrderedDict | |
def test_transform(size, crop): | |
transform_list = [] | |
if size != 0: | |
transform_list.append(transforms.Resize(size)) | |
if crop: | |
transform_list.append(transforms.CenterCrop(size)) | |
transform_list.append(transforms.ToTensor()) | |
transform = transforms.Compose(transform_list) | |
return transform | |
def style_transform(h,w): | |
k = (h,w) | |
size = int(np.max(k)) | |
transform_list = [] | |
transform_list.append(transforms.CenterCrop((h,w))) | |
transform_list.append(transforms.ToTensor()) | |
transform = transforms.Compose(transform_list) | |
return transform | |
def content_transform(): | |
transform_list = [] | |
transform_list.append(transforms.ToTensor()) | |
transform = transforms.Compose(transform_list) | |
return transform | |
def StyleTransformer(content_img: Image, style_img: Image, | |
vgg_path:str = 'vgg_normalised.pth', decoder_path:str = 'decoder_iter_160000.pth', | |
Trans_path:str = 'transformer_iter_160000.pth', embedding_path:str = 'embedding_iter_160000.pth'): | |
# Advanced options | |
content_size=640 | |
style_size=640 | |
crop='store_true' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vgg = StyTR.vgg | |
vgg.load_state_dict(torch.load(vgg_path)) | |
vgg = nn.Sequential(*list(vgg.children())[:44]) | |
decoder = StyTR.decoder | |
Trans = transformer.Transformer() | |
embedding = StyTR.PatchEmbed() | |
decoder.eval() | |
Trans.eval() | |
vgg.eval() | |
new_state_dict = OrderedDict() | |
state_dict = torch.load(decoder_path) | |
for k, v in state_dict.items(): | |
#namekey = k[7:] # remove `module.` | |
namekey = k | |
new_state_dict[namekey] = v | |
decoder.load_state_dict(new_state_dict) | |
new_state_dict = OrderedDict() | |
state_dict = torch.load(Trans_path) | |
for k, v in state_dict.items(): | |
#namekey = k[7:] # remove `module.` | |
namekey = k | |
new_state_dict[namekey] = v | |
Trans.load_state_dict(new_state_dict) | |
new_state_dict = OrderedDict() | |
state_dict = torch.load(embedding_path) | |
for k, v in state_dict.items(): | |
#namekey = k[7:] # remove `module.` | |
namekey = k | |
new_state_dict[namekey] = v | |
embedding.load_state_dict(new_state_dict) | |
network = StyTR.StyTrans(vgg,decoder,embedding,Trans) | |
network.eval() | |
network.to(device) | |
content_tf = test_transform(content_size, crop) | |
style_tf = test_transform(style_size, crop) | |
content_tf1 = content_transform() | |
content = content_tf(content_img.convert("RGB")) | |
h,w,c=np.shape(content) | |
style_tf1 = style_transform(h,w) | |
style = style_tf(style_img.convert("RGB")) | |
style = style.to(device).unsqueeze(0) | |
content = content.to(device).unsqueeze(0) | |
with torch.no_grad(): | |
output= network(content,style) | |
output = output[0].cpu() | |
output = transforms.ToPILImage(output) | |
return output | |
def create_styledSofa(sofa:Image, style:Image): | |
styled_sofa = StyleTransformer(sofa,style) | |
return styled_sofa | |
# image = Image.open('sofa_office.jpg') | |
# image.show() | |
# image = np.array(image) | |
# image,box = resize_sofa(image) | |
# image = image.crop(box) | |
# print(box) | |
# image.show() |