import os from io import BytesIO import requests from datetime import datetime import random # Interface utilities import gradio as gr # Data utilities import numpy as np import pandas as pd # Image utilities from PIL import Image import cv2 # FLAVA Model import torch from transformers import BertTokenizer, FlavaModel # Style Transfer Model import paddlehub as hub os.system("hub install stylepro_artistic==1.0.1") stylepro_artistic = hub.Module(name="stylepro_artistic") # FLAVA Model device = "cuda" if torch.cuda.is_available() else "cpu" model = FlavaModel.from_pretrained("facebook/flava-full") tokenizer = BertTokenizer.from_pretrained("facebook/flava-full") model = model.to(device) # Load Data photo_features = np.load("unsplash-dataset/features.npy") photo_data = pd.read_csv("unsplash-dataset/photos.csv") def image_from_text(text_input): start=datetime.now() ## Inference with torch.no_grad(): inputs = tokenizer([text_input], padding=True, return_tensors="pt").to(device) text_features = model.get_text_features(**inputs)[:, 0, :].cpu().numpy() ## Find similarity similarities = list((text_features @ photo_features.T).squeeze(0)) ## Return best image :) idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[0][1] photo = photo_data.iloc[idx] print(f"Time spent at FLAVA: {datetime.now()-start}") start=datetime.now() # Downlaod image response = requests.get(photo["path"]) pil_image = Image.open(BytesIO(response.content)).convert("RGB") open_cv_image = np.array(pil_image) # Convert RGB to BGR open_cv_image = open_cv_image[:, :, ::-1].copy() print(f"Time spent at Image request: {datetime.now()-start}") return open_cv_image def inference(content, style): content_image = image_from_text(content) start=datetime.now() result = stylepro_artistic.style_transfer( images=[{ "content": content_image, "styles": [cv2.imread(style.name)] }]) print(f"Time spent at Style Transfer: {datetime.now()-start}") return Image.fromarray(np.uint8(result[0]["data"])[:,:,::-1]).convert("RGB") if __name__ == "__main__": title = "FLAVA Neural Style Transfer" description = "Gradio demo for Neural Style Transfer. Inspired from this demo for CLIP. To use it, simply enter the text for image content and upload style image. Read more at the links below." article = "
Parameter-Free Style Projection for Arbitrary Style Transfer | Github RepoFLAVA paper | Hugging Face FLAVA Implementation
" examples=[ ["a cute kangaroo", "styles/starry.jpeg"], ["man holding beer", "styles/mona1.jpeg"], ] demo = gr.Interface(inference, inputs=[ gr.inputs.Textbox(lines=1, placeholder="Describe the content of the image", default="a modern city with neon lights", label="Describe the image to which the style will be applied"), gr.inputs.Image(type="file", label="Style to be applied"), ], outputs=gr.outputs.Image(type="pil"), enable_queue=True, title=title, description=description, article=article, examples=examples ) demo.launch()