Spaces:
Build error
Build error
File size: 4,399 Bytes
ea2f7c7 e4fb230 4e64649 e4fb230 bf82406 5e8f5b8 bf82406 587b848 e4fb230 5e8f5b8 8e6efc1 7bebb02 e4fb230 7bebb02 e4fb230 7bebb02 4e64649 3b83a8e 8e6efc1 3b83a8e acda9fe 8e6efc1 acda9fe 8e6efc1 acda9fe 8e6efc1 acda9fe 8e6efc1 acda9fe ab7b996 baaaa83 21b3928 ab7b996 e4fb230 a907392 8e6efc1 e4fb230 7bebb02 12e61cd 8e6efc1 7bebb02 5e8f5b8 bf82406 c06d8b9 12e61cd 7bebb02 709b74f bf82406 587b848 709b74f bf82406 3b83a8e 7bebb02 3b83a8e bf82406 a402a5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
import StyleTransfer.transformer as transformer
import StyleTransfer.StyTR as StyTR
from collections import OrderedDict
import tensorflow_hub as tfhub
import tensorflow as tf
import paddlehub as phub
import os
############################################# TRANSFORMER ############################################
def style_transform(h:int,w:int) -> transforms.Compose:
k = (h,w)
transform_list = []
transform_list.append(transforms.CenterCrop((h,w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform() -> transforms.Compose:
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
vgg_path = 'StyleTransfer/models/vgg_normalised.pth'
decoder_path = 'StyleTransfer/models/decoder_iter_160000.pth'
Trans_path = 'StyleTransfer/models/transformer_iter_160000.pth'
embedding_path = 'StyleTransfer/models/embedding_iter_160000.pth'
# Advanced options
content_size=640
style_size=640
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
decoder.eval()
Trans.eval()
vgg.eval()
new_state_dict = OrderedDict()
state_dict = torch.load(decoder_path)
decoder.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(Trans_path)
Trans.load_state_dict(state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(embedding_path)
embedding.load_state_dict(state_dict)
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
network.eval()
content_tf = content_transform()
style_tf = style_transform(style_size,style_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network.to(device)
content = content_tf(content_img.convert("RGB"))
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output= network(content,style)
output = output[0].cpu().squeeze()
output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
return Image.fromarray(output)
############################################## STYLE-FAST #############################################
style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
def StyleFAST(content_image:Image.Image, style_image:Image.Image) -> Image.Image:
content_image = tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...] / 255.
style_image = tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.
output = style_transfer_model(content_image, style_image)
stylized_image = output[0]
return Image.fromarray(np.uint8(stylized_image[0] * 255))
########################################### STYLE PROJECTION ##########################################
os.system("hub install stylepro_artistic==1.0.1")
stylepro_artistic = phub.Module(name="stylepro_artistic")
def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image:
print('line92')
result = stylepro_artistic.style_transfer(
images=[{
'content': np.array(content_image.convert('RGB') )[:, :, ::-1],
'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]}],
alpha=0.8)
print('line97')
return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
def create_styledSofa(content_image:Image.Image,style_image:Image.Image,choice:str) -> Image.Image:
if choice =="Style Transformer":
output = StyleTransformer(content_image,style_image)
elif choice =="Style FAST":
output = StyleFAST(content_image,style_image)
elif choice =="Style Projection":
output = StyleProjection(content_image,style_image)
else:
output = content_image
return output
|