SengaFiller / jgdildsengafiller.py
Mya-Mya
First Commit
17725c1
raw
history blame
791 Bytes
from backend import SengaFiller
from numpy import ndarray, cast, int8, ones
from PIL import Image
class JGDILDSengaFiller(SengaFiller):
def __init__(self) -> None:
super().__init__()
from tensorflow import keras
self.model = keras.models.load_model("./model1.h5")
def run(self, image_pil: Image.Image) -> ndarray:
input_width, input_height = image_pil.size
image_mono_pil = image_pil.point(lambda x: int(x > 200), mode="L")
image_numpy = ones((input_height, input_width)) * image_mono_pil
x = image_numpy.reshape((1, input_height, input_width, 1))
y = self.model(x)
output_height, output_width = y.shape[1:3]
output_image = y.numpy().reshape(output_height, output_width)
return output_image