SofaStyler / StyleTransfer /styleTransfer.py
Sophie98
change to streamlit
ad1ac8f
raw history blame
No virus
5.42 kB
import numpy as np
import paddlehub as phub
import StyleTransfer.srcTransformer.StyTR as StyTR
import StyleTransfer.srcTransformer.transformer as transformer
import tensorflow as tf
import tensorflow_hub as tfhub
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
# TRANSFORMER
vgg_path = "StyleTransfer/srcTransformer/Transformer_models/vgg_normalised.pth"
decoder_path = "StyleTransfer/srcTransformer/Transformer_models/decoder_iter_160000.pth"
Trans_path = (
"StyleTransfer/srcTransformer/Transformer_models/transformer_iter_160000.pth"
)
embedding_path = (
"StyleTransfer/srcTransformer/Transformer_models/embedding_iter_160000.pth"
)
def style_transform(h, w):
"""
This function creates a transformation for the style image,
that crops it and formats it into a tensor.
Parameters:
h = height
w = width
Return:
transform = transformation pipeline
"""
transform_list = []
transform_list.append(transforms.CenterCrop((h, w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform():
"""
This function simply creates a transformation pipeline,
that formats the content image into a tensor.
Returns:
transform = the transformation pipeline
"""
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
# This loads the network architecture already at building time
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
# The (square) shape of the content and style image is fixed
content_size = 640
style_size = 640
def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
"""
This function creates the Transformer network and applies it on
a content and style image to create a styled image.
Parameters:
content_img = the image with the content
style_img = the image with the style/pattern
Returns:
output = an image that is a combination of both
"""
decoder.eval()
Trans.eval()
vgg.eval()
state_dict = torch.load(decoder_path)
decoder.load_state_dict(state_dict)
state_dict = torch.load(Trans_path)
Trans.load_state_dict(state_dict)
state_dict = torch.load(embedding_path)
embedding.load_state_dict(state_dict)
network = StyTR.StyTrans(vgg, decoder, embedding, Trans)
network.eval()
content_tf = content_transform()
style_tf = style_transform(style_size, style_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network.to(device)
content = content_tf(content_img.convert("RGB"))
style = style_tf(style_img.convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output = network(content, style)
output = output[0].cpu().squeeze()
output = (
output.mul(255)
.add_(0.5)
.clamp_(0, 255)
.permute(1, 2, 0)
.to("cpu", torch.uint8)
.numpy()
)
return Image.fromarray(output)
# STYLE-FAST
style_transfer_model = tfhub.load(
"https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2"
)
def StyleFAST(content_image: Image.Image, style_image: Image.Image) -> Image.Image:
"""
This function applies a Fast image style transfer technique,
which uses a pretrained model from tensorhub.
Parameters:
content_image = the image with the content
style_image = the image with the style/pattern
Returns:
stylized_image = an image that is a combination of both
"""
content_image = (
tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...]
/ 255.0
)
style_image = (
tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.0
)
output = style_transfer_model(content_image, style_image)
stylized_image = output[0]
return Image.fromarray(np.uint8(stylized_image[0] * 255))
# STYLE PROJECTION
stylepro_artistic = phub.Module(name="stylepro_artistic")
def styleProjection(
content_image: Image.Image, style_image: Image.Image, alpha: float = 1.0
):
"""
This function uses parameter free style transfer,
based on a model from paddlehub.
There is an optional weight parameter alpha, which
allows to control the balance between image and style.
Parameters:
content_image = the image with the content
style_image = the image with the style/pattern
alpha = weight for the image vs style.
This should be a float between 0 and 1.
Returns:
result = an image that is a combination of both
"""
result = stylepro_artistic.style_transfer(
images=[
{
"content": np.array(content_image.convert("RGB"))[:, :, ::-1],
"styles": [np.array(style_image.convert("RGB"))[:, :, ::-1]],
}
],
alpha=alpha,
)
return Image.fromarray(np.uint8(result[0]["data"])[:, :, ::-1]).convert("RGB")