hysts HF staff commited on
Commit
532cb3d
1 Parent(s): 4c44f87

Migrate from yapf to black

Browse files
Files changed (5) hide show
  1. .pre-commit-config.yaml +54 -35
  2. .style.yapf +0 -5
  3. .vscode/settings.json +21 -0
  4. app.py +76 -97
  5. inference.py +12 -17
.pre-commit-config.yaml CHANGED
@@ -1,37 +1,56 @@
1
  exclude: patch
2
  repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.12.0
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- additional_dependencies: ['types-python-slugify']
33
- - repo: https://github.com/google/yapf
34
- rev: v0.32.0
35
- hooks:
36
- - id: yapf
37
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  exclude: patch
2
  repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.4.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ["--fix=lf"]
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.7.5
19
+ hooks:
20
+ - id: docformatter
21
+ args: ["--in-place"]
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ args: ["--profile", "black"]
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v1.5.1
29
+ hooks:
30
+ - id: mypy
31
+ args: ["--ignore-missing-imports"]
32
+ additional_dependencies:
33
+ ["types-python-slugify", "types-requests", "types-PyYAML"]
34
+ - repo: https://github.com/psf/black
35
+ rev: 23.9.1
36
+ hooks:
37
+ - id: black
38
+ language_version: python3.10
39
+ args: ["--line-length", "119"]
40
+ - repo: https://github.com/kynan/nbstripout
41
+ rev: 0.6.1
42
+ hooks:
43
+ - id: nbstripout
44
+ args:
45
+ [
46
+ "--extra-keys",
47
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
48
+ ]
49
+ - repo: https://github.com/nbQA-dev/nbQA
50
+ rev: 1.7.0
51
+ hooks:
52
+ - id: nbqa-black
53
+ - id: nbqa-pyupgrade
54
+ args: ["--py37-plus"]
55
+ - id: nbqa-isort
56
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
.vscode/settings.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "ms-python.black-formatter",
4
+ "editor.formatOnType": true,
5
+ "editor.codeActionsOnSave": {
6
+ "source.organizeImports": true
7
+ }
8
+ },
9
+ "black-formatter.args": [
10
+ "--line-length=119"
11
+ ],
12
+ "isort.args": ["--profile", "black"],
13
+ "flake8.args": [
14
+ "--max-line-length=119"
15
+ ],
16
+ "ruff.args": [
17
+ "--line-length=119"
18
+ ],
19
+ "editor.formatOnSave": true,
20
+ "files.insertFinalNewline": true
21
+ }
app.py CHANGED
@@ -18,89 +18,64 @@ class InferenceUtil:
18
  try:
19
  card = InferencePipeline.get_model_card(model_id, self.hf_token)
20
  except Exception:
21
- return '', ''
22
- base_model = getattr(card.data, 'base_model', '')
23
- training_prompt = getattr(card.data, 'training_prompt', '')
24
  return base_model, training_prompt
25
 
26
 
27
- DESCRIPTION = '# [Tune-A-Video](https://tuneavideo.github.io/)'
28
  if not torch.cuda.is_available():
29
- DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
30
 
31
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv(
32
- 'CACHE_EXAMPLES') == '1'
33
 
34
- HF_TOKEN = os.getenv('HF_TOKEN')
35
  pipe = InferencePipeline(HF_TOKEN)
36
  app = InferenceUtil(HF_TOKEN)
37
 
38
- with gr.Blocks(css='style.css') as demo:
39
  gr.Markdown(DESCRIPTION)
40
 
41
  with gr.Row():
42
  with gr.Column():
43
  with gr.Box():
44
  model_id = gr.Dropdown(
45
- label='Model ID',
46
  choices=[
47
- 'Tune-A-Video-library/a-man-is-surfing',
48
- 'Tune-A-Video-library/mo-di-bear-guitar',
49
- 'Tune-A-Video-library/redshift-man-skiing',
50
  ],
51
- value='Tune-A-Video-library/a-man-is-surfing')
52
- with gr.Accordion(
53
- label=
54
- 'Model info (Base model and prompt used for training)',
55
- open=False):
56
  with gr.Row():
