Pie31415 commited on
Commit
2c22172
1 Parent(s): d906598
.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ ###
132
+ .vscode/
133
+ *.pth
README.md CHANGED
@@ -1,14 +1,3 @@
1
- ---
2
- title: Control Animation
3
- emoji: 🔥
4
- sdk: gradio
5
- sdk_version: 3.23.0
6
- app_file: app.py
7
- pipeline_tag: text-to-video
8
- tags:
9
- - jax-diffusers-event
10
- ---
11
-
12
  # Control Animation
13
 
14
- Our code uses [Text2Video-Zero](https://github.com/Picsart-AI-Research/Text2Video-Zero) and the [Diffusers](https://github.com/huggingface/diffusers) library as inspiration.
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Control Animation
2
 
3
+ Our code uses [Text2Video-Zero](https://github.com/Picsart-AI-Research/Text2Video-Zero) and the [Diffusers](https://github.com/huggingface/diffusers) library as inspiration.
requirements.txt CHANGED
@@ -1,6 +1,13 @@
1
- git+https://github.com/huggingface/diffusers.git #diffusers==0.16.0.dev0
2
- -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
3
- jax[tpu]==0.4.5
 
 
 
 
 
 
 
4
  absl-py==1.4.0
5
  accelerate==0.16.0
6
  addict==2.4.0
@@ -39,7 +46,6 @@ fastapi==0.95.1
39
  ffmpy==0.3.0
40
  filelock==3.11.0
41
  flatbuffers==23.3.3
42
- flax==0.6.7
43
  fonttools==4.39.3
44
  frozenlist==1.3.3
45
  fsspec==2023.4.0
@@ -60,8 +66,6 @@ imageio-ffmpeg==0.4.2
60
  importlib-metadata==6.5.0
61
  importlib-resources==5.12.0
62
  invisible-watermark==0.1.5
63
- jax
64
- jaxlib==0.4.4
65
  Jinja2==3.1.2
66
  joblib==1.2.0
67
  jsonschema==4.17.3
@@ -89,7 +93,7 @@ onnx==1.13.1
89
  onnxruntime==1.14.1
90
  open-clip-torch==2.16.0
91
  opencv-contrib-python==4.7.0.72
92
- opencv-python==4.7.0.72
93
  opencv-python-headless==4.7.0.72
94
  opt-einsum==3.3.0
95
  optax==0.1.4
 
1
+ jax[cuda11_pip]
2
+ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
3
+ jaxlib
4
+ flax
5
+ git+https://github.com/huggingface/diffusers@main
6
+ opencv-python
7
+ torch
8
+ #git+https://github.com/huggingface/diffusers.git #diffusers==0.16.0.dev0
9
+ #-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
10
+ #jax[tpu]==0.4.5
11
  absl-py==1.4.0
12
  accelerate==0.16.0
13
  addict==2.4.0
 
46
  ffmpy==0.3.0
47
  filelock==3.11.0
48
  flatbuffers==23.3.3
 
49
  fonttools==4.39.3
50
  frozenlist==1.3.3
51
  fsspec==2023.4.0
 
66
  importlib-metadata==6.5.0
67
  importlib-resources==5.12.0
68
  invisible-watermark==0.1.5
 
 
69
  Jinja2==3.1.2
70
  joblib==1.2.0
71
  jsonschema==4.17.3
 
93
  onnxruntime==1.14.1
94
  open-clip-torch==2.16.0
95
  opencv-contrib-python==4.7.0.72
96
+ # opencv-python==4.7.0.72
97
  opencv-python-headless==4.7.0.72
98
  opt-einsum==3.3.0
99
  optax==0.1.4
text_to_animation/model.py CHANGED
@@ -6,6 +6,9 @@ import jax.numpy as jnp
6
  import tomesd
7
  import jax
8
 
 
 
 
9
  from flax.training.common_utils import shard
10
  from flax.jax_utils import replicate
11
  from flax import jax_utils
@@ -381,31 +384,12 @@ class ControlAnimationModel:
381
  result, fps, path=path, watermark=gradio_utils.logo_name_to_path(watermark)
382
  )
383
 
384
- def generate_animation(
385
- self,
386
- prompt: str,
387
- model_link: str = "dreamlike-art/dreamlike-photoreal-2.0",
388
- is_safetensor: bool = False,
389
- motion_field_strength_x: int = 12,
390
- motion_field_strength_y: int = 12,
391
- t0: int = 44,
392
- t1: int = 47,
393
- n_prompt: str = "",
394
- chunk_size: int = 8,
395
- video_length: int = 8,
396
- merging_ratio: float = 0.0,
397
- seed: int = 0,
398
- resolution: int = 512,
399
- fps: int = 2,
400
- use_cf_attn: bool = True,
401
- use_motion_field: bool = True,
402
- smooth_bg: bool = False,
403
- smooth_bg_strength: float = 0.4,
404
- path: str = None,
405
- ):
406
- if is_safetensor and model_link[-len(".safetensors") :] == ".safetensors":
407
- pipe = utils.load_safetensors_model(model_link)
408
- return
409
 
410
  def generate_initial_frames(
411
  self,
@@ -419,9 +403,8 @@ class ControlAnimationModel:
419
  # batch_size: int = 1,
420
  cfg_scale: float = 7.0,
421
  seed: int = 0,
422
- ):
423
- print(f">>> prompt: {prompt}, model_link: {model_link}")
424
-
425
  pipe = StableDiffusionPipeline.from_pretrained(model_link)
426
 
427
  batch_size = 4
@@ -434,6 +417,34 @@ class ControlAnimationModel:
434
  width=width,
435
  height=height,
436
  guidance_scale=cfg_scale,
 
437
  ).images
 
 
 
