import gradio as gr from argparse import ArgumentParser def parse_args(): parser = ArgumentParser() parser.add_argument( '--model_path', type=str, default='./model/v3.pth', help='Path or url to the model file', ) parser.add_argument( '--model_type', type=str, default='mlp', help='Type of the model', ) parser.add_argument( '--fix_model_path', action='store_true', help='Fix the model path', ) parser.add_argument( '--device', type=str, default='cuda', help='Device to use', ) parser.add_argument( '--share', action='store_true', help='Share the demo', ) return parser.parse_args() def ui(args): from waifu_scorer.predict import WaifuScorer, load_model scorer = WaifuScorer( model_path=args.model_path, model_type=args.model_type, device=args.device, ) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image = gr.Image( label='Image', type='pil', height=512, sources=['upload', 'clipboard'], ) with gr.Column(): with gr.Row(): model_path = gr.Textbox( label='Model Path', value=args.model_path, placeholder='Path or URL to the model file', interactive=not args.fix_model_path, ) with gr.Row(): score = gr.Number( label='Score', ) def change_model(model_path): nonlocal scorer scorer.mlp = load_model(model_path, model_type=args.model_type, device=args.device) print(f"Model changed to `{model_path}`") return gr.update() model_path.submit( fn=change_model, inputs=model_path, outputs=model_path, ) image.change( fn=lambda image: scorer.predict([image]*2)[0] if image is not None else None, inputs=image, outputs=score, ) return demo def launch(args): demo = ui(args) demo.launch(share=args.share)