Ahsen Khaliq commited on
Commit
ddaf006
1 Parent(s): 79681a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -18
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  from PIL import Image
3
  import torch
4
  import gradio as gr
5
- os.system("pip install dlib")
6
  os.system("git clone https://github.com/mchong6/JoJoGAN.git")
7
  os.chdir("JoJoGAN")
8
 
@@ -29,9 +29,9 @@ os.makedirs('style_images', exist_ok=True)
29
  os.makedirs('style_images_aligned', exist_ok=True)
30
  os.makedirs('models', exist_ok=True)
31
 
32
- os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2")
33
- os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
34
- os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
35
 
36
 
37
  device = 'cpu'
@@ -44,7 +44,7 @@ latent_dim = 512
44
 
45
  # Load original generator
46
  original_generator = Generator(1024, latent_dim, 8, 2).to(device)
47
- ckpt = torch.load(os.path.join('models', 'stylegan2-ffhq-config-f.pt'), map_location=lambda storage, loc: storage)
48
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
49
  mean_latent = original_generator.mean_latent(10000)
50
 
@@ -60,21 +60,12 @@ transform = transforms.Compose(
60
  )
61
 
62
  plt.rcParams['figure.dpi'] = 150
63
- os.system("gdown https://drive.google.com/uc?id=1-7UlCppmiG4DKbhYDNbIZTc6mHy9JMWJ")
64
- os.system("mv e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt")
65
 
66
 
67
  os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y")
68
- os.system("mv e4e_ffhq_encode.pt models/jojo.pt")
69
 
70
- def inference(img):
71
- filepath = img
72
-
73
- name = strip_path_extension(filepath)+'.pt'
74
-
75
- aligned_face = align_face(filepath)
76
-
77
- my_w = e4e_projection(aligned_face, name, device).unsqueeze(0)
78
 
79
 
80
  plt.rcParams['figure.dpi'] = 150
@@ -83,7 +74,7 @@ def inference(img):
83
  preserve_color = False
84
 
85
 
86
- ckpt = torch.load(os.path.join('models', 'jojo.pt'), map_location=lambda storage, loc: storage)
87
  generator.load_state_dict(ckpt["g"], strict=False)
88
 
89
  with torch.no_grad():
@@ -98,4 +89,4 @@ title = "AnimeGANv2"
98
  description = "Gradio Demo for AnimeGanv2 Face Portrait. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please use a cropped portrait picture for best results similar to the examples below."
99
  article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_animegan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif' alt='animation'/> <img src='https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg' alt='animation'/><img src='https://user-images.githubusercontent.com/26464535/127134790-93595da2-4f8b-4aca-a9d7-98699c5e6914.jpg' alt='animation'/></p>"
100
 
101
- gr.Interface(inference, [gr.inputs.Image(type="filepath")], gr.outputs.Image(type="numpy"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False).launch()
 
2
  from PIL import Image
3
  import torch
4
  import gradio as gr
5
+ #os.system("pip install dlib")
6
  os.system("git clone https://github.com/mchong6/JoJoGAN.git")
7
  os.chdir("JoJoGAN")
8
 
 
29
  os.makedirs('style_images_aligned', exist_ok=True)
30
  os.makedirs('models', exist_ok=True)
31
 
32
+ #os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2")
33
+ #os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
34
+ #os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
35
 
36
 
37
  device = 'cpu'
 
44
 
45
  # Load original generator
46
  original_generator = Generator(1024, latent_dim, 8, 2).to(device)
47
+ ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
48
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
49
  mean_latent = original_generator.mean_latent(10000)
50
 
 
60
  )
61
 
62
  plt.rcParams['figure.dpi'] = 150
 
 
63
 
64
 
65
  os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y")
 
66
 
67
+ def inference(img):
68
+ my_w = e4e_projection(img, "test.pt", device).unsqueeze(0)
 
 
 
 
 
 
69
 
70
 
71
  plt.rcParams['figure.dpi'] = 150
 
74
  preserve_color = False
75
 
76
 
77
+ ckpt = torch.load('jojo.pt', map_location=lambda storage, loc: storage)
78
  generator.load_state_dict(ckpt["g"], strict=False)
79
 
80
  with torch.no_grad():
 
89
  description = "Gradio Demo for AnimeGanv2 Face Portrait. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please use a cropped portrait picture for best results similar to the examples below."
90
  article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_animegan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif' alt='animation'/> <img src='https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg' alt='animation'/><img src='https://user-images.githubusercontent.com/26464535/127134790-93595da2-4f8b-4aca-a9d7-98699c5e6914.jpg' alt='animation'/></p>"
91
 
92
+ gr.Interface(inference, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="numpy"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False).launch()