hysts HF staff commited on
Commit
b565c7c
·
1 Parent(s): a0395fb
Files changed (4) hide show
  1. .pre-commit-config.yaml +60 -34
  2. .style.yapf +0 -5
  3. .vscode/settings.json +30 -0
  4. app.py +37 -63
.pre-commit-config.yaml CHANGED
@@ -1,35 +1,61 @@
 
1
  repos:
2
- - repo: https://github.com/pre-commit/pre-commit-hooks
3
- rev: v4.2.0
4
- hooks:
5
- - id: check-executables-have-shebangs
6
- - id: check-json
7
- - id: check-merge-conflict
8
- - id: check-shebang-scripts-are-executable
9
- - id: check-toml
10
- - id: check-yaml
11
- - id: double-quote-string-fixer
12
- - id: end-of-file-fixer
13
- - id: mixed-line-ending
14
- args: ['--fix=lf']
15
- - id: requirements-txt-fixer
16
- - id: trailing-whitespace
17
- - repo: https://github.com/myint/docformatter
18
- rev: v1.4
19
- hooks:
20
- - id: docformatter
21
- args: ['--in-place']
22
- - repo: https://github.com/pycqa/isort
23
- rev: 5.12.0
24
- hooks:
25
- - id: isort
26
- - repo: https://github.com/pre-commit/mirrors-mypy
27
- rev: v0.991
28
- hooks:
29
- - id: mypy
30
- args: ['--ignore-missing-imports']
31
- - repo: https://github.com/google/yapf
32
- rev: v0.32.0
33
- hooks:
34
- - id: yapf
35
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^patch
2
  repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.6.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ["--fix=lf"]
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.7.5
19
+ hooks:
20
+ - id: docformatter
21
+ args: ["--in-place"]
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.13.2
24
+ hooks:
25
+ - id: isort
26
+ args: ["--profile", "black"]
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v1.10.0
29
+ hooks:
30
+ - id: mypy
31
+ args: ["--ignore-missing-imports"]
32
+ additional_dependencies:
33
+ [
34
+ "types-python-slugify",
35
+ "types-requests",
36
+ "types-PyYAML",
37
+ "types-pytz",
38
+ ]
39
+ - repo: https://github.com/psf/black
40
+ rev: 24.4.2
41
+ hooks:
42
+ - id: black
43
+ language_version: python3.10
44
+ args: ["--line-length", "119"]
45
+ - repo: https://github.com/kynan/nbstripout
46
+ rev: 0.7.1
47
+ hooks:
48
+ - id: nbstripout
49
+ args:
50
+ [
51
+ "--extra-keys",
52
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
53
+ ]
54
+ - repo: https://github.com/nbQA-dev/nbQA
55
+ rev: 1.8.5
56
+ hooks:
57
+ - id: nbqa-black
58
+ - id: nbqa-pyupgrade
59
+ args: ["--py37-plus"]
60
+ - id: nbqa-isort
61
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
app.py CHANGED
@@ -16,24 +16,24 @@ import numpy as np
16
  import PIL.Image
17
  import torch
18
 
19
- if os.getenv('SYSTEM') == 'spaces':
20
- with open('patch') as f:
21
- subprocess.run(shlex.split('patch -p1'), cwd='gan-control', stdin=f)
22
 
23
- sys.path.insert(0, 'gan-control/src')
24
 
25
  from gan_control.inference.controller import Controller
26
 
27
- TITLE = 'GAN-Control'
28
- DESCRIPTION = 'https://github.com/amazon-research/gan-control'
29
 
30
 
31
  def download_models() -> None:
32
- model_dir = pathlib.Path('controller_age015id025exp02hai04ori02gam15')
33
  if not model_dir.exists():
34
  path = huggingface_hub.hf_hub_download(
35
- 'public-data/gan-control',
36
- 'controller_age015id025exp02hai04ori02gam15.tar.gz')
37
  with tarfile.open(path) as f:
38
  f.extractall()
39
 
@@ -55,33 +55,31 @@ def run(
55
  ) -> PIL.Image.Image:
56
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
57
  batch_size = nrows * ncols
