File size: 1,497 Bytes
7dab4fa
 
 
8abd2c9
 
04d2061
7dab4fa
8abd2c9
7dab4fa
 
 
a675130
7dab4fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d2061
 
b9fcbbd
04d2061
 
7dab4fa
 
 
 
 
 
 
 
04d2061
 
 
7dab4fa
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from PIL import Image
import ot
import numpy as np
import gradio as gr


def preprocess(content_img,style_img):

    content_img = Image.fromarray(content_img.astype("uint8"), "RGB")
    style_img = Image.fromarray(style_img.astype("uint8"), "RGB")

    resize = 100
    content_img = content_img.resize((resize, resize))
    style_img = style_img.resize((resize, resize))

    xs = np.array(content_img, dtype='float64').reshape(-1, 3)
    xt = np.array(style_img, dtype='float64').reshape(-1, 3)

    n = xs.shape[0]
    # 各点の重さ。今回は全て1/nとしている
    a, b = np.ones((n,)) / n, np.ones((n,)) / n
    # 距離の定義
    C = ot.dist(xs, xt)
    C /= C.max()

    # 最適な輸送方法の計算
    P = ot.emd(a, b, C)

    # Pを用いて実際に輸送してみる
    transferred_img = np.einsum('ij, ki->kj',xt, P)
    transferred_img = transferred_img*n
    transferred_img = np.array(transferred_img, dtype='uint8')
    transferred_img = transferred_img.reshape(int(n**(1/2)), int(n**(1/2)), 3)

    return transferred_img


content_file = 'gohho.png'
style_file = 'sea.png'

gr_img = gr.inputs.Image()
gr_img2 = gr.inputs.Image()

interface = gr.Interface(
    fn=preprocess,
    inputs=[
        gr_img,
        gr_img2
    ],
    outputs="image",
    examples=[
        [content_file, style_file],
    ],
    title="最適輸送距離による色相変換",
    description="content画像と,style画像を入力してください",
)

interface.launch()