|
from PIL import Image |
|
import ot |
|
import numpy as np |
|
import gradio as gr |
|
|
|
|
|
def preprocess(content_img,style_img): |
|
|
|
content_img = Image.fromarray(content_img.astype("uint8"), "RGB") |
|
style_img = Image.fromarray(style_img.astype("uint8"), "RGB") |
|
|
|
resize = 100 |
|
content_img = content_img.resize((resize, resize)) |
|
style_img = style_img.resize((resize, resize)) |
|
|
|
xs = np.array(content_img, dtype='float64').reshape(-1, 3) |
|
xt = np.array(style_img, dtype='float64').reshape(-1, 3) |
|
|
|
n = xs.shape[0] |
|
|
|
a, b = np.ones((n,)) / n, np.ones((n,)) / n |
|
|
|
C = ot.dist(xs, xt) |
|
C /= C.max() |
|
|
|
|
|
P = ot.emd(a, b, C) |
|
|
|
|
|
transferred_img = np.einsum('ij, ki->kj',xt, P) |
|
transferred_img = transferred_img*n |
|
transferred_img = np.array(transferred_img, dtype='uint8') |
|
transferred_img = transferred_img.reshape(int(n**(1/2)), int(n**(1/2)), 3) |
|
|
|
return transferred_img |
|
|
|
|
|
content_file = 'gohho.png' |
|
style_file = 'sea.png' |
|
|
|
gr_img = gr.inputs.Image() |
|
gr_img2 = gr.inputs.Image() |
|
|
|
interface = gr.Interface( |
|
fn=preprocess, |
|
inputs=[ |
|
gr_img, |
|
gr_img2 |
|
], |
|
outputs="image", |
|
examples=[ |
|
[content_file, style_file], |
|
], |
|
title="最適輸送距離による色相変換", |
|
description="content画像と,style画像を入力してください", |
|
) |
|
|
|
interface.launch() |
|
|