cartoonize / app.py
darkheartdsanju's picture
update app.py
e2188c4 verified
raw
history blame contribute delete
No virus
3.6 kB
import numpy as np
import gradio as gr
import cv2
import os
import argparse
from inference import Predictor
import io
#from black import to_black
# os.system("wget https://huggingface.co/YANGYYYY/cartoonize/tree/main/GeneratorV2_train_photo_Hayao_init.pt")
# if os.path.exists("GeneratorV2_train_photo_Hayao_init.pt"):
# print("下载成功!")
# else:
# print("下载失败!")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt')
parser.add_argument('--device', type=str, default='cpu', help='Device, cuda or cpu')
return parser.parse_args()
def parse_args_video():
parser = argparse.ArgumentParser()
parser.add_argument('--weight', type=str, default='GeneratorV2_train_photo_Hayao_init.pt')
parser.add_argument('--src', type=str, default='dataset/video/花.mp4', help='Path to input video')
parser.add_argument('--out', type=str, default='dataset/video_Hayao/hua_hayao.mp4', help='Path to save new video')
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--start', type=int, default=0, help='Start time of video (second)')
parser.add_argument('--end', type=int, default=10, help='End time of video (second), 0 if not set')
return parser.parse_args()
def transfer(image, transfer_style):
if transfer_style == "Hayao":
#output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 转换为灰度图像
#os.system("wget https://huggingface.co/YANGYYYY/cartoonize/resolve/main/GeneratorV2_train_photo_Hayao_init.pt")
args = parse_args()
predictor = Predictor(args.weight, args.device)
anime_img = predictor.transform_image(image)
return anime_img
elif transfer_style == "Shinkai":
args = parse_args()
args.weight = 'GeneratorV2_train_photo_Shinkai_init.pt'
predictor = Predictor(args.weight, args.device)
anime_img = predictor.transform_image(image)
return anime_img
elif transfer_style == "Kon Satoshi":
args = parse_args()
args.weight = 'GeneratorV2_train_photo_Paprika_init.pt'
predictor = Predictor(args.weight, args.device)
anime_img = predictor.transform_image(image)
return anime_img
else:
return image
def clear_output(input_widget):
input_widget = np.array([])
with gr.Blocks() as demo:
gr.Markdown("Transfer image or video files using this demo.")
with gr.Tabs():
with gr.TabItem("Transfer Image"):
with gr.Row():
image_input = gr.Image()
image_output = gr.Image()
with gr.Row():
image_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"])
image_button = gr.Button("Transfer")
clear_image_button = gr.Button("Clear")
with gr.TabItem("Transfer Video"):
with gr.Row():
video_input = gr.Video()
video_output = gr.Video()
with gr.Row():
video_dropdown = gr.Dropdown(label="Transfer Style",choices=["Hayao", "Shinkai", "Kon Satoshi"])
video_button = gr.Button("Transfer")
clear_video_button = gr.Button("Clear")
image_button.click(transfer, inputs=[image_input,image_dropdown], outputs=image_output)
clear_image_button.click(clear_output, inputs=image_input,outputs=image_output)
demo.launch()
# 启动接口
#demo.launch(server_name='127.0.0.1',server_port=7788)