csxmli's picture
Update app.py
3a261c5 verified
import os
os.environ['TORCH_CUDA_ARCH_LIST']="7.5;8.6;9.0;9.0a"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["GRADIO_TEMP_DIR"] = "./gradio_tmp"
import spaces
import os.path as osp
import torch
import cv2
import numpy as np
import time
import gradio as gr
from models.TextEnhancement import MARCONetPlus
from utils.utils_image import imread_uint, uint2tensor4, tensor2uint
from networks.rrdbnet2_arch import RRDBNet as BSRGAN
# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Background restoration model (lazy loading)
BGModel = None
def load_bg_model():
"""Load BSRGAN model for background super-resolution"""
global BGModel
if BGModel is None:
BGModel = BSRGAN(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2)
model_old = torch.load('./checkpoints/bsrgan_bg.pth', map_location=device)
state_dict = BGModel.state_dict()
for ((key, param), (key2, _)) in zip(model_old.items(), state_dict.items()):
state_dict[key2] = param
BGModel.load_state_dict(state_dict, strict=True)
BGModel.eval()
for k, v in BGModel.named_parameters():
v.requires_grad = False
BGModel = BGModel.to(device)
# Text restoration model
TextModel = MARCONetPlus(
'./checkpoints/net_w_encoder_860000.pth',
'./checkpoints/net_prior_860000.pth',
'./checkpoints/net_sr_860000.pth',
'./checkpoints/yolo11m_short_character.pt',
device=device
)
@spaces.GPU(duration=120)
def gradio_inference(input_img, aligned=False, bg_sr=False, scale_factor=2):
"""Run MARCONetPlus inference with optional background SR"""
if input_img is None:
return None
# Convert input image (PIL) to OpenCV format
img_L = cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)
height_L, width_L = img_L.shape[:2]
# Background super-resolution
if not aligned and bg_sr:
load_bg_model()
img_E = cv2.resize(img_L, (int(width_L//8*8), int(height_L//8*8)), interpolation=cv2.INTER_AREA)
img_E = uint2tensor4(img_E).to(device)
with torch.no_grad():
try:
img_E = BGModel(img_E)
except:
torch.cuda.empty_cache()
max_size = 1536
scale = min(max_size / width_L, max_size / height_L, 1.0)
new_width = int(width_L * scale)
new_height = int(height_L * scale)
img_E = cv2.resize(img_L, (new_width//8*8, new_height//8*8), interpolation=cv2.INTER_AREA)
img_E = uint2tensor4(img_E).to(device)
img_E = BGModel(img_E)
img_E = tensor2uint(img_E)
else:
img_E = img_L
# Resize background
width_S = width_L * scale_factor
height_S = height_L * scale_factor
img_E = cv2.resize(img_E, (width_S, height_S), interpolation=cv2.INTER_AREA)
# Text restoration
SQ, ori_texts, en_texts, debug_texts, pred_texts = TextModel.handle_texts(
img=img_L, bg=img_E, sf=scale_factor, is_aligned=aligned
)
if SQ is None:
return None
if not aligned:
SQ = cv2.resize(SQ.astype(np.float32), (width_S, height_S), interpolation=cv2.INTER_AREA)
out_img = SQ[:, :, ::-1].astype(np.uint8)
else:
out_img = en_texts[0][:, :, ::-1].astype(np.uint8)
return out_img
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# MARCONetPlus Text Image Restoration")
with gr.Row():
input_img = gr.Image(type="pil", label="Input Image")
output_img = gr.Image(type="numpy", label="Restored Output")
with gr.Row():
aligned = gr.Checkbox(label="Aligned (cropped text regions)", value=False)
bg_sr = gr.Checkbox(label="Background SR (BSRGAN)", value=False)
scale_factor = gr.Slider(1, 4, value=2, step=1, label="Scale Factor")
run_btn = gr.Button("Run Inference")
run_btn.click(
fn=gradio_inference,
inputs=[input_img, aligned, bg_sr, scale_factor],
outputs=[output_img]
)
if __name__ == "__main__":
demo.launch(share=True)