import gradio as gr from options.test_options import TestOptions from data import create_dataset from models import create_model from PIL import Image import torchvision.transforms as transforms import torch import sys import matplotlib.pyplot as plt "python test.py --model test --name selfie2anime --dataroot selfie2anime/testB --num_test 100 --model_suffix '_B' --no_dropout" title = "MASFNet: Multi-scale Adaptive Sampling Fusion Network for Object Detection in Adverse Weather" description = "" article = "" def reset_interface(): return gr.update(value=None), gr.update(visible=False) def resize_image(img): # 将图片调整为256x256分辨率 return img.resize((256, 256), Image.BICUBIC) def check_resolution(img): # 获取图片分辨率 width, height = img.size # 检查分辨率是否符合要求 if (width == 256 and height == 256) or (width == 64 and height == 64): return True else: return False def inference(img): try: # Debugging: Check if image is correctly received if img is None: print("No image received!") return None if check_resolution(img)==False: img = resize_image(img) import sys sys.argv = ['--model', '--dataroot', './data/', '--num_test', '1', '--no_dropout'] # Load options and set them up opt = TestOptions().parse() opt.num_threads = 0 opt.batch_size = 1 opt.serial_batches = True opt.no_flip = True opt.display_id = -1 opt.name = '' opt.model_suffix = '_B' opt.num_test = 1 opt.no_dropout = True # Create model and set it up dataset = create_dataset(opt) model = create_model(opt) model.setup(opt) if opt.eval: model.eval() # Convert PIL image to tensor img_tensor = transforms.ToTensor()(img.convert('RGB')).unsqueeze(0) img_tensor = img_tensor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Move to GPU if available # Prepare data for the model data = {'A':img_tensor,'A_paths':'./data/'} model.set_input(data) model.test() # Get the output visuals img_out = model.get_current_visuals() output_img_tensor = img_out.get('fake') print(f'type of output_img_tensor: {type(img_out)}') if output_img_tensor is None: print("No output from model!") return None if isinstance(output_img_tensor, torch.Tensor): # 将张量转换回PIL图像 output_img = output_img_tensor.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0) output_img = (output_img * 0.5 + 0.5) * 255 # 假设输出在[-1, 1]之间标准化 output_img = output_img.astype('uint8') output_img = Image.fromarray(output_img) print(f'type if output_img_tensor: {type(output_img_tensor)}') return output_img else: print(f"意外的输出类型: {type(output_img_tensor)}") return None except Exception as e: print(f"Error during inference: {e}") return None example_images = [ "img/1.png" ] with gr.Blocks() as demo: gr.Markdown(f"### {title}") gr.Markdown(description) with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", label="Upload an Image") submit_btn = gr.Button("Submit...") with gr.Column(): output = gr.Image(type="pil", label="Prediction Result") submit_btn.click(fn=inference, inputs=img_input, outputs=output) demo.load(reset_interface, None, output) gr.Examples( examples=example_images, inputs=img_input, ) demo.launch()