hysts HF staff commited on
Commit
2a13f62
1 Parent(s): fb7a74a
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +27 -25
  3. requirements.txt +3 -3
  4. style.css +3 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐢
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.34.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -22,12 +22,7 @@ sys.path.insert(0, 'anime_face_landmark_detection')
22
 
23
  from CFA import CFA
24
 
25
- TITLE = 'kanosawa/anime_face_landmark_detection'
26
- DESCRIPTION = 'This is an unofficial demo for https://github.com/kanosawa/anime_face_landmark_detection.'
27
-
28
- HF_TOKEN = os.getenv('HF_TOKEN')
29
- MODEL_REPO = 'hysts/anime_face_landmark_detection'
30
- MODEL_FILENAME = 'checkpoint_landmark_191116.pth'
31
 
32
  NUM_LANDMARK = 24
33
  CROP_SIZE = 128
@@ -39,8 +34,7 @@ def load_sample_image_paths() -> list[pathlib.Path]:
39
  dataset_repo = 'hysts/sample-images-TADNE'
40
  path = huggingface_hub.hf_hub_download(dataset_repo,
41
  'images.tar.gz',
42
- repo_type='dataset',
43
- use_auth_token=HF_TOKEN)
44
  with tarfile.open(path) as f:
45
  f.extractall()
46
  return sorted(image_dir.glob('*'))
@@ -55,9 +49,9 @@ def load_face_detector() -> cv2.CascadeClassifier:
55
 
56
 
57
  def load_landmark_detector(device: torch.device) -> torch.nn.Module:
58
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
59
- MODEL_FILENAME,
60
- use_auth_token=HF_TOKEN)
61
  model = CFA(output_channel_num=NUM_LANDMARK + 1, checkpoint_name=path)
62
  model.to(device)
63
  model.eval()
@@ -122,17 +116,25 @@ transform = T.Compose([
122
  T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
123
  ])
124
 
125
- func = functools.partial(detect,
126
- face_detector=face_detector,
127
- device=device,
128
- transform=transform,
129
- landmark_detector=landmark_detector)
130
-
131
- gr.Interface(
132
- fn=func,
133
- inputs=gr.Image(label='Input', type='filepath'),
134
- outputs=gr.Image(label='Output', type='numpy'),
135
- examples=examples,
136
- title=TITLE,
137
- description=DESCRIPTION,
138
- ).queue().launch(show_api=False)
 
 
 
 
 
 
 
 
 
22
 
23
  from CFA import CFA
24
 
25
+ DESCRIPTION = '# [kanosawa/anime_face_landmark_detection](https://github.com/kanosawa/anime_face_landmark_detection)'
 
 
 
 
 
26
 
27
  NUM_LANDMARK = 24
28
  CROP_SIZE = 128
 
34
  dataset_repo = 'hysts/sample-images-TADNE'
35
  path = huggingface_hub.hf_hub_download(dataset_repo,
36
  'images.tar.gz',
37
+ repo_type='dataset')
 
38
  with tarfile.open(path) as f:
39
  f.extractall()
40
  return sorted(image_dir.glob('*'))
 
49
 
50
 
51
  def load_landmark_detector(device: torch.device) -> torch.nn.Module:
52
+ path = huggingface_hub.hf_hub_download(
53
+ 'public-data/anime_face_landmark_detection',
54
+ 'checkpoint_landmark_191116.pth')
55
  model = CFA(output_channel_num=NUM_LANDMARK + 1, checkpoint_name=path)
56
  model.to(device)
57
  model.eval()
 
116
  T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
117
  ])
118
 
119
+ fn = functools.partial(detect,
120
+ face_detector=face_detector,
121
+ device=device,
122
+ transform=transform,
123
+ landmark_detector=landmark_detector)
124
+
125
+ with gr.Blocks(css='style.css') as demo:
126
+ gr.Markdown(DESCRIPTION)
127
+ with gr.Row():
128
+ with gr.Column():
129
+ image = gr.Image(label='Input', type='filepath')
130
+ run_button = gr.Button('Run')
131
+ with gr.Column():
132
+ result = gr.Image(label='Result')
133
+
134
+ gr.Examples(examples=examples,
135
+ inputs=image,
136
+ outputs=result,
137
+ fn=fn,
138
+ cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
139
+ run_button.click(fn=fn, inputs=image, outputs=result, api_name='predict')
140
+ demo.queue(max_size=15).launch()
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- opencv-python-headless>=4.5.5.62
2
- torch>=1.10.1
3
- torchvision>=0.11.2
 
1
+ opencv-python-headless>=4.7.0.72
2
+ torch==2.0.1
3
+ torchvision==0.15.2
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }