Spaces:
Build error
Build error
from PIL import Image | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
import transformer as transformer | |
import StyTR as StyTR | |
from collections import OrderedDict | |
import tensorflow_hub as tfhub | |
import tensorflow as tf | |
import os | |
import cv2 | |
import paddlehub as phub | |
############################################# TRANSFORMER ############################################ | |
def style_transform(h,w): | |
k = (h,w) | |
transform_list = [] | |
transform_list.append(transforms.CenterCrop((h,w))) | |
transform_list.append(transforms.ToTensor()) | |
transform = transforms.Compose(transform_list) | |
return transform | |
def content_transform(): | |
transform_list = [] | |
transform_list.append(transforms.ToTensor()) | |
transform = transforms.Compose(transform_list) | |
return transform | |
def StyleTransformer(content_img: Image, style_img: Image): | |
vgg_path = 'vgg_normalised.pth' | |
decoder_path = 'decoder_iter_160000.pth' | |
Trans_path = 'transformer_iter_160000.pth' | |
embedding_path = 'embedding_iter_160000.pth' | |
# Advanced options | |
content_size=640 | |
style_size=640 | |
vgg = StyTR.vgg | |
vgg.load_state_dict(torch.load(vgg_path)) | |
vgg = nn.Sequential(*list(vgg.children())[:44]) | |
decoder = StyTR.decoder | |
Trans = transformer.Transformer() | |
embedding = StyTR.PatchEmbed() | |
decoder.eval() | |
Trans.eval() | |
vgg.eval() | |
new_state_dict = OrderedDict() | |
state_dict = torch.load(decoder_path) | |
decoder.load_state_dict(state_dict) | |
new_state_dict = OrderedDict() | |
state_dict = torch.load(Trans_path) | |
Trans.load_state_dict(state_dict) | |
new_state_dict = OrderedDict() | |
state_dict = torch.load(embedding_path) | |
embedding.load_state_dict(state_dict) | |
network = StyTR.StyTrans(vgg,decoder,embedding,Trans) | |
network.eval() | |
content_tf = content_transform() | |
style_tf = style_transform(style_size,style_size) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
network.to(device) | |
content = content_tf(content_img.convert("RGB")) | |
style = style_tf(style_img.convert("RGB")) | |
style = style.to(device).unsqueeze(0) | |
content = content.to(device).unsqueeze(0) | |
with torch.no_grad(): | |
output= network(content,style) | |
output = output[0].cpu().squeeze() | |
output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() | |
return Image.fromarray(output) | |
############################################## STYLE-GAN ############################################# | |
def StyleGAN(content_image, style_image): | |
style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2") | |
content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255. | |
style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255. | |
output = style_transfer_model(content_image, style_image) | |
stylized_image = output[0] | |
return Image.fromarray(np.uint8(stylized_image[0] * 255)) | |
########################################### STYLE PROJECTION ########################################## | |
def styleProjection(content_image,style_image): | |
stylepro_artistic = phub.Module(name="stylepro_artistic") | |
result = stylepro_artistic.style_transfer( | |
images=[{ | |
'content': np.array(content_image.convert('RGB') )[:, :, ::-1], | |
'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]] | |
}]) | |
return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB') | |
def create_styledSofa(content_image,style_image): | |
output = StyleTransformer(content_image,style_image) | |
return output | |