import gradio as gr import functools from pixelization import Model import torch import argparse import huggingface_hub import os TOKEN = "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" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--theme', type=str, default='default') parser.add_argument('--live', action='store_true') parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') parser.add_argument('--allow-flagging', type=str, default='never') return parser.parse_args() def main(): args = parse_args() # DL MODEL # PIX_MODEL os.environ['PIX_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "pixelart_vgg19.pth", token=TOKEN); # NET_MODEL os.environ['NET_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "160_net_G_A.pth", token=TOKEN); # ALIAS_MODEL os.environ['ALIAS_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "alias_net.pth", token=TOKEN); # # For local testing # # PIX_MODEL # os.environ['PIX_MODEL'] = "pixelart_vgg19.pth" # # NET_MODEL # os.environ['NET_MODEL'] = "160_net_G_A.pth" # # ALIAS_MODEL # os.environ['ALIAS_MODEL'] = "alias_net.pth" use_cpu = True m = Model(device = "cpu" if use_cpu else "cuda") m.load() # To use GPU: Change use_cpu to false, and checkout my comment on networks.py at line 107 & 108 # + Use torch with cuda support (Change in requirements.txt) gr.Interface(m.pixelize_modified, [ gr.components.Image(type='pil', label='Input'), gr.components.Slider(minimum=1, maximum=16, value=4, step=1, label='Pixel Size'), gr.components.Checkbox(True, label="Upscale after") ], gr.components.Image(type='pil', label='Output'), title="Pixelization", description=''' Demo for [WuZongWei6/Pixelization](https://github.com/WuZongWei6/Pixelization) Models that are used is private to comply with License. ''', theme=args.theme, allow_flagging=args.allow_flagging, live=args.live, ).launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()