AdamOswald1's picture
fix for app.py line 65
1ac77da
raw
history blame
2.87 kB
import numpy as np
import matplotlib.pylab as plt
import ot
import ot.plot
from PIL import Image
from sklearn.cluster import KMeans
import gradio as gr
def kclus(data,n_clusters,resize1, resize2):
# img = np.resize(data,(resize2, resize1,3))
# data = np.array(img,dtype="float64").reshape(-1, 3)
data = Image.fromarray(data)
# print(np.array(data.resize((resize2, resize1)), dtype="float64"))
data = np.array(data.resize((resize2, resize1)), dtype="float64").reshape(-1, 3)
kmeans = KMeans(n_clusters, random_state=0)
cluster_pred = kmeans.fit_predict(data)
# print(cluster_pred.shape)
ans = [0 for i in range(n_clusters)]
for i in range(n_clusters):
x = data[cluster_pred==i]
color_ave = np.sum(x, axis=0)/len(x)
ans[i] = color_ave
# print(ans.shape)
return ans,cluster_pred
def saiteki(source,target,gaso,cluster_pred,resize1,resize2):
xs = np.array(source, dtype="float64")
xt = np.array(target, dtype="float64")
# print(xt.shape)
a, b = np.ones((gaso,)) / gaso, np.ones((gaso,)) / gaso
# print(a.shape,b.shape)
M = ot.dist(xs, xt)
M /= M.max()
G0 = ot.emd(a, b, M)
P = G0 * (gaso)
# print(P.shape)
kansei = np.einsum('id,ji -> jd', xt, P)
# cluster_pred = cluster_pred.reshape(100,100)
kansei = kansei.reshape(gaso,3)
# print(kansei.shape)
new_data = [0 for i in range(len(cluster_pred))]
for i in range(len(cluster_pred)):
new_data[i] = kansei[cluster_pred[i]]
# print(new_data.shape)
new_data = np.array(new_data, dtype="uint8").reshape(resize1,resize2,3)
# plt.figure(1)
# plt.imshow(new_data)
# plt.show()
return new_data
# resize1 = 200
# resize2 = 350
#
# im1 = Image.open('data/gohho.jpeg')
# im2 = Image.open('data/yuya_sea.jpeg')
# source = np.array(im1.resize((resize2,resize1)),dtype="float64").reshape(-1,3)
# target = np.array(im2.resize((resize2,resize1)),dtype="float64").reshape(-1,3)
# source = np.array(source, dtype="uint8").reshape(resize1,resize2,3)
# plt.figure(1)
# plt.imshow(sourc)
# plt.show()
# gaso = 200
# so,cluster_pred1 = kclus(source,gaso)
# re,cluster_pred2 = kclus(target,gaso)
# saiteki(so, re,gaso,cluster_pred1,resize1,resize2)
def sepia(source,target):
resize1 = 200
resize2 = 300
gaso = 150
# a = np.array(source.resize((resize2, resize1)), dtype="float64").reshape(-1, 3)
# b = np.array(target.resize((resize2, resize1)), dtype="float64").reshape(-1, 3)
so, cluster_pred1 = kclus(source, gaso,resize1, resize2)
re, cluster_pred2 = kclus(target, gaso,resize1, resize2)
sepia_img = saiteki(so, re, gaso, cluster_pred1, resize1, resize2)
# plt.figure(1)
# plt.imshow(sepia_img)
# plt.show()
return sepia_img
demo = gr.Interface(fn = sepia, inputs = ["image","image"], outputs = ["image"])
demo.launch()