import os os.environ['TORCH_CUDA_ARCH_LIST']="7.5;8.6;9.0;9.0a" os.environ["OMP_NUM_THREADS"] = "1" os.environ["GRADIO_TEMP_DIR"] = "./gradio_tmp" import spaces import os.path as osp import torch import cv2 import numpy as np import time import gradio as gr from models.TextEnhancement import MARCONetPlus from utils.utils_image import imread_uint, uint2tensor4, tensor2uint from networks.rrdbnet2_arch import RRDBNet as BSRGAN # Initialize device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Background restoration model (lazy loading) BGModel = None def load_bg_model(): """Load BSRGAN model for background super-resolution""" global BGModel if BGModel is None: BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) model_old = torch.load('./checkpoints/bsrgan_bg.pth', map_location=device) state_dict = BGModel.state_dict() for ((key, param), (key2, _)) in zip(model_old.items(), state_dict.items()): state_dict[key2] = param BGModel.load_state_dict(state_dict, strict=True) BGModel.eval() for k, v in BGModel.named_parameters(): v.requires_grad = False BGModel = BGModel.to(device) # Text restoration model TextModel = MARCONetPlus( './checkpoints/net_w_encoder_860000.pth', './checkpoints/net_prior_860000.pth', './checkpoints/net_sr_860000.pth', './checkpoints/yolo11m_short_character.pt', device=device ) @spaces.GPU(duration=120) def gradio_inference(input_img, aligned=False, bg_sr=False, scale_factor=2): """Run MARCONetPlus inference with optional background SR""" if input_img is None: return None # Convert input image (PIL) to OpenCV format img_L = cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR) height_L, width_L = img_L.shape[:2] # Background super-resolution if not aligned and bg_sr: load_bg_model() img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation=cv2.INTER_AREA) img_E = uint2tensor4(img_E).to(device) with torch.no_grad(): try: img_E = BGModel(img_E) except: torch.cuda.empty_cache() max_size = 1536 scale = min(max_size / width_L, max_size / height_L, 1.0) new_width = int(width_L * scale) new_height = int(height_L * scale) img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA) img_E = uint2tensor4(img_E).to(device) img_E = BGModel(img_E) img_E = tensor2uint(img_E) else: img_E = img_L # Resize background width_S = width_L * scale_factor height_S = height_L * scale_factor img_E = cv2.resize(img_E, (width_S, height_S), interpolation=cv2.INTER_AREA) # Text restoration SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts( img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned ) if SQ is None: return None if not aligned: SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA) out_img = SQ[:, :, ::-1].astype(np.uint8) else: out_img = en_texts[0][:, :, ::-1].astype(np.uint8) return out_img # Gradio UI with gr.Blocks() as demo: gr.Markdown("# MARCONetPlus Text Image Restoration") with gr.Row(): input_img = gr.Image(type="pil", label="Input Image") output_img = gr.Image(type="numpy", label="Restored Output") with gr.Row(): aligned = gr.Checkbox(label="Aligned (cropped text regions)", value=False) bg_sr = gr.Checkbox(label="Background SR (BSRGAN)", value=False) scale_factor = gr.Slider(1, 4, value=2, step=1, label="Scale Factor") run_btn = gr.Button("Run Inference") run_btn.click( fn=gradio_inference, inputs=[input_img, aligned, bg_sr, scale_factor], outputs=[output_img] ) if __name__ == "__main__": demo.launch(share=True)