58
- latent_size = controller.config.model_config['latent_size']
59
- latent = torch.from_numpy(
60
- np.random.RandomState(seed).randn(batch_size,
61
- latent_size)).float().to(device)
62
 
63
  initial_image_tensors, initial_latent_z, initial_latent_w = controller.gen_batch(
64
- latent=latent, truncation=truncation)
65
- res0 = controller.make_resized_grid_image(initial_image_tensors,
66
- nrow=ncols)
67
 
68
  pose_control = torch.tensor([[yaw, pitch, 0]], dtype=torch.float32)
69
  image_tensors, _, modified_latent_w = controller.gen_batch_by_controls(
70
- latent=initial_latent_w,
71
- input_is_latent=True,
72
- orientation=pose_control)
73
  res1 = controller.make_resized_grid_image(image_tensors, nrow=ncols)
74
 
75
  age_control = torch.tensor([[age]], dtype=torch.float32)
76
  image_tensors, _, modified_latent_w = controller.gen_batch_by_controls(
77
- latent=initial_latent_w, input_is_latent=True, age=age_control)
 
78
  res2 = controller.make_resized_grid_image(image_tensors, nrow=ncols)
79
 
80
- hair_color = torch.tensor([[hair_color_r, hair_color_g, hair_color_b]],
81
- dtype=torch.float32) / 255
82
  hair_color = torch.clamp(hair_color, 0, 1)
83
  image_tensors, _, modified_latent_w = controller.gen_batch_by_controls(
84
- latent=initial_latent_w, input_is_latent=True, hair=hair_color)
 
85
  res3 = controller.make_resized_grid_image(image_tensors, nrow=ncols)
86
 
87
  return res0, res1, res2, res3
