Spaces:
Build error
Build error
File size: 3,722 Bytes
ea2f7c7 e4fb230 bf82406 5e8f5b8 bf82406 e4fb230 5e8f5b8 8e6efc1 e4fb230 3b83a8e b3e13eb 3b83a8e 8e6efc1 3b83a8e 8e6efc1 acda9fe 8e6efc1 acda9fe 8e6efc1 acda9fe 8e6efc1 acda9fe 8e6efc1 acda9fe ab7b996 baaaa83 21b3928 ab7b996 e4fb230 a907392 8e6efc1 e4fb230 5e8f5b8 8e6efc1 f628b78 b3e13eb 5e8f5b8 bf82406 b3e13eb bf82406 b3e13eb bf82406 3b83a8e bf82406 a402a5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
import transformer as transformer
import StyTR as StyTR
from collections import OrderedDict
import tensorflow_hub as tfhub
import tensorflow as tf
import os
import cv2
import paddlehub as phub
############################################# TRANSFORMER ############################################
def style_transform(h,w):
k = (h,w)
transform_list = []
transform_list.append(transforms.CenterCrop((h,w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform():
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def StyleTransformer(content_img: Image, style_img: Image):
vgg_path = 'vgg_normalised.pth'
decoder_path = 'decoder_iter_160000.pth'
Trans_path = 'transformer_iter_160000.pth'
embedding_path = 'embedding_iter_160000.pth'
# Advanced options
content_size=640
style_size=640
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
decoder.eval()
Trans.eval()
vgg.eval()
new_state_dict = OrderedDict()
state_dict = torch.load(decoder_path)
decoder.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(Trans_path)
Trans.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(embedding_path)
embedding.load_state_dict(state_dict)
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
network.eval()
content_tf = content_transform()
style_tf = style_transform(style_size,style_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network.to(device)
content = content_tf(content_img.convert("RGB"))
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output= network(content,style)
output = output[0].cpu().squeeze()
output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
return Image.fromarray(output)
############################################## STYLE-GAN #############################################
def StyleGAN(content_image, style_image):
style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
output = style_transfer_model(content_image, style_image)
stylized_image = output[0]
return Image.fromarray(np.uint8(stylized_image[0] * 255))
########################################### STYLE PROJECTION ##########################################
def styleProjection(content_image,style_image):
stylepro_artistic = phub.Module(name="stylepro_artistic")
result = stylepro_artistic.style_transfer(
images=[{
'content': np.array(content_image.convert('RGB') )[:, :, ::-1],
'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]
}])
return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
def create_styledSofa(content_image,style_image):
output = StyleTransformer(content_image,style_image)
return output
|