57
- base_model_used_for_training = gr.Text(
58
- label='Base model', interactive=False)
59
- prompt_used_for_training = gr.Text(
60
- label='Training prompt', interactive=False)
61
- prompt = gr.Textbox(label='Prompt',
62
- max_lines=1,
63
- placeholder='Example: "A panda is surfing"')
64
- video_length = gr.Slider(label='Video length',
65
- minimum=4,
66
- maximum=12,
67
- step=1,
68
- value=8)
69
- fps = gr.Slider(label='FPS',
70
- minimum=1,
71
- maximum=12,
72
- step=1,
73
- value=1)
74
- seed = gr.Slider(label='Seed',
75
- minimum=0,
76
- maximum=100000,
77
- step=1,
78
- value=0)
79
- with gr.Accordion('Other Parameters', open=False):
80
- num_steps = gr.Slider(label='Number of Steps',
81
- minimum=0,
82
- maximum=100,
83
- step=1,
84
- value=50)
85
- guidance_scale = gr.Slider(label='CFG Scale',
86
- minimum=0,
87
- maximum=50,
88
- step=0.1,
89
- value=7.5)
90
-
91
- run_button = gr.Button('Generate')
92
-
93
- gr.Markdown('''
94
  - It takes a few minutes to download model first.
95
  - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
96
- ''')
 
97
  with gr.Column():
98
- result = gr.Video(label='Result')
99
  with gr.Row():
100
  examples = [
101
  [
102
- 'Tune-A-Video-library/a-man-is-surfing',
103
- 'A panda is surfing.',
104
  8,
105
  1,
106
  3,
@@ -108,8 +83,8 @@ with gr.Blocks(css='style.css') as demo:
108
  7.5,
109
  ],
110
  [
111
- 'Tune-A-Video-library/a-man-is-surfing',
112
- 'A racoon is surfing, cartoon style.',
113
  8,
114
  1,
115
  3,
@@ -117,8 +92,8 @@ with gr.Blocks(css='style.css') as demo:
117
  7.5,
118
  ],
119
  [
120
- 'Tune-A-Video-library/mo-di-bear-guitar',
121
- 'a handsome prince is playing guitar, modern disney style.',
122
  8,
123
  1,
124
  123,
@@ -126,8 +101,8 @@ with gr.Blocks(css='style.css') as demo:
126
  7.5,
127
  ],
128
  [
129
- 'Tune-A-Video-library/mo-di-bear-guitar',
130
- 'a magical princess is playing guitar, modern disney style.',
131
  8,
132
  1,
133
  123,
@@ -135,8 +110,8 @@ with gr.Blocks(css='style.css') as demo:
135
  7.5,
136
  ],
137
  [
138
- 'Tune-A-Video-library/mo-di-bear-guitar',
139
- 'a rabbit is playing guitar, modern disney style.',
140
  8,
141
  1,
142
  123,
@@ -144,8 +119,8 @@ with gr.Blocks(css='style.css') as demo:
144
  7.5,
145
  ],
146
  [
147
- 'Tune-A-Video-library/mo-di-bear-guitar',
148
- 'a baby is playing guitar, modern disney style.',
149
  8,
150
  1,
151
  123,
@@ -153,8 +128,8 @@ with gr.Blocks(css='style.css') as demo:
153
  7.5,
154
  ],
155
  [
156
- 'Tune-A-Video-library/redshift-man-skiing',
157
- '(redshift style) spider man is skiing.',
158
  8,
159
  1,
160
  123,
@@ -162,8 +137,8 @@ with gr.Blocks(css='style.css') as demo:
162
  7.5,
163
  ],
164
  [
165
- 'Tune-A-Video-library/redshift-man-skiing',
166
- '(redshift style) black widow is skiing.',
167
  8,
168
  1,
169
  123,
@@ -171,8 +146,8 @@ with gr.Blocks(css='style.css') as demo:
171
  7.5,
172
  ],
173
  [
174
- 'Tune-A-Video-library/redshift-man-skiing',
175
- '(redshift style) batman is skiing.',
176
  8,
177
  1,
178
  123,
@@ -180,8 +155,8 @@ with gr.Blocks(css='style.css') as demo:
180
  7.5,
181
  ],
182
  [
183
- 'Tune-A-Video-library/redshift-man-skiing',
184
- '(redshift style) hulk is skiing.',
185
  8,
186
  1,
187
  123,
@@ -189,26 +164,30 @@ with gr.Blocks(css='style.css') as demo:
189
  7.5,
190
  ],
191
  ]
