hysts HF staff commited on
Commit
2db804d
β€’
1 Parent(s): 83321d4
Files changed (4) hide show
  1. README.md +4 -1
  2. app.py +33 -34
  3. requirements.txt +5 -5
  4. style.css +3 -0
README.md CHANGED
@@ -4,9 +4,12 @@ emoji: πŸ“š
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
+ suggested_hardware: t4-small
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
14
+
15
+ https://arxiv.org/abs/2104.04767
app.py CHANGED
@@ -3,7 +3,6 @@
3
  from __future__ import annotations
4
 
5
  import functools
6
- import os
7
 
8
  import gradio as gr
9
  import numpy as np
@@ -13,8 +12,7 @@ from huggingface_hub import hf_hub_download
13
 
14
  from model import Model
15
 
16
- TITLE = 'MobileStyleGAN'
17
- DESCRIPTION = 'This is an unofficial demo for https://github.com/bes-dev/MobileStyleGAN.pytorch.'
18
  SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/MobileStyleGAN/resolve/main/samples'
19
  ARTICLE = f'''## Generated images
20
  ### FFHQ
@@ -24,8 +22,6 @@ ARTICLE = f'''## Generated images
24
  ![FFHQ]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
25
  '''
26
 
27
- HF_TOKEN = os.getenv('HF_TOKEN')
28
-
29
 
30
  def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
31
  return torch.from_numpy(np.random.RandomState(seed).randn(
@@ -45,9 +41,8 @@ def generate_image(seed: int, truncation_psi: float, generator: str,
45
 
46
 
47
  def load_model(device: torch.device) -> nn.Module:
48
- path = hf_hub_download('hysts/MobileStyleGAN',
49
- 'models/mobilestylegan_ffhq_v2.pth',
50
- use_auth_token=HF_TOKEN)
51
  ckpt = torch.load(path)
52
  model = Model()
53
  model.load_state_dict(ckpt['state_dict'], strict=False)
@@ -62,29 +57,33 @@ def load_model(device: torch.device) -> nn.Module:
62
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
63
  model = load_model(device)
64
 
65
- func = functools.partial(generate_image, model=model, device=device)
66
-
67
- gr.Interface(
68
- fn=func,
69
- inputs=[
70
- gr.Slider(label='Seed',
71
- minimum=0,
72
- maximum=100000,
73
- step=1,
74
- value=0,
75
- randomize=True),
76
- gr.Slider(label='Truncation psi',
77
- minimum=0,
78
- maximum=2,
79
- step=0.05,
80
- value=1.0),
81
- gr.Radio(label='Generator',
82
- choices=['student', 'teacher'],
83
- type='value',
84
- value='student'),
85
- ],
86
- outputs=gr.Image(label='Output', type='numpy'),
87
- title=TITLE,
88
- description=DESCRIPTION,
89
- article=ARTICLE,
90
- ).queue().launch(show_api=False)
 
 
 
 
 
3
  from __future__ import annotations
4
 
5
  import functools
 
6
 
7
  import gradio as gr
8
  import numpy as np
 
12
 
13
  from model import Model
14
 
15
+ DESCRIPTION = '# [MobileStyleGAN](https://github.com/bes-dev/MobileStyleGAN.pytorch)'
 
16
  SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/MobileStyleGAN/resolve/main/samples'
17
  ARTICLE = f'''## Generated images
18
  ### FFHQ
 
22
  ![FFHQ]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
23
  '''
24
 
 
 
25
 
26
  def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
27
  return torch.from_numpy(np.random.RandomState(seed).randn(
 
41
 
42
 
43
  def load_model(device: torch.device) -> nn.Module:
44
+ path = hf_hub_download('public-data/MobileStyleGAN',
45
+ 'models/mobilestylegan_ffhq_v2.pth')
 
46
  ckpt = torch.load(path)
47
  model = Model()
48
  model.load_state_dict(ckpt['state_dict'], strict=False)
 
57
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
58
  model = load_model(device)
59
 
60
+ fn = functools.partial(generate_image, model=model, device=device)
61
+
62
+ with gr.Blocks(css='style.css') as demo:
63
+ gr.Markdown(DESCRIPTION)
64
+ with gr.Row():
65
+ with gr.Column():
66
+ with gr.Group():
67
+ seed = gr.Slider(label='Seed',
68
+ minimum=0,
69
+ maximum=100000,
70
+ step=1,
71
+ value=0,
72
+ randomize=True)
73
+ psi = gr.Slider(label='Truncation psi',
74
+ minimum=0,
75
+ maximum=2,
76
+ step=0.05,
77
+ value=1.0)
78
+ generator = gr.Radio(label='Generator',
79
+ choices=['student', 'teacher'],
80
+ type='value',
81
+ value='student')
82
+ run_button = gr.Button('Run')
83
+ with gr.Column():
84
+ result = gr.Image(label='Output', type='numpy')
85
+ with gr.Row():
86
+ gr.Markdown(ARTICLE)
87
+
88
+ run_button.click(fn=fn, inputs=[seed, psi, generator], outputs=result)
89
+ demo.queue(max_size=10).launch()
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  git+https://github.com/fbcotter/pytorch_wavelets.git
2
- numpy==1.22.3
3
- Pillow==9.0.1
4
  piq==0.6.0
5
  PyWavelets==1.2.0
6
- scipy==1.8.0
7
- torch==1.11.0
8
- torchvision==0.12.0
 
1
  git+https://github.com/fbcotter/pytorch_wavelets.git
2
+ numpy==1.23.5
3
+ Pillow==10.0.0
4
  piq==0.6.0
5
  PyWavelets==1.2.0
6
+ scipy==1.10.1
7
+ torch==2.0.1
8
+ torchvision==0.15.2
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }