rome / app.py
Pie31415's picture
added in submodules
b128124
raw
history blame
2.15 kB
import os, sys
import subprocess
import argparse
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
subprocess.run(["git", "submodule", "update", "--init", "--recursive"])
print(os.getcwd())
print(os.listdir('.'))
sys.path.append("./rome")
from src.utils import args as args_utils
from src.utils.processing import process_black_shape, tensor2image
# loading models ---- create model repo
from huggingface_hub import hf_hub_url
default_modnet_path = hf_hub_url('Pie31415/rome','modnet_photographic_portrait_matting.ckpt')
default_model_path = hf_hub_url('Pie31415/rome','models/rome.pth')
# parser configurations
parser = argparse.ArgumentParser(conflict_handler='resolve')
parser.add_argument('--save_dir', default='.', type=str)
parser.add_argument('--save_render', default='True', type=args_utils.str2bool, choices=[True, False])
parser.add_argument('--model_checkpoint', default=default_model_path, type=str)
parser.add_argument('--modnet_path', default=default_modnet_path, type=str)
parser.add_argument('--random_seed', default=0, type=int)
parser.add_argument('--debug', action='store_true')
parser.add_argument('--verbose', default='False', type=args_utils.str2bool, choices=[True, False])
args, _ = parser.parse_known_args()
parser = importlib.import_module(f'src.rome').ROME.add_argparse_args(parser)
args = parser.parse_args()
args.deca_path = 'DECA'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from infer import Infer
infer = Infer(args)
infer = infer.to(device)
def predict(source_img, driver_img):
out = infer.evaluate(source_img, driver_img, crop_center=False)
res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(),
out['source_information']['data_dict']['target_img'][0].cpu(),
out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
return res[..., ::-1]
import gradio as gr
gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Image(type="pil")
],
outputs=gr.Image(),
examples=[]).launch()