File size: 1,497 Bytes
7dab4fa 8abd2c9 04d2061 7dab4fa 8abd2c9 7dab4fa a675130 7dab4fa 04d2061 b9fcbbd 04d2061 7dab4fa 04d2061 7dab4fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from PIL import Image
import ot
import numpy as np
import gradio as gr
def preprocess(content_img,style_img):
content_img = Image.fromarray(content_img.astype("uint8"), "RGB")
style_img = Image.fromarray(style_img.astype("uint8"), "RGB")
resize = 100
content_img = content_img.resize((resize, resize))
style_img = style_img.resize((resize, resize))
xs = np.array(content_img, dtype='float64').reshape(-1, 3)
xt = np.array(style_img, dtype='float64').reshape(-1, 3)
n = xs.shape[0]
# 各点の重さ。今回は全て1/nとしている
a, b = np.ones((n,)) / n, np.ones((n,)) / n
# 距離の定義
C = ot.dist(xs, xt)
C /= C.max()
# 最適な輸送方法の計算
P = ot.emd(a, b, C)
# Pを用いて実際に輸送してみる
transferred_img = np.einsum('ij, ki->kj',xt, P)
transferred_img = transferred_img*n
transferred_img = np.array(transferred_img, dtype='uint8')
transferred_img = transferred_img.reshape(int(n**(1/2)), int(n**(1/2)), 3)
return transferred_img
content_file = 'gohho.png'
style_file = 'sea.png'
gr_img = gr.inputs.Image()
gr_img2 = gr.inputs.Image()
interface = gr.Interface(
fn=preprocess,
inputs=[
gr_img,
gr_img2
],
outputs="image",
examples=[
[content_file, style_file],
],
title="最適輸送距離による色相変換",
description="content画像と,style画像を入力してください",
)
interface.launch()
|