Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.environ['TORCH_CUDA_ARCH_LIST']="7.5;8.6;9.0;9.0a" | |
os.environ["OMP_NUM_THREADS"] = "1" | |
os.environ["GRADIO_TEMP_DIR"] = "./gradio_tmp" | |
import spaces | |
import os.path as osp | |
import torch | |
import cv2 | |
import numpy as np | |
import time | |
import gradio as gr | |
from models.TextEnhancement import MARCONetPlus | |
from utils.utils_image import imread_uint, uint2tensor4, tensor2uint | |
from networks.rrdbnet2_arch import RRDBNet as BSRGAN | |
# Initialize device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Background restoration model (lazy loading) | |
BGModel = None | |
def load_bg_model(): | |
"""Load BSRGAN model for background super-resolution""" | |
global BGModel | |
if BGModel is None: | |
BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) | |
model_old = torch.load('./checkpoints/bsrgan_bg.pth', map_location=device) | |
state_dict = BGModel.state_dict() | |
for ((key, param), (key2, _)) in zip(model_old.items(), state_dict.items()): | |
state_dict[key2] = param | |
BGModel.load_state_dict(state_dict, strict=True) | |
BGModel.eval() | |
for k, v in BGModel.named_parameters(): | |
v.requires_grad = False | |
BGModel = BGModel.to(device) | |
# Text restoration model | |
TextModel = MARCONetPlus( | |
'./checkpoints/net_w_encoder_860000.pth', | |
'./checkpoints/net_prior_860000.pth', | |
'./checkpoints/net_sr_860000.pth', | |
'./checkpoints/yolo11m_short_character.pt', | |
device=device | |
) | |
def gradio_inference(input_img, aligned=False, bg_sr=False, scale_factor=2): | |
"""Run MARCONetPlus inference with optional background SR""" | |
if input_img is None: | |
return None | |
# Convert input image (PIL) to OpenCV format | |
img_L = cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR) | |
height_L, width_L = img_L.shape[:2] | |
# Background super-resolution | |
if not aligned and bg_sr: | |
load_bg_model() | |
img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation=cv2.INTER_AREA) | |
img_E = uint2tensor4(img_E).to(device) | |
with torch.no_grad(): | |
try: | |
img_E = BGModel(img_E) | |
except: | |
torch.cuda.empty_cache() | |
max_size = 1536 | |
scale = min(max_size / width_L, max_size / height_L, 1.0) | |
new_width = int(width_L * scale) | |
new_height = int(height_L * scale) | |
img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA) | |
img_E = uint2tensor4(img_E).to(device) | |
img_E = BGModel(img_E) | |
img_E = tensor2uint(img_E) | |
else: | |
img_E = img_L | |
# Resize background | |
width_S = width_L * scale_factor | |
height_S = height_L * scale_factor | |
img_E = cv2.resize(img_E, (width_S, height_S), interpolation=cv2.INTER_AREA) | |
# Text restoration | |
SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts( | |
img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned | |
) | |
if SQ is None: | |
return None | |
if not aligned: | |
SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA) | |
out_img = SQ[:, :, ::-1].astype(np.uint8) | |
else: | |
out_img = en_texts[0][:, :, ::-1].astype(np.uint8) | |
return out_img | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# MARCONetPlus Text Image Restoration") | |
with gr.Row(): | |
input_img = gr.Image(type="pil", label="Input Image") | |
output_img = gr.Image(type="numpy", label="Restored Output") | |
with gr.Row(): | |
aligned = gr.Checkbox(label="Aligned (cropped text regions)", value=False) | |
bg_sr = gr.Checkbox(label="Background SR (BSRGAN)", value=False) | |
scale_factor = gr.Slider(1, 4, value=2, step=1, label="Scale Factor") | |
run_btn = gr.Button("Run Inference") | |
run_btn.click( | |
fn=gradio_inference, | |
inputs=[input_img, aligned, bg_sr, scale_factor], | |
outputs=[output_img] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |