from PIL import Image import ot import numpy as np import gradio as gr def preprocess(content_img,style_img): content_img = Image.fromarray(content_img.astype("uint8"), "RGB") style_img = Image.fromarray(style_img.astype("uint8"), "RGB") resize = 100 content_img = content_img.resize((resize, resize)) style_img = style_img.resize((resize, resize)) xs = np.array(content_img, dtype='float64').reshape(-1, 3) xt = np.array(style_img, dtype='float64').reshape(-1, 3) n = xs.shape[0] # 各点の重さ。今回は全て1/nとしている a, b = np.ones((n,)) / n, np.ones((n,)) / n # 距離の定義 C = ot.dist(xs, xt) C /= C.max() # 最適な輸送方法の計算 P = ot.emd(a, b, C) # Pを用いて実際に輸送してみる transferred_img = np.einsum('ij, ki->kj',xt, P) transferred_img = transferred_img*n transferred_img = np.array(transferred_img, dtype='uint8') transferred_img = transferred_img.reshape(int(n**(1/2)), int(n**(1/2)), 3) return transferred_img content_file = 'gohho.png' style_file = 'sea.png' gr_img = gr.inputs.Image() gr_img2 = gr.inputs.Image() interface = gr.Interface( fn=preprocess, inputs=[ gr_img, gr_img2 ], outputs="image", examples=[ [content_file, style_file], ], title="最適輸送距離による色相変換", description="content画像と,style画像を入力してください", ) interface.launch()