@@ -89,54 +87,30 @@ def run(
89
 
90
  download_models()
91
 
92
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
93
- path = 'controller_age015id025exp02hai04ori02gam15/'
94
  controller = Controller(path, device)
95
  fn = functools.partial(run, controller=controller, device=device)
96
 
97
  gr.Interface(
98
  fn=fn,
99
  inputs=[
100
- gr.Slider(label='Seed', minimum=0, maximum=1000000, step=1, value=0),
101
- gr.Slider(label='Truncation',
102
- minimum=0,
103
- maximum=1,
104
- step=0.1,
105
- value=0.7),
106
- gr.Slider(label='Yaw', minimum=-90, maximum=90, step=1, value=30),
107
- gr.Slider(label='Pitch', minimum=-90, maximum=90, step=1, value=0),
108
- gr.Slider(label='Age', minimum=15, maximum=75, step=1, value=75),
109
- gr.Slider(label='Hair Color (R)',
110
- minimum=0,
111
- maximum=255,
112
- step=1,
113
- value=186),
114
- gr.Slider(label='Hair Color (G)',
115
- minimum=0,
116
- maximum=255,
117
- step=1,
118
- value=158),
119
- gr.Slider(label='Hair Color (B)',
120
- minimum=0,
121
- maximum=255,
122
- step=1,
123
- value=92),
124
- gr.Slider(label='Number of Rows',
125
- minimum=1,
126
- maximum=3,
127
- step=1,
128
- value=1),
129
- gr.Slider(label='Number of Columns',
130
- minimum=1,
131
- maximum=5,
132
- step=1,
133
- value=5),
134
  ],
135
  outputs=[
136
- gr.Image(label='Generated Image', type='pil'),
137
- gr.Image(label='Head Pose Controlled', type='pil'),
138
- gr.Image(label='Age Controlled', type='pil'),
139
- gr.Image(label='Hair Color Controlled', type='pil'),
140
  ],
141
  title=TITLE,
142
  description=DESCRIPTION,
 
16
  import PIL.Image
17
  import torch
18
 
19
+ if os.getenv("SYSTEM") == "spaces":
20
+ with open("patch") as f:
21
+ subprocess.run(shlex.split("patch -p1"), cwd="gan-control", stdin=f)
22
 
23
+ sys.path.insert(0, "gan-control/src")
24
 
25
  from gan_control.inference.controller import Controller
26
 
27
+ TITLE = "GAN-Control"
28
+ DESCRIPTION = "https://github.com/amazon-research/gan-control"
29
 
30
 
31
  def download_models() -> None:
32
+ model_dir = pathlib.Path("controller_age015id025exp02hai04ori02gam15")
33
  if not model_dir.exists():
34
  path = huggingface_hub.hf_hub_download(
35
+ "public-data/gan-control", "controller_age015id025exp02hai04ori02gam15.tar.gz"
36
+ )
37
  with tarfile.open(path) as f:
38
  f.extractall()
39
 
 
55
  ) -> PIL.Image.Image:
56
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
57
  batch_size = nrows * ncols
58
+ latent_size = controller.config.model_config["latent_size"]
59
+ latent = torch.from_numpy(np.random.RandomState(seed).randn(batch_size, latent_size)).float().to(device)
 
 
60
 
61
  initial_image_tensors, initial_latent_z, initial_latent_w = controller.gen_batch(
62
+ latent=latent, truncation=truncation
63
+ )
64
+ res0 = controller.make_resized_grid_image(initial_image_tensors, nrow=ncols)
65
 
66
  pose_control = torch.tensor([[yaw, pitch, 0]], dtype=torch.float32)
67
  image_tensors, _, modified_latent_w = controller.gen_batch_by_controls(
68
+ latent=initial_latent_w, input_is_latent=True, orientation=pose_control
69
+ )
 
70
  res1 = controller.make_resized_grid_image(image_tensors, nrow=ncols)
71
 
72
  age_control = torch.tensor([[age]], dtype=torch.float32)
73
  image_tensors, _, modified_latent_w = controller.gen_batch_by_controls(
74
+ latent=initial_latent_w, input_is_latent=True, age=age_control
75
+ )
76
  res2 = controller.make_resized_grid_image(image_tensors, nrow=ncols)
77
 
78
+ hair_color = torch.tensor([[hair_color_r, hair_color_g, hair_color_b]], dtype=torch.float32) / 255
 
79
  hair_color = torch.clamp(hair_color, 0, 1)
80
  image_tensors, _, modified_latent_w = controller.gen_batch_by_controls(
81
+ latent=initial_latent_w, input_is_latent=True, hair=hair_color
82
+ )
83
  res3 = controller.make_resized_grid_image(image_tensors, nrow=ncols)
84
 
85
  return res0, res1, res2, res3
 
87
 
88
  download_models()
89
 
90
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
91
+ path = "controller_age015id025exp02hai04ori02gam15/"
92
  controller = Controller(path, device)
93
  fn = functools.partial(run, controller=controller, device=device)
94
 
95
  gr.Interface(
96
  fn=fn,
97
  inputs=[
98
+ gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=0),
99
+ gr.Slider(label="Truncation", minimum=0, maximum=1, step=0.1, value=0.7),
100
+ gr.Slider(label="Yaw", minimum=-90, maximum=90, step=1, value=30),
101
+ gr.Slider(label="Pitch", minimum=-90, maximum=90, step=1, value=0),
102
+ gr.Slider(label="Age", minimum=15, maximum=75, step=1, value=75),
103
+ gr.Slider(label="Hair Color (R)", minimum=0, maximum=255, step=1, value=186),
104
+ gr.Slider(label="Hair Color (G)", minimum=0, maximum=255, step=1, value=158),
105
+ gr.Slider(label="Hair Color (B)", minimum=0, maximum=255, step=1, value=92),
106
+ gr.Slider(label="Number of Rows", minimum=1, maximum=3, step=1, value=1),
107
+ gr.Slider(label="Number of Columns", minimum=1, maximum=5, step=1, value=5),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  ],
109
  outputs=[
110
+ gr.Image(label="Generated Image", type="pil"),
111
+ gr.Image(label="Head Pose Controlled", type="pil"),
112
+ gr.Image(label="Age Controlled", type="pil"),
113
+ gr.Image(label="Hair Color Controlled", type="pil"),
114
  ],
115
  title=TITLE,
116
  description=DESCRIPTION,