sberbank-ai
feat: Add app file
346b427
raw
history blame
1.33 kB
import numpy as np
import cv2
import gradio as gr
from huggingface_hub import hf_hub_download
from scgan.config import Config
from scgan.generate_images import ImgGenerator
def download_weights(repo_id):
char_map_path = hf_hub_download(repo_id, "char_map.pkl")
weights_path = hf_hub_download(repo_id, "model_checkpoint_epoch_200.pth.tar")
return char_map_path, weights_path
def get_text_from_image(img):
COLOR_MIN = np.array([0, 0, 0],np.uint8)
COLOR_MAX = np.array([250,250,160],np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
text_mask = cv2.inRange(img, COLOR_MIN, COLOR_MAX).astype(bool)
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
bg = np.ones(img.shape, dtype=np.uint8) * 255
bg[text_mask] = img[text_mask]
return bg
def predict(text):
imgs, texts = GENERATOR.generate(word_list=[text])
image_on_white = get_text_from_image(imgs[0])
return image_on_white
CHAR_MAP_PATH, WEIGHTS_PATH = download_weights("sberbank-ai/scrabblegan-peter")
GENERATOR = ImgGenerator(
checkpt_path=WEIGHTS_PATH,
config=Config,
char_map_path=CHAR_MAP_PATH
)
gr.Interface(
predict,
inputs=gr.Textbox(label="Type your text to generate it on an image"),
outputs=gr.Image(label="Generated image"),
title="Peter handwritten image generation",
).launch()