hysts HF staff commited on
Commit
cb2f6c2
1 Parent(s): 3f9cb7e
Files changed (3) hide show
  1. README.md +5 -2
  2. app.py +10 -13
  3. requirements.txt +2 -2
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
- title: Gan Control
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
1
  ---
2
+ title: GAN-Control
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
+ suggested_hardware: t4-small
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
14
+
15
+ https://arxiv.org/abs/2101.02477
app.py CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
5
  import functools
6
  import os
7
  import pathlib
 
8
  import subprocess
9
  import sys
10
  import tarfile
@@ -17,27 +18,22 @@ import torch
17
 
18
  if os.getenv('SYSTEM') == 'spaces':
19
  with open('patch') as f:
20
- subprocess.run('patch -p1'.split(), cwd='gan-control', stdin=f)
21
 
22
  sys.path.insert(0, 'gan-control/src')
23
 
24
  from gan_control.inference.controller import Controller
25
 
26
- DESCRIPTION = '''GAN-Control
27
-
28
- This is an unofficial demo for https://github.com/amazon-research/gan-control.
29
- '''
30
-
31
- TOKEN = os.getenv('HF_TOKEN')
32
 
33
 
34
  def download_models() -> None:
35
  model_dir = pathlib.Path('controller_age015id025exp02hai04ori02gam15')
36
  if not model_dir.exists():
37
  path = huggingface_hub.hf_hub_download(
38
- 'hysts/gan-control',
39
- 'controller_age015id025exp02hai04ori02gam15.tar.gz',
40
- use_auth_token=TOKEN)
41
  with tarfile.open(path) as f:
42
  f.extractall()
43
 
@@ -96,10 +92,10 @@ download_models()
96
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
97
  path = 'controller_age015id025exp02hai04ori02gam15/'
98
  controller = Controller(path, device)
99
- func = functools.partial(run, controller=controller, device=device)
100
 
101
  gr.Interface(
102
- fn=func,
103
  inputs=[
104
  gr.Slider(label='Seed', minimum=0, maximum=1000000, step=1, value=0),
105
  gr.Slider(label='Truncation',
@@ -142,5 +138,6 @@ gr.Interface(
142
  gr.Image(label='Age Controlled', type='pil'),
143
  gr.Image(label='Hair Color Controlled', type='pil'),
144
  ],
 
145
  description=DESCRIPTION,
146
- ).queue().launch(show_api=False)
 
5
  import functools
6
  import os
7
  import pathlib
8
+ import shlex
9
  import subprocess
10
  import sys
11
  import tarfile
 
18
 
19
  if os.getenv('SYSTEM') == 'spaces':
20
  with open('patch') as f:
21
+ subprocess.run(shlex.split('patch -p1'), cwd='gan-control', stdin=f)
22
 
23
  sys.path.insert(0, 'gan-control/src')
24
 
25
  from gan_control.inference.controller import Controller
26
 
27
+ TITLE = 'GAN-Control'
28
+ DESCRIPTION = 'https://github.com/amazon-research/gan-control'
 
 
 
 
29
 
30
 
31
  def download_models() -> None:
32
  model_dir = pathlib.Path('controller_age015id025exp02hai04ori02gam15')
33
  if not model_dir.exists():
34
  path = huggingface_hub.hf_hub_download(
35
+ 'public-data/gan-control',
36
+ 'controller_age015id025exp02hai04ori02gam15.tar.gz')
 
37
  with tarfile.open(path) as f:
38
  f.extractall()
39
 
 
92
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
93
  path = 'controller_age015id025exp02hai04ori02gam15/'
94
  controller = Controller(path, device)
95
+ fn = functools.partial(run, controller=controller, device=device)
96
 
97
  gr.Interface(
98
+ fn=fn,
99
  inputs=[
100
  gr.Slider(label='Seed', minimum=0, maximum=1000000, step=1, value=0),
101
  gr.Slider(label='Truncation',
 
138
  gr.Image(label='Age Controlled', type='pil'),
139
  gr.Image(label='Hair Color Controlled', type='pil'),
140
  ],
141
+ title=TITLE,
142
  description=DESCRIPTION,
143
+ ).queue(max_size=10).launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- numpy==1.22.3
2
- Pillow==9.1.0
3
  torch==1.11.0
4
  torchvision==0.12.0
 
1
+ numpy==1.23.5
2
+ Pillow==10.0.0
3
  torch==1.11.0
4
  torchvision==0.12.0