192
- gr.Examples(examples=examples,
193
- inputs=[
194
- model_id,
195
- prompt,
196
- video_length,
197
- fps,
198
- seed,
199
- num_steps,
200
- guidance_scale,
201
- ],
202
- outputs=result,
203
- fn=pipe.run,
204
- cache_examples=CACHE_EXAMPLES)
205
-
206
- model_id.change(fn=app.load_model_info,
207
- inputs=model_id,
208
- outputs=[
209
- base_model_used_for_training,
210
- prompt_used_for_training,
211
- ])
 
 
 
 
212
  inputs = [
213
  model_id,
214
  prompt,
 
18
  try:
19
  card = InferencePipeline.get_model_card(model_id, self.hf_token)
20
  except Exception:
21
+ return "", ""
22
+ base_model = getattr(card.data, "base_model", "")
23
+ training_prompt = getattr(card.data, "training_prompt", "")
24
  return base_model, training_prompt
25
 
26
 
27
+ DESCRIPTION = "# [Tune-A-Video](https://tuneavideo.github.io/)"
28
  if not torch.cuda.is_available():
29
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
30
 
31
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
 
32
 
33
+ HF_TOKEN = os.getenv("HF_TOKEN")
34
  pipe = InferencePipeline(HF_TOKEN)
35
  app = InferenceUtil(HF_TOKEN)
36
 
37
+ with gr.Blocks(css="style.css") as demo:
38
  gr.Markdown(DESCRIPTION)
39
 
40
  with gr.Row():
41
  with gr.Column():
42
  with gr.Box():
43
  model_id = gr.Dropdown(
44
+ label="Model ID",
45
  choices=[
46
+ "Tune-A-Video-library/a-man-is-surfing",
47
+ "Tune-A-Video-library/mo-di-bear-guitar",
48
+ "Tune-A-Video-library/redshift-man-skiing",
49
  ],
50
+ value="Tune-A-Video-library/a-man-is-surfing",
51
+ )
52
+ with gr.Accordion(label="Model info (Base model and prompt used for training)", open=False):
 
 
53
  with gr.Row():
54
+ base_model_used_for_training = gr.Text(label="Base model", interactive=False)
55
+ prompt_used_for_training = gr.Text(label="Training prompt", interactive=False)
56
+ prompt = gr.Textbox(label="Prompt", max_lines=1, placeholder='Example: "A panda is surfing"')
57
+ video_length = gr.Slider(label="Video length", minimum=4, maximum=12, step=1, value=8)
58
+ fps = gr.Slider(label="FPS", minimum=1, maximum=12, step=1, value=1)
59
+ seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=0)
60
+ with gr.Accordion("Other Parameters", open=False):
61
+ num_steps = gr.Slider(label="Number of Steps", minimum=0, maximum=100, step=1, value=50)
62
+ guidance_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=50, step=0.1, value=7.5)
63
+
64
+ run_button = gr.Button("Generate")
65
+
66
+ gr.Markdown(
67
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  - It takes a few minutes to download model first.
69
  - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
70
+ """
71
+ )
72
  with gr.Column():
73
+ result = gr.Video(label="Result")
74
  with gr.Row():
75
  examples = [
76
  [
77
+ "Tune-A-Video-library/a-man-is-surfing",
78
+ "A panda is surfing.",
79
  8,
80
  1,
81
  3,
 
83
  7.5,
84
  ],
85
  [
86
+ "Tune-A-Video-library/a-man-is-surfing",
87
+ "A racoon is surfing, cartoon style.",
88
  8,
89
  1,
90
  3,
 
92
  7.5,
93
  ],
94
  [
95
+ "Tune-A-Video-library/mo-di-bear-guitar",
96
+ "a handsome prince is playing guitar, modern disney style.",
97
  8,
98
  1,
99
  123,
 
101
  7.5,
102
  ],
103
  [
104
+ "Tune-A-Video-library/mo-di-bear-guitar",
105
+ "a magical princess is playing guitar, modern disney style.",
106
  8,
107
  1,
108
  123,
 
110
  7.5,
111
  ],
112
  [
113
+ "Tune-A-Video-library/mo-di-bear-guitar",
114
+ "a rabbit is playing guitar, modern disney style.",
115
  8,
116
  1,
117
  123,
 
119
  7.5,
120
  ],
121
  [
122
+ "Tune-A-Video-library/mo-di-bear-guitar",
123
+ "a baby is playing guitar, modern disney style.",
124
  8,
125
  1,
126
  123,
 
128
  7.5,
129
  ],
130
  [
131
+ "Tune-A-Video-library/redshift-man-skiing",
132
+ "(redshift style) spider man is skiing.",
133
  8,
134
  1,
135
  123,
 
137
  7.5,
138
  ],
139
  [
140
+ "Tune-A-Video-library/redshift-man-skiing",
141
+ "(redshift style) black widow is skiing.",
142
  8,
143
  1,
144
  123,
 
146
  7.5,
147
  ],
148
  [
149
+ "Tune-A-Video-library/redshift-man-skiing",
150
+ "(redshift style) batman is skiing.",
151
  8,
152
  1,
153
  123,
 
155
  7.5,
156
  ],
157
  [
158
+ "Tune-A-Video-library/redshift-man-skiing",
159
+ "(redshift style) hulk is skiing.",
160
  8,
161
  1,
162
  123,
 
164
  7.5,
165
  ],
166
  ]
167
+ gr.Examples(
168
+ examples=examples,
169
+ inputs=[
170
+ model_id,
171
+ prompt,
172
+ video_length,
173
+ fps,
174
+ seed,
175
+ num_steps,
176
+ guidance_scale,
177
+ ],
178
+ outputs=result,
179
+ fn=pipe.run,
180
+ cache_examples=CACHE_EXAMPLES,
181
+ )
182
+
183
+ model_id.change(
184
+ fn=app.load_model_info,
185
+ inputs=model_id,
186
+ outputs=[
187
+ base_model_used_for_training,
188
+ prompt_used_for_training,
189
+ ],
190
+ )
191
  inputs = [
192
  model_id,
193
  prompt,
inference.py CHANGED
@@ -13,7 +13,7 @@ from diffusers.utils.import_utils import is_xformers_available
13
  from einops import rearrange
14
  from huggingface_hub import ModelCard
15
 
16
- sys.path.append('Tune-A-Video')
17
 
18
  from tuneavideo.models.unet import UNet3DConditionModel
19
  from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
@@ -23,8 +23,7 @@ class InferencePipeline:
23
  def __init__(self, hf_token: str | None = None):
24
  self.hf_token = hf_token
25
  self.pipe = None
26
- self.device = torch.device(
27
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
28
  self.model_id = None
29
 
30
  def clear(self) -> None:
@@ -39,10 +38,9 @@ class InferencePipeline:
39
  return pathlib.Path(model_id).exists()
40
 
41
  @staticmethod
42
- def get_model_card(model_id: str,
43
- hf_token: str | None = None) -> ModelCard:
44
  if InferencePipeline.check_if_model_is_local(model_id):
45
- card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
46
  else:
47
  card_path = model_id
48
  return ModelCard.load(card_path, token=hf_token)
@@ -57,14 +55,11 @@ class InferencePipeline:
57
  return
58
  base_model_id = self.get_base_model_info(model_id, self.hf_token)
59
  unet = UNet3DConditionModel.from_pretrained(
60
- model_id,
61
- subfolder='unet',
62
- torch_dtype=torch.float16,
63
- use_auth_token=self.hf_token)
64
- pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
65
- unet=unet,
66
- torch_dtype=torch.float16,
67
- use_auth_token=self.hf_token)
68
  pipe = pipe.to(self.device)
69
  if is_xformers_available():
70
  pipe.unet.enable_xformers_memory_efficient_attention()
@@ -82,7 +77,7 @@ class InferencePipeline:
82
  guidance_scale: float,
83
  ) -> PIL.Image.Image:
84
  if not torch.cuda.is_available():
85
- raise gr.Error('CUDA is not available.')
86
 
87
  self.load_pipe(model_id)
88
 
@@ -97,10 +92,10 @@ class InferencePipeline:
97
  generator=generator,
98
  ) # type: ignore
99
 
100
- frames = rearrange(out.videos[0], 'c t h w -> t h w c')
101
  frames = (frames * 255).to(torch.uint8).numpy()
102
 
103
- out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
104
  writer = imageio.get_writer(out_file.name, fps=fps)
105
  for frame in frames:
106
  writer.append_data(frame)
 
13
  from einops import rearrange
14
  from huggingface_hub import ModelCard
15
 
16
+ sys.path.append("Tune-A-Video")
17
 
18
  from tuneavideo.models.unet import UNet3DConditionModel
19
  from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
 
23
  def __init__(self, hf_token: str | None = None):
24
  self.hf_token = hf_token
25
  self.pipe = None
26
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
27
  self.model_id = None
28
 
29
  def clear(self) -> None:
 
38
  return pathlib.Path(model_id).exists()
39
 
40
  @staticmethod
41
+ def get_model_card(model_id: str, hf_token: str | None = None) -> ModelCard:
 
42
  if InferencePipeline.check_if_model_is_local(model_id):
43
+ card_path = (pathlib.Path(model_id) / "README.md").as_posix()
44
  else:
45
  card_path = model_id
46
  return ModelCard.load(card_path, token=hf_token)
 
55
  return
56
  base_model_id = self.get_base_model_info(model_id, self.hf_token)
57
  unet = UNet3DConditionModel.from_pretrained(
58
+ model_id, subfolder="unet", torch_dtype=torch.float16, use_auth_token=self.hf_token
59
+ )
60
+ pipe = TuneAVideoPipeline.from_pretrained(
61
+ base_model_id, unet=unet, torch_dtype=torch.float16, use_auth_token=self.hf_token
62
+ )
 
 
 
63
  pipe = pipe.to(self.device)
64
  if is_xformers_available():
65
  pipe.unet.enable_xformers_memory_efficient_attention()
 
77
  guidance_scale: float,
78
  ) -> PIL.Image.Image:
79
  if not torch.cuda.is_available():
80
+ raise gr.Error("CUDA is not available.")
81
 
82
  self.load_pipe(model_id)
83
 
 
92
  generator=generator,
93
  ) # type: ignore
94
 
95
+ frames = rearrange(out.videos[0], "c t h w -> t h w c")
96
  frames = (frames * 255).to(torch.uint8).numpy()
97
 
98
+ out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
99
  writer = imageio.get_writer(out_file.name, fps=fps)
100
  for frame in frames:
101
  writer.append_data(frame)