File size: 3,763 Bytes
058960f
 
 
9ba5c76
2e697e2
058960f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba5c76
 
058960f
 
 
 
 
 
 
 
 
3318dec
058960f
 
 
9ba5c76
 
 
058960f
 
 
 
 
 
 
9ba5c76
 
058960f
 
 
9ba5c76
 
 
058960f
 
9ba5c76
058960f
 
9ba5c76
 
058960f
9ba5c76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from io import BytesIO
import requests
from datetime import datetime
import random

# Interface utilities
import gradio as gr

# Data utilities
import numpy as np
import pandas as pd

# Image utilities
from PIL import Image
import cv2 

# Clip Model
import torch
from transformers import CLIPTokenizer, CLIPModel

# Style Transfer Model
import paddlehub as hub



os.system("hub install stylepro_artistic==1.0.1")
stylepro_artistic = hub.Module(name="stylepro_artistic")



# Clip Model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
model = model.to(device)

# Load Data
photos = pd.read_csv("unsplash-dataset/photos.tsv000", sep="\t", header=0)
photo_features = np.load("unsplash-dataset/features.npy")
photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
photo_ids = list(photo_ids["photo_id"])

def image_from_text(text_input):
    start=datetime.now()

    ## Inference
    with torch.no_grad():
        inputs = tokenizer([text_input],  padding=True, return_tensors="pt")
        text_features = model.get_text_features(**inputs).cpu().numpy()
    
    ## Find similarity
    similarities = list((text_features @ photo_features.T).squeeze(0))
    
    ## Return best image :)
    idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[0][1]
    photo_id = photo_ids[idx]
    photo_data = photos[photos["photo_id"] == photo_id].iloc[0]

    print(f"Time spent at CLIP: {datetime.now()-start}")

    start=datetime.now()
    # Downlaod image
    response = requests.get(photo_data["photo_image_url"] + "?w=640")
    pil_image = Image.open(BytesIO(response.content)).convert("RGB") 
    open_cv_image = np.array(pil_image) 
    # Convert RGB to BGR 
    open_cv_image = open_cv_image[:, :, ::-1].copy() 

    print(f"Time spent at Image request: {datetime.now()-start}")

    return open_cv_image

def inference(content, style):
    content_image = image_from_text(content)
    start=datetime.now()

    result = stylepro_artistic.style_transfer(
        images=[{
            "content": content_image,
            "styles": [cv2.imread(style.name)]
        }])

    print(f"Time spent at Style Transfer: {datetime.now()-start}")
    return Image.fromarray(np.uint8(result[0]["data"])[:,:,::-1]).convert("RGB")

if __name__ == "__main__": 
    title = "Neural Style Transfer"
    description = "Gradio demo for Neural Style Transfer. To use it, simply enter the text for image content and upload style image. Read more at the links below."
    article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2003.07694'target='_blank'>Parameter-Free Style Projection for Arbitrary Style Transfer</a> | <a href='https://github.com/PaddlePaddle/PaddleHub' target='_blank'>Github Repo</a></br><a href='https://arxiv.org/abs/2103.00020'target='_blank'>Clip paper</a> | <a href='https://huggingface.co/transformers/model_doc/clip.html' target='_blank'>Hugging Face Clip Implementation</a></p>"
    examples=[
            ["a cute kangaroo", "styles/starry.jpeg"],
            ["man holding beer", "styles/mona1.jpeg"],
        ]
    interface = gr.Interface(inference, 
        inputs=[
            gr.inputs.Textbox(lines=1, placeholder="Describe the content of the image", default="a cute kangaroo", label="Describe the image to which the style will be applied"),
            gr.inputs.Image(type="file", label="Style to be applied"),
        ], 
        outputs=gr.outputs.Image(type="pil"),
        enable_queue=True,
        title=title,
        description=description,
        article=article,
        examples=examples)
    interface.launch()