438
 
439
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import tomesd
7
  import jax
8
 
9
+ from PIL import Image
10
+ from typing import List
11
+
12
  from flax.training.common_utils import shard
13
  from flax.jax_utils import replicate
14
  from flax import jax_utils
 
384
  result, fps, path=path, watermark=gradio_utils.logo_name_to_path(watermark)
385
  )
386
 
387
+ @staticmethod
388
+ def to_pil_images(images: torch.Tensor) -> List[Image.Image]:
389
+ images = (images / 2 + 0.5).clamp(0, 1)
390
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
391
+ images = np.round(images * 255).astype(np.uint8)
392
+ return [Image.fromarray(image) for image in images]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
  def generate_initial_frames(
395
  self,
 
403
  # batch_size: int = 1,
404
  cfg_scale: float = 7.0,
405
  seed: int = 0,
406
+ ) -> List[Image.Image]:
407
+ generator = torch.Generator(device=self.device).manual_seed(seed)
 
408
  pipe = StableDiffusionPipeline.from_pretrained(model_link)
409
 
410
  batch_size = 4
 
417
  width=width,
418
  height=height,
419
  guidance_scale=cfg_scale,
420
+ generator=generator,
421
  ).images
422
+ pil_images = self.to_pil_images(images)
423
+
424
+ return pil_images
425
 
426
+ def generate_animation(
427
+ self,
428
+ prompt: str,
429
+ model_link: str = "dreamlike-art/dreamlike-photoreal-2.0",
430
+ is_safetensor: bool = False,
431
+ motion_field_strength_x: int = 12,
432
+ motion_field_strength_y: int = 12,
433
+ t0: int = 44,
434
+ t1: int = 47,
435
+ n_prompt: str = "",
436
+ chunk_size: int = 8,
437
+ video_length: int = 8,
438
+ merging_ratio: float = 0.0,
439
+ seed: int = 0,
440
+ resolution: int = 512,
441
+ fps: int = 2,
442
+ use_cf_attn: bool = True,
443
+ use_motion_field: bool = True,
444
+ smooth_bg: bool = False,
445
+ smooth_bg_strength: float = 0.4,
446
+ path: str = None,
447
+ ):
448
+ if is_safetensor and model_link[-len(".safetensors") :] == ".safetensors":
449
+ pipe = utils.load_safetensors_model(model_link)
450
+ return
webui/app_control_animation.py CHANGED
@@ -106,9 +106,7 @@ def create_demo(model: ControlAnimationModel):
106
  with gr.Column(scale=3):
107
  initial_frames = gr.Gallery(
108
  label="Initial Frames", show_label=False
109
- ).style(
110
- columns=[2], rows=[2], object_fit="scale-down", height="auto"
111
- )
112
  initial_frames.select(select_initial_frame)
113
  select_frame_button = gr.Button(
114
  value="Select Initial Frame", variant="secondary"
 
106
  with gr.Column(scale=3):
107
  initial_frames = gr.Gallery(
108
  label="Initial Frames", show_label=False
109
+ ).style(columns=4, object_fit="contain")
 
 
110
  initial_frames.select(select_initial_frame)
111
  select_frame_button = gr.Button(
112
  value="Select Initial Frame", variant="secondary"