pixelization / app.py
NoCrypt's picture
init
2c9c37b
raw
history blame
No virus
2.54 kB
import gradio as gr
import functools
from pixelization import Model
import torch
import argparse
import huggingface_hub
import os
TOKEN = "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"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--theme', type=str, default='default')
parser.add_argument('--live', action='store_true')
parser.add_argument('--share', action='store_true')
parser.add_argument('--port', type=int)
parser.add_argument('--disable-queue',
dest='enable_queue',
action='store_false')
parser.add_argument('--allow-flagging', type=str, default='never')
return parser.parse_args()
def main():
args = parse_args()
# DL MODEL
# PIX_MODEL
os.environ['PIX_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "pixelart_vgg19.pth", token=TOKEN);
# NET_MODEL
os.environ['NET_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "160_net_G_A.pth", token=TOKEN);
# ALIAS_MODEL
os.environ['ALIAS_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "alias_net.pth", token=TOKEN);
# # For local testing
# # PIX_MODEL
# os.environ['PIX_MODEL'] = "pixelart_vgg19.pth"
# # NET_MODEL
# os.environ['NET_MODEL'] = "160_net_G_A.pth"
# # ALIAS_MODEL
# os.environ['ALIAS_MODEL'] = "alias_net.pth"
use_cpu = True
m = Model(device = "cpu" if use_cpu else "cuda")
m.load()
# To use GPU: Change use_cpu to false, and checkout my comment on networks.py at line 107 & 108
# + Use torch with cuda support (Change in requirements.txt)
gr.Interface(m.pixelize_modified,
[
gr.components.Image(type='pil', label='Input'),
gr.components.Slider(minimum=1, maximum=16, value=4, step=1, label='Pixel Size'),
gr.components.Checkbox(True, label="Upscale after")
],
gr.components.Image(type='pil', label='Output'),
title="Pixelization",
description='''
Demo for [WuZongWei6/Pixelization](https://github.com/WuZongWei6/Pixelization)
Models that are used is private to comply with License.
''',
theme=args.theme,
allow_flagging=args.allow_flagging,
live=args.live,
).launch(
enable_queue=args.enable_queue,
server_port=args.port,
share=args.share,
)
if __name__ == '__main__':
main()