otd / app.py
ryusei-iki
example_list
04d2061
from PIL import Image
import ot
import numpy as np
import gradio as gr
def preprocess(content_img,style_img):
content_img = Image.fromarray(content_img.astype("uint8"), "RGB")
style_img = Image.fromarray(style_img.astype("uint8"), "RGB")
resize = 100
content_img = content_img.resize((resize, resize))
style_img = style_img.resize((resize, resize))
xs = np.array(content_img, dtype='float64').reshape(-1, 3)
xt = np.array(style_img, dtype='float64').reshape(-1, 3)
n = xs.shape[0]
# 各点の重さ。今回は全て1/nとしている
a, b = np.ones((n,)) / n, np.ones((n,)) / n
# 距離の定義
C = ot.dist(xs, xt)
C /= C.max()
# 最適な輸送方法の計算
P = ot.emd(a, b, C)
# Pを用いて実際に輸送してみる
transferred_img = np.einsum('ij, ki->kj',xt, P)
transferred_img = transferred_img*n
transferred_img = np.array(transferred_img, dtype='uint8')
transferred_img = transferred_img.reshape(int(n**(1/2)), int(n**(1/2)), 3)
return transferred_img
content_file = 'gohho.png'
style_file = 'sea.png'
gr_img = gr.inputs.Image()
gr_img2 = gr.inputs.Image()
interface = gr.Interface(
fn=preprocess,
inputs=[
gr_img,
gr_img2
],
outputs="image",
examples=[
[content_file, style_file],
],
title="最適輸送距離による色相変換",
description="content画像と,style画像を入力してください",
)
interface.launch()