Spaces:
Build error
Build error
import cv2 | |
import gradio as gr | |
import os | |
from edit_func import * | |
from TransUnet import Trans_UNet | |
import TransUnet_Config as config2 | |
from huggingface_hub import hf_hub_download | |
from googletrans import Translator | |
import random | |
import torch.nn as nn | |
import spaces | |
# @spaces.GPU | |
class DTM(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.detect_text_model = Trans_UNet( | |
config2.in_channels, config2.adp_channels, config2.out_channels, | |
config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels, | |
config2.dropout | |
).to(self.device) | |
self.repo_name = 'SS3M/detect-text-model' | |
files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt', | |
'detect-text-v3-2.pt', 'detect-text-v3-3.pt', | |
'detect-text-v3-4.pt', 'detect-text-v3-5.pt', | |
'detect-text-v3-6.pt', 'detect-text-v3-7.pt'] | |
self.files = [] | |
for file in files: | |
self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file)) | |
def forward(self, X): | |
X = X.to(self.device) | |
N, C, H, W = X.shape | |
result = torch.zeros((N, 1, H, W)) | |
for file in self.files: | |
model_path = file | |
best_model_state = torch.load( | |
model_path, | |
weights_only=True, | |
map_location=self.device | |
) | |
self.detect_text_model.load_state_dict(best_model_state) | |
result += self.detect_text_model(X) | |
result /= len(self.files) | |
return result | |
# @spaces.GPU | |
class DWBM(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.detect_wordball_model = Trans_UNet( | |
config2.in_channels, config2.adp_channels, config2.out_channels, | |
config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels, | |
config2.dropout | |
).to(self.device) | |
self.repo_name = 'SS3M/detect-wordball-model' | |
files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt', | |
'detect-text-v3-2.pt', 'detect-text-v3-3.pt', | |
'detect-text-v3-4.pt', 'detect-text-v3-5.pt', | |
'detect-text-v3-6.pt', 'detect-text-v3-7.pt'] | |
self.files = [] | |
for file in files: | |
self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file)) | |
def forward(self, X): | |
X = X.to(self.device) | |
N, C, H, W = X.shape | |
result = torch.zeros((N, 1, H, W)) | |
for file in self.files: | |
model_path = file | |
best_model_state = torch.load( | |
model_path, | |
weights_only=True, | |
map_location=self.device | |
) | |
self.detect_wordball_model.load_state_dict(best_model_state) | |
result += self.detect_wordball_model(X) | |
result /= len(self.files) | |
return result | |
detect_text_model = DTM() | |
detect_wordball_model = DWBM() | |
translator = Translator() | |
def down1(src_img): | |
src_img = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR) | |
text_msk = create_text_mask(src_img, detect_text_model) | |
wordball_msk = create_wordball_mask(src_img, detect_wordball_model) | |
text_positions, areas = get_text_positions(text_msk, text_value=0) | |
rgbs = [] | |
for _ in range(len(areas)): | |
rgbs.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))) | |
idx = '; '.join(str(i) for i in range(len(areas))) | |
text_positions = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions]) | |
areas = '; '.join(str(i) for i in areas) | |
rgbs = '; '.join([', '.join(str(i) for i in rgb) for rgb in rgbs]) | |
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) | |
return text_msk*255, wordball_msk*255, idx, text_positions, areas, rgbs, 'Xong' | |
def idx_txt_change(src_img, idx_txt, pos_txt, rgb_txt): | |
try: | |
src_img2 = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR) | |
text_positions = pos_txt.split('; ') | |
for idx in range(len(text_positions)): | |
text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) | |
rgbs = rgb_txt.split('; ') | |
for idx in range(len(rgbs)): | |
rgbs[idx] = (int(i) for i in rgbs[idx].split(', ')) | |
idxes = [int(idx) for idx in idx_txt.split('; ')] | |
for idx, ((min_x, min_y, max_x, max_y), (r, g, b)) in enumerate(zip(text_positions, rgbs)): | |
if idx in idxes: | |
cv2.rectangle(src_img2, (min_x, min_y), (max_x, max_y), (b, g, r), thickness=4) | |
src_img2 = cv2.cvtColor(src_img2, cv2.COLOR_BGR2RGB) | |
return src_img2 | |
except: | |
return src_img | |
def scale_area_change(min_area, max_area, area_txt): | |
areas = [int(area) for area in area_txt.split('; ')] | |
idxes = [] | |
for idx, area in enumerate(areas): | |
if min_area <= area <= max_area: | |
idxes.append(idx) | |
idxes = '; '.join(str(i) for i in idxes) | |
return idxes | |
def position_block_change(X, Y, W, H, ID, pos_txt_value): | |
text_positions = pos_txt_value.split('; ') | |
for idx in range(len(text_positions)): | |
text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) | |
text_positions2 = [] | |
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions): | |
if idx == ID: | |
text_positions2.append((X, Y, X+W, Y+H)) | |
else: | |
text_positions2.append((min_x, min_y, max_x, max_y)) | |
text_positions2 = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions2]) | |
return text_positions2 | |
def ID_block_change(ID_value, checkbox_value, ID_txt_value): | |
ID_txt_value = [int(i) for i in ID_txt_value.split('; ')] | |
if checkbox_value and ID_value not in ID_txt_value: | |
ID_txt_value.append(ID_value) | |
if not checkbox_value and ID_value in ID_txt_value: | |
ID_txt_value.remove(ID_value) | |
ID_txt_value = sorted(ID_txt_value) | |
ID_txt_value = '; '.join([str(i) for i in ID_txt_value]) | |
return ID_txt_value | |
def down2(src_img_value, txt_mask_value, wordball_mask_value, idx_txt_value, pos_txt_value): | |
src_img_value = cv2.cvtColor(src_img_value, cv2.COLOR_RGB2BGR) | |
text_positions = pos_txt_value.split('; ') | |
for idx in range(len(text_positions)): | |
text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) | |
idxes = [int(i) for i in idx_txt_value.split('; ')] | |
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions): | |
if idx not in idxes: | |
txt_mask_value[min_y:max_y+1, min_x:max_x+1] = 255 | |
txt_mask_value = txt_mask_value[:, :, 0].astype(np.uint8) | |
non_text_src_img = clear_text(src_img_value, txt_mask_value, wordball_mask_value, text_value=0, non_text_value=255, r=5) | |
list_texts = get_list_texts(src_img_value, [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes]) | |
list_translated_texts = translate(list_texts, translator) | |
list_fonts = '; '.join(['MTO Astro City.ttf' for _ in range(len(list_translated_texts))]) | |
list_sizes = '; '.join(['20' for _ in range(len(list_translated_texts))]) | |
list_strokes = '; '.join(['3' for _ in range(len(list_translated_texts))]) | |
list_pads = '; '.join(['5' for _ in range(len(list_translated_texts))]) | |
list_translated_texts = '; '.join(list_translated_texts) | |
switch = str(random.random()) | |
return non_text_src_img, list_translated_texts, list_fonts, list_sizes, list_strokes, list_pads, switch, 'Xong' | |
def text_info_change(non_txt_img_value, translated_txt_value, pos_txt_value, idx_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value): | |
non_txt_img_value = non_txt_img_value.copy() | |
idxes = [int(i) for i in idx_txt_value.split('; ')] | |
translated_text_src_img = insert_text(non_txt_img_value, | |
translated_txt_value.split('; '), | |
[tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes], | |
font=font_txt_value.split('; '), | |
font_size=[int(i) for i in size_txt_value.split('; ')], | |
pad=[int(i) for i in pad_txt_value.split('; ')], | |
stroke=[int(i) for i in stroke_txt_value.split('; ')]) | |
return translated_text_src_img | |
def value2_change(value, ID2_value, txt_value): | |
txt_value = txt_value.split('; ') | |
txt_value2 = [] | |
for idx, text in enumerate(txt_value): | |
if idx == ID2_value: | |
txt_value2.append(str(value)) | |
else: | |
txt_value2.append(str(text)) | |
txt_value2 = '; '.join(txt_value2) | |
return txt_value2 | |
# Tạo giao diện Gradio | |
with gr.Blocks() as demo: | |
# Cấu trúc | |
src_img = gr.Image(type="numpy", label="Upload Image") | |
down_bttn_1 = gr.Button("↓", elem_classes="arrow-button") | |
with gr.Row(): | |
txt_mask = gr.Image(type="numpy", label="Upload Image", visible=True) | |
wordball_mask = gr.Image(type="numpy", label="Upload Image", visible=True) | |
complete = gr.Textbox() | |
with gr.Row(): | |
idx_txt = gr.Textbox(label='ID', interactive=False, visible=False) | |
pos_txt = gr.Textbox(label='Pos', interactive=False, visible=False) | |
area_txt = gr.Textbox(label='Area', interactive=False, visible=False) | |
rgb_txt = gr.Textbox(label='rgb', interactive=False, visible=False) | |
with gr.Row(): | |
boxed_txt_img = gr.Image(type="numpy", label="Upload Image") | |
with gr.Column() as down_1_column: | |
def create_box(pos_txt_value, rgb_txt_value): | |
text_positions = pos_txt_value.split('; ') | |
for idx in range(len(text_positions)): | |
text_positions[idx] = (int(i) for i in text_positions[idx].split(', ')) | |
rgbs = rgb_txt_value.split('; ') | |
for idx in range(len(rgbs)): | |
rgbs[idx] = (int(i) for i in rgbs[idx].split(', ')) | |
elements = [] | |
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions): | |
with gr.Group() as box: | |
r, g, b = rgbs[idx] | |
with gr.Row(): | |
gr.Markdown( | |
f""" | |
<div style="margin-left: 20px; display: flex; align-items: center;"> | |
<div style="width: 10px; height: 10px; background-color: rgb({r}, {g}, {b}); margin-right: 5px;"></div> | |
<span style="font-size: 20px;">Textbox {idx+1}</span> | |
</div> | |
""" | |
) | |
checkbox = gr.Checkbox(value=True, label='', min_width=50, interactive=True) | |
with gr.Row(): | |
X = gr.Number(label="X", value=min_x, interactive=True) | |
Y = gr.Number(label="Y", value=min_y, interactive=True) | |
W = gr.Number(label="W", value=max_x-min_x, interactive=True) | |
H = gr.Number(label="H", value=max_y-min_y, interactive=True) | |
ID = gr.Number(label="ID", value=idx, interactive=True, visible=False) | |
elements.append((X, Y, W, H, ID)) | |
checkbox.change( | |
fn=ID_block_change, | |
inputs=[ID, checkbox, idx_txt], | |
outputs=idx_txt, | |
show_progress=False | |
).then( | |
fn=idx_txt_change, | |
inputs=[src_img, idx_txt, pos_txt, rgb_txt], | |
outputs=boxed_txt_img, | |
) | |
X.change( | |
fn=position_block_change, | |
inputs=[X, Y, W, H, ID, pos_txt], | |
outputs=pos_txt, | |
show_progress=False | |
).then( | |
fn=idx_txt_change, | |
inputs=[src_img, idx_txt, pos_txt, rgb_txt], | |
outputs=boxed_txt_img, | |
show_progress=False | |
) | |
Y.change( | |
fn=position_block_change, | |
inputs=[X, Y, W, H, ID, pos_txt], | |
outputs=pos_txt, | |
show_progress=False | |
).then( | |
fn=idx_txt_change, | |
inputs=[src_img, idx_txt, pos_txt, rgb_txt], | |
outputs=boxed_txt_img, | |
show_progress=False | |
) | |
W.change( | |
fn=position_block_change, | |
inputs=[X, Y, W, H, ID, pos_txt], | |
outputs=pos_txt, | |
show_progress=False | |
).then( | |
fn=idx_txt_change, | |
inputs=[src_img, idx_txt, pos_txt, rgb_txt], | |
outputs=boxed_txt_img, | |
show_progress=False | |
) | |
H.change( | |
fn=position_block_change, | |
inputs=[X, Y, W, H, ID, pos_txt], | |
outputs=pos_txt, | |
show_progress=False | |
).then( | |
fn=idx_txt_change, | |
inputs=[src_img, idx_txt, pos_txt, rgb_txt], | |
outputs=boxed_txt_img, | |
show_progress=False | |
) | |
down_bttn_2 = gr.Button("↓", elem_classes="arrow-button") | |
non_txt_img = gr.Image(type="numpy", label="Upload Image", visible=False) | |
complete2 = gr.Textbox() | |
with gr.Row(): | |
translated_txt = gr.Textbox(label='translated', interactive=False, visible=False) | |
font_txt = gr.Textbox(label='font', interactive=False, visible=False) | |
size_txt = gr.Textbox(label='size', interactive=False, visible=False) | |
stroke_txt = gr.Textbox(label='stroke', interactive=False, visible=False) | |
pad_txt = gr.Textbox(label='pad', interactive=False, visible=False) | |
switch_txt = gr.Textbox(label='switch', value='1', interactive=False, visible=False) | |
with gr.Row(): | |
boxed_inserted_non_txt_img = gr.Image(type="numpy", label="Upload Image") | |
with gr.Column(): | |
def create_box2(translated_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value): | |
translated_txt_value = translated_txt_value.split('; ') | |
font_txt_value = font_txt_value.split('; ') | |
size_txt_value = size_txt_value.split('; ') | |
stroke_txt_value = stroke_txt_value.split('; ') | |
pad_txt_value = pad_txt_value.split('; ') | |
elements = [] | |
for idx in range(len(font_txt_value)): | |
with gr.Group(): | |
gr.Markdown( | |
f""" | |
<div style="margin-left: 20px; display: flex; align-items: center;"> | |
<div style="width: 10px; height: 10px; background-color: rgb(255, 255, 255); margin-right: 5px;"></div> | |
<span style="font-size: 20px;">Text box {idx}</span> | |
</div> | |
""" | |
) | |
translated_text_box = gr.Textbox(label="Translate", value=translated_txt_value[idx], interactive=True) | |
with gr.Row(): | |
font = gr.Dropdown(choices=os.listdir('MTO Font'), label="Phông chữ", value=font_txt_value[idx], interactive=True, scale=7) | |
size = gr.Number(label="Size", value=int(size_txt_value[idx]), interactive=True, minimum=1) | |
stroke = gr.Number(label="Stroke", value=int(stroke_txt_value[idx]), interactive=True, minimum=0, maximum=5) | |
pad = gr.Number(label="Pad", value=int(pad_txt_value[idx]), interactive=True, minimum=1, maximum=10) | |
ID2 = gr.Number(label="ID", value=int(idx), interactive=True, visible=False) | |
translated_text_box.submit( | |
fn=value2_change, | |
inputs=[translated_text_box, ID2, translated_txt], | |
outputs=translated_txt, | |
show_progress=False | |
).then( | |
fn=text_info_change, | |
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], | |
outputs=boxed_inserted_non_txt_img, | |
) | |
font.change( | |
fn=value2_change, | |
inputs=[font, ID2, font_txt], | |
outputs=font_txt, | |
show_progress=False | |
).then( | |
fn=text_info_change, | |
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], | |
outputs=boxed_inserted_non_txt_img, | |
) | |
size.change( | |
fn=value2_change, | |
inputs=[size, ID2, size_txt], | |
outputs=size_txt, | |
show_progress=False | |
).then( | |
fn=text_info_change, | |
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], | |
outputs=boxed_inserted_non_txt_img, | |
) | |
stroke.change( | |
fn=value2_change, | |
inputs=[stroke, ID2, stroke_txt], | |
outputs=stroke_txt, | |
show_progress=False | |
).then( | |
fn=text_info_change, | |
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], | |
outputs=boxed_inserted_non_txt_img, | |
) | |
pad.change( | |
fn=value2_change, | |
inputs=[pad, ID2, pad_txt], | |
outputs=pad_txt, | |
show_progress=False | |
).then( | |
fn=text_info_change, | |
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], | |
outputs=boxed_inserted_non_txt_img, | |
) | |
# Css | |
demo.css = """ | |
.arrow-button { | |
font-size: 40px; /* Kích thước font */ | |
} | |
.group-elem { | |
height: 70px; | |
} | |
""" | |
# Điều khiển | |
down_bttn_1.click( | |
fn=down1, | |
inputs=src_img, | |
outputs=[txt_mask, wordball_mask, idx_txt, pos_txt, area_txt, rgb_txt, complete], | |
) | |
down_bttn_2.click( | |
fn=down2, | |
inputs=[src_img, txt_mask, wordball_mask, idx_txt, pos_txt], | |
outputs=[non_txt_img, translated_txt, font_txt, size_txt, stroke_txt, pad_txt, switch_txt, complete2], | |
).then( | |
fn=text_info_change, | |
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt], | |
outputs=boxed_inserted_non_txt_img, | |
) | |
demo.launch() | |