Pie31415 commited on
Commit
3cf6a6c
β€’
1 Parent(s): ae3c9eb

inital commit

Browse files
Files changed (2) hide show
  1. .gitmodules +9 -0
  2. app.py +51 -0
.gitmodules ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "ROME"]
2
+ path = ROME
3
+ url = https://github.com/khakhulin/DECA/tree/1cc2361a2929a206e1b0330ee8b89fcda478d037
4
+ [submodule "DECA"]
5
+ path = DECA
6
+ url = https://github.com/khakhulin/DECA/tree/1cc2361a2929a206e1b0330ee8b89fcda478d037
7
+ [submodule "MODNet"]
8
+ path = MODNet
9
+ url = https://github.com/ZHKKKe/MODNet/tree/28165a451e4610c9d77cfdf925a94610bb2810fb
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import numpy as np
3
+ import torch
4
+ from torchvision import transforms
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+
8
+ # loading models ---- create model repo
9
+ from huggingface_hub import hf_hub_url
10
+ default_modnet_path = hf_hub_url('Pie31415/rome','modnet_photographic_portrait_matting.ckpt')
11
+ default_model_path = hf_hub_url('Pie31415/rome','models/rome.pth')
12
+
13
+ # parser configurations
14
+ parser = argparse.ArgumentParser(conflict_handler='resolve')
15
+ parser.add_argument('--save_dir', default='.', type=str)
16
+ parser.add_argument('--save_render', default='True', type=args_utils.str2bool, choices=[True, False])
17
+ parser.add_argument('--model_checkpoint', default=default_model_path, type=str)
18
+ parser.add_argument('--modnet_path', default=default_modnet_path, type=str)
19
+ parser.add_argument('--random_seed', default=0, type=int)
20
+ parser.add_argument('--debug', action='store_true')
21
+ parser.add_argument('--verbose', default='False', type=args_utils.str2bool, choices=[True, False])
22
+ args, _ = parser.parse_known_args()
23
+
24
+ parser = importlib.import_module(f'src.rome').ROME.add_argparse_args(parser)
25
+ args = parser.parse_args()
26
+ args.deca_path = 'DECA'
27
+
28
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+
30
+ from infer import Infer
31
+
32
+ infer = Infer(args)
33
+ infer = infer.to(device)
34
+
35
+ def predict(source_img, driver_img):
36
+ out = infer.evaluate(source_img, driver_img, crop_center=False)
37
+ res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(),
38
+ out['source_information']['data_dict']['target_img'][0].cpu(),
39
+ out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
40
+ return res[..., ::-1]
41
+
42
+
43
+ import gradio as gr
44
+ gr.Interface(
45
+ fn=predict,
46
+ inputs=[
47
+ gr.Image(type="pil"),
48
+ gr.Image(type="pil")
49
+ ],
50
+ outputs=gr.Image(),
51
+ examples=[]).launch()