diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..3476c46f3aef9cad29bfd49275684b4add6d470c
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,167 @@
+outputs/
+processed/
+profile/
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+docs/.build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# IDE
+.idea/
+.vscode/
+
+# macos
+*.DS_Store
+#data/
+
+docs/.build
+
+# pytorch checkpoint
+*.pt
+
+# ignore any kernel build files
+.o
+.so
+
+# ignore python interface defition file
+.pyi
+
+# ignore coverage test file
+coverage.lcov
+coverage.xml
+
+# ignore testmon and coverage files
+.coverage
+.testmondata*
+
+pretrained
+samples
+cache_dir
+test_outputs
diff --git a/.isort.cfg b/.isort.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..ccbf575fdbfacd185cf880431ad81462e0ae8fdf
--- /dev/null
+++ b/.isort.cfg
@@ -0,0 +1,7 @@
+[settings]
+line_length = 120
+multi_line_output=3
+include_trailing_comma = true
+ignore_comments = true
+profile = black
+honor_noqa = true
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9871e1184462f2069071ea8db96495b20059d645
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,39 @@
+repos:
+
+ - repo: https://github.com/PyCQA/autoflake
+ rev: v2.2.1
+ hooks:
+ - id: autoflake
+ name: autoflake (python)
+ args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
+
+ - repo: https://github.com/pycqa/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ name: sort all imports (python)
+
+ - repo: https://github.com/psf/black-pre-commit-mirror
+ rev: 23.9.1
+ hooks:
+ - id: black
+ name: black formatter
+ args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
+
+ - repo: https://github.com/pre-commit/mirrors-clang-format
+ rev: v13.0.1
+ hooks:
+ - id: clang-format
+ name: clang formatter
+ types_or: [c++, c]
+
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.3.0
+ hooks:
+ - id: check-yaml
+ - id: check-merge-conflict
+ - id: check-case-conflict
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ - id: mixed-line-ending
+ args: ['--fix=lf']
diff --git a/app.py b/app.py
index 420354c347a399f11b2a1072f587fdca1927e421..a4d31e717babb8d82ae6a2d431dcdae395dd3c2c 100644
--- a/app.py
+++ b/app.py
@@ -2,131 +2,107 @@ import os
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
-import torch
-from openai import OpenAI
-from time import time
-import tempfile
-import uuid
import logging
+import uuid
+
+import GPUtil
import gradio as gr
-from videosys import CogVideoConfig, VideoSysEngine
-from videosys.models.cogvideo.pipeline import CogVideoPABConfig
import psutil
-import GPUtil
-
+import torch
+from videosys import CogVideoXConfig, CogVideoXPABConfig, VideoSysEngine
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
-dtype = torch.bfloat16
-sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
-
-For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
-There are a few rules to follow:
-
-You will only ever output a single video description per user request.
-
-When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
-Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
-
-Video descriptions must have the same num of words as examples below. Extra words will be ignored.
-"""
+dtype = torch.float16
-def convert_prompt(prompt: str, retry_times: int = 3) -> str:
- if not os.environ.get("OPENAI_API_KEY"):
- return prompt
- client = OpenAI()
- text = prompt.strip()
-
- for i in range(retry_times):
- response = client.chat.completions.create(
- messages=[
- {"role": "system", "content": sys_prompt},
- {
- "role": "user",
- "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
- },
- {
- "role": "assistant",
- "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
- },
- {
- "role": "user",
- "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
- },
- {
- "role": "assistant",
- "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
- },
- {
- "role": "user",
- "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
- },
- {
- "role": "assistant",
- "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
- },
- {
- "role": "user",
- "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
- },
- ],
- model="glm-4-0520",
- temperature=0.01,
- top_p=0.7,
- stream=False,
- max_tokens=250,
- )
- if response.choices:
- return response.choices[0].message.content
- return prompt
-def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_gap=2):
- pab_config = CogVideoPABConfig(full_threshold=pab_threshold, full_gap=pab_gap)
- config = CogVideoConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
+def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_range=2):
+ pab_config = CogVideoXPABConfig(spatial_threshold=pab_threshold, spatial_range=pab_range)
+ config = CogVideoXConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
engine = VideoSysEngine(config)
return engine
+
def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
- try:
- video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
+ video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
- temp_file.name
- unique_filename = f"{uuid.uuid4().hex}.mp4"
- output_path = os.path.join("./temp_outputs", unique_filename)
+ unique_filename = f"{uuid.uuid4().hex}.mp4"
+ output_path = os.path.join("./.tmp_outputs", unique_filename)
- engine.save_video(video, output_path)
- return output_path
- except Exception as e:
- logger.error(f"An error occurred: {str(e)}")
- return None
+ engine.save_video(video, output_path)
+ return output_path
def get_server_status():
cpu_percent = psutil.cpu_percent()
memory = psutil.virtual_memory()
- disk = psutil.disk_usage('/')
+ disk = psutil.disk_usage("/")
gpus = GPUtil.getGPUs()
gpu_info = []
for gpu in gpus:
- gpu_info.append({
- 'id': gpu.id,
- 'name': gpu.name,
- 'load': f"{gpu.load*100:.1f}%",
- 'memory_used': f"{gpu.memoryUsed}MB",
- 'memory_total': f"{gpu.memoryTotal}MB"
- })
-
+ gpu_info.append(
+ {
+ "id": gpu.id,
+ "name": gpu.name,
+ "load": f"{gpu.load*100:.1f}%",
+ "memory_used": f"{gpu.memoryUsed}MB",
+ "memory_total": f"{gpu.memoryTotal}MB",
+ }
+ )
+
+ return {"cpu": f"{cpu_percent}%", "memory": f"{memory.percent}%", "disk": f"{disk.percent}%", "gpu": gpu_info}
+
+
+def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
+ engine = load_model()
+ video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
+ return video_path
+
+
+def generate_vs(
+ prompt,
+ num_inference_steps,
+ guidance_scale,
+ threshold_start,
+ threshold_end,
+ gap,
+ progress=gr.Progress(track_tqdm=True),
+):
+ threshold = [int(threshold_end), int(threshold_start)]
+ gap = int(gap)
+ engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_range=gap)
+ video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
+ return video_path
+
+
+def get_server_status():
+ cpu_percent = psutil.cpu_percent()
+ memory = psutil.virtual_memory()
+ disk = psutil.disk_usage("/")
+ try:
+ gpus = GPUtil.getGPUs()
+ if gpus:
+ gpu = gpus[0]
+ gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
+ else:
+ gpu_memory = "No GPU found"
+ except:
+ gpu_memory = "GPU information unavailable"
+
return {
- 'cpu': f"{cpu_percent}%",
- 'memory': f"{memory.percent}%",
- 'disk': f"{disk.percent}%",
- 'gpu': gpu_info
+ "cpu": f"{cpu_percent}%",
+ "memory": f"{memory.percent}%",
+ "disk": f"{disk.percent}%",
+ "gpu_memory": gpu_memory,
}
+def update_server_status():
+ status = get_server_status()
+ return (status["cpu"], status["memory"], status["disk"], status["gpu_memory"])
+
css = """
body {
@@ -137,16 +113,17 @@ body {
padding: 20px;
}
+
.container {
display: flex;
flex-direction: column;
- gap: 20px;
+ gap: 10px;
}
.row {
display: flex;
flex-wrap: wrap;
- gap: 18px;
+ gap: 10px;
}
.column {
@@ -186,12 +163,6 @@ body {
font-size: 0.9em !important;
line-height: 1.2 !important;
}
-.server-status button {
- padding: 1px 8px !important;
- height: 22px !important;
- font-size: 0.9em !important;
- margin-top: 2px !important;
-}
.server-status .textbox {
gap: 0 !important;
}
@@ -215,150 +186,76 @@ body {
"""
with gr.Blocks(css=css) as demo:
- gr.HTML("""
+ gr.HTML(
+ """
- VideoSys Huggingface Space🤗
+ VideoSys for CogVideoX🤗
🌐 Github:
https://github.com/NUS-HPC-AI-Lab/VideoSys
-
- ⚠️ This demo is for academic research and experiential use only.
+
+ ⚠️ This demo is for academic research and experiential use only.
Users should strictly adhere to local laws and ethics.
-
+
💡 This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.
- """)
+ """
+ )
with gr.Row():
with gr.Column():
- prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=5)
- with gr.Row():
- gr.Markdown(
- "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
- )
- enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
+ prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=4)
with gr.Column():
- gr.Markdown(
- "**Optional Parameters** (default values are recommended)
"
- "Turn Inference Steps larger if you want more detailed video, but it will be slower.
"
- "50 steps are recommended for most cases. will cause 120 seconds for inference.
"
- )
+ gr.Markdown("**Generation Parameters**
")
with gr.Row():
num_inference_steps = gr.Number(label="Inference Steps", value=50)
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
- pab_gap = gr.Number(label="PAB Gap", value=2, precision=0)
- pab_threshold = gr.Textbox(label="PAB Threshold", value="100,850", lines=1)
with gr.Row():
- generate_button = gr.Button("🎬 Generate Video")
+ pab_range = gr.Number(
+ label="PAB Broadcast Range", value=2, precision=0, info="Broadcast timesteps range."
+ )
+ pab_threshold_start = gr.Number(label="PAB Start Timestep", value=850, info="Start from step 1000.")
+ pab_threshold_end = gr.Number(label="PAB End Timestep", value=100, info="End at step 0.")
+ with gr.Row():
generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
+ generate_button = gr.Button("🎬 Generate Video (Original)")
with gr.Column(elem_classes="server-status"):
gr.Markdown("#### Server Status")
-
+
with gr.Row():
cpu_status = gr.Textbox(label="CPU", scale=1)
memory_status = gr.Textbox(label="Memory", scale=1)
-
+
with gr.Row():
disk_status = gr.Textbox(label="Disk", scale=1)
gpu_status = gr.Textbox(label="GPU Memory", scale=1)
-
+
with gr.Row():
- refresh_button = gr.Button("Refresh", size="sm")
+ refresh_button = gr.Button("Refresh")
with gr.Column():
- with gr.Row():
- video_output = gr.Video(label="CogVideoX", width=720, height=480)
- with gr.Row():
- download_video_button = gr.File(label="📥 Download Video", visible=False)
- elapsed_time = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
with gr.Row():
video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
with gr.Row():
- download_video_button_vs = gr.File(label="📥 Download Video", visible=False)
- elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
- # with gr.Column():
- # task_status = gr.Textbox(label="任务状态", visible=False)
-
-
-
-
- def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
- engine = load_model()
- t = time()
- video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
- elapsed_time = time() - t
- video_update = gr.update(visible=True, value=video_path)
- elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
-
- return video_path, video_update, elapsed_time
-
- def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
- threshold = [int(i) for i in threshold.split(",")]
- gap = int(gap)
- engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap)
- t = time()
- video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
- elapsed_time = time() - t
- video_update = gr.update(visible=True, value=video_path)
- elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
-
- return video_path, video_update, elapsed_time
-
- def enhance_prompt_func(prompt):
- return convert_prompt(prompt, retry_times=1)
-
- def get_server_status():
- cpu_percent = psutil.cpu_percent()
- memory = psutil.virtual_memory()
- disk = psutil.disk_usage('/')
- try:
- gpus = GPUtil.getGPUs()
- if gpus:
- gpu = gpus[0]
- gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
- else:
- gpu_memory = "No GPU found"
- except:
- gpu_memory = "GPU information unavailable"
-
- return {
- 'cpu': f"{cpu_percent}%",
- 'memory': f"{memory.percent}%",
- 'disk': f"{disk.percent}%",
- 'gpu_memory': gpu_memory
- }
-
-
- def update_server_status():
- status = get_server_status()
- return (
- status['cpu'],
- status['memory'],
- status['disk'],
- status['gpu_memory']
- )
+ video_output = gr.Video(label="CogVideoX", width=720, height=480)
-
generate_button.click(
generate_vanilla,
inputs=[prompt, num_inference_steps, guidance_scale],
- outputs=[video_output, download_video_button, elapsed_time],
+ outputs=[video_output],
)
generate_button_vs.click(
generate_vs,
- inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold, pab_gap],
- outputs=[video_output_vs, download_video_button_vs, elapsed_time_vs],
+ inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold_start, pab_threshold_end, pab_range],
+ outputs=[video_output_vs],
)
- enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
-
-
refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status])
demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)
if __name__ == "__main__":
demo.queue(max_size=10, default_concurrency_limit=1)
- demo.launch()
\ No newline at end of file
+ demo.launch()
diff --git a/docs/dsp.md b/docs/dsp.md
deleted file mode 100644
index 2a08cbc44db772909cb7763d449f0a6df51f10bb..0000000000000000000000000000000000000000
--- a/docs/dsp.md
+++ /dev/null
@@ -1,25 +0,0 @@
-# DSP
-
-paper: https://arxiv.org/abs/2403.10266
-
-![dsp_overview](../assets/figures/dsp_overview.png)
-
-
-DSP (Dynamic Sequence Parallelism) is a novel, elegant and super efficient sequence parallelism for [OpenSora](https://github.com/hpcaitech/Open-Sora), [Latte](https://github.com/Vchitect/Latte) and other multi-dimensional transformer architecture.
-
-The key idea is to dynamically switch the parallelism dimension according to the current computation stage, leveraging the potential characteristics of multi-dimensional transformers. Compared with splitting head and sequence dimension as previous methods, it can reduce at least 75% of communication cost.
-
-It achieves **3x** speed for training and **2x** speed for inference in OpenSora compared with sota sequence parallelism ([DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509)). For a 10s (80 frames) of 512x512 video, the inference latency of OpenSora is:
-
-| Method | 1xH800 | 8xH800 (DS Ulysses) | 8xH800 (DSP) |
-| ------ | ------ | ------ | ------ |
-| Latency(s) | 106 | 45 | 22 |
-
-The following is DSP's end-to-end throughput for training of OpenSora:
-
-![dsp_overview](../assets/figures/dsp_exp.png)
-
-
-### Usage
-
-DSP is currently supported for: OpenSora, OpenSoraPlan and Latte. To enable DSP, you just need to launch with multiple GPUs.
diff --git a/docs/pab.md b/docs/pab.md
deleted file mode 100644
index 0de7b98139e52b17edc5fd43e5aad1a5f9a01525..0000000000000000000000000000000000000000
--- a/docs/pab.md
+++ /dev/null
@@ -1,121 +0,0 @@
-# Pyramid Attention Broadcast(PAB)
-
-[[paper](https://arxiv.org/abs/2408.12588)][[blog](https://arxiv.org/abs/2403.10266)]
-
-Pyramid Attention Broadcast(PAB)(#pyramid-attention-broadcastpab)
-- [Pyramid Attention Broadcast(PAB)](#pyramid-attention-broadcastpab)
- - [Insights](#insights)
- - [Pyramid Attention Broadcast (PAB) Mechanism](#pyramid-attention-broadcast-pab-mechanism)
- - [Experimental Results](#experimental-results)
- - [Usage](#usage)
- - [Supported Models](#supported-models)
- - [Configuration for PAB](#configuration-for-pab)
- - [Parameters](#parameters)
- - [Example Configuration](#example-configuration)
-
-
-We introduce Pyramid Attention Broadcast (PAB), the first approach that achieves real-time DiT-based video generation. By mitigating redundant attention computation, PAB achieves up to 21.6 FPS with 10.6x acceleration, without sacrificing quality across popular DiT-based video generation models including Open-Sora, Open-Sora-Plan, and Latte. Notably, as a training-free approach, PAB can enpower any future DiT-based video generation models with real-time capabilities.
-
-## Insights
-
-![method](../assets/figures/pab_motivation.png)
-
-Our study reveals two key insights of three **attention mechanisms** within video diffusion transformers:
-- First, attention differences across time steps exhibit a U-shaped pattern, with significant variations occurring during the first and last 15% of steps, while the middle 70% of steps show very stable, minor differences.
-- Second, within the stable middle segment, the variability differs among attention types:
- - **Spatial attention** varies the most, involving high-frequency elements like edges and textures;
- - **Temporal attention** exhibits mid-frequency variations related to movements and dynamics in videos;
- - **Cross-modal attention** is the most stable, linking text with video content, analogous to low-frequency signals reflecting textual semantics.
-
-## Pyramid Attention Broadcast (PAB) Mechanism
-
-![method](../assets/figures/pab_method.png)
-
-Building on these insights, we propose a **pyramid attention broadcast(PAB)** mechanism to minimize unnecessary computations and optimize the utility of each attention module, as shown in Figure[xx figure] below.
-
-In the middle segment, we broadcast one step's attention outputs to its subsequent several steps, thereby significantly reducing the computational cost on attention modules.
-
-For more efficient broadcast and minimum influence to effect, we set varied broadcast ranges for different attentions based on their stability and differences.
-**The smaller the variation in attention, the broader the potential broadcast range.**
-
-
-## Experimental Results
-Here are the results of our experiments, more results are shown in https://oahzxl.github.io/PAB:
-
-![pab_vis](../assets/figures/pab_vis.png)
-
-
-## Usage
-
-### Supported Models
-
-PAB currently supports Open-Sora, Open-Sora-Plan, and Latte.
-
-### Configuration for PAB
-
-To efficiently use the Pyramid Attention Broadcast (PAB) mechanism, configure the following parameters to control the broadcasting for different attention types. This helps reduce computational costs by skipping certain steps based on attention stability.
-
-#### Parameters
-
-- **spatial_broadcast**: Enable or disable broadcasting for spatial attention.
- - Type: `True` or `False`
-
-- **spatial_threshold**: Set the range of diffusion steps within which spatial attention is applied.
- - Format: `[min_value, max_value]`
-
-- **spatial_gap**: Number of blocks in model to skip during broadcasting for spatial attention.
- - Type: Integer
-
-- **temporal_broadcast**: Enable or disable broadcasting for temporal attention.
- - Type: `True` or `False`
-
-- **temporal_threshold**: Set the range of diffusion steps within which temporal attention is applied.
- - Format: `[min_value, max_value]`
-
-- **temporal_gap**: Number of steps to skip during broadcasting for temporal attention.
- - Type: Integer
-
-- **cross_broadcast**: Enable or disable broadcasting for cross-modal attention.
- - Type: `True` or `False`
-
-- **cross_threshold**: Set the range of diffusion steps within which cross-modal attention is applied.
- - Format: `[min_value, max_value]`
-
-- **cross_gap**: Number of steps to skip during broadcasting for cross-modal attention.
- - Type: Integer
-
-#### Example Configuration
-
-```yaml
-spatial_broadcast: True
-spatial_threshold: [100, 800]
-spatial_gap: 2
-
-temporal_broadcast: True
-temporal_threshold: [100, 800]
-temporal_gap: 3
-
-cross_broadcast: True
-cross_threshold: [100, 900]
-cross_gap: 5
-```
-
-Explanation:
-
-- **Spatial Attention**:
- - Broadcasting enabled (`spatial_broadcast: True`)
- - Applied within the threshold range of 100 to 800
- - Skips every 2 steps (`spatial_gap: 2`)
- - Active within the first 28 steps (`spatial_block: [0, 28]`)
-
-- **Temporal Attention**:
- - Broadcasting enabled (`temporal_broadcast: True`)
- - Applied within the threshold range of 100 to 800
- - Skips every 3 steps (`temporal_gap: 3`)
-
-- **Cross-Modal Attention**:
- - Broadcasting enabled (`cross_broadcast: True`)
- - Applied within the threshold range of 100 to 900
- - Skips every 5 steps (`cross_gap: 5`)
-
-Adjust these settings based on your specific needs to optimize the performance of each attention mechanism.
diff --git a/eval/pab/commom_metrics/README.md b/eval/pab/commom_metrics/README.md
deleted file mode 100644
index 1e595d9229094f1f165146b680843f93138a577d..0000000000000000000000000000000000000000
--- a/eval/pab/commom_metrics/README.md
+++ /dev/null
@@ -1,6 +0,0 @@
-Common metrics
-
-Include LPIPS, PSNR and SSIM.
-
-The code is adapted from [common_metrics_on_video_quality
-](https://github.com/JunyaoHu/common_metrics_on_video_quality).
diff --git a/eval/pab/commom_metrics/calculate_lpips.py b/eval/pab/commom_metrics/calculate_lpips.py
deleted file mode 100644
index 9d9efcf24235f6d91701541ab8acfa7279bbecf4..0000000000000000000000000000000000000000
--- a/eval/pab/commom_metrics/calculate_lpips.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import lpips
-import numpy as np
-import torch
-
-spatial = True # Return a spatial map of perceptual distance.
-
-# Linearly calibrated models (LPIPS)
-loss_fn = lpips.LPIPS(net="alex", spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
-# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
-
-
-def trans(x):
- # if greyscale images add channel
- if x.shape[-3] == 1:
- x = x.repeat(1, 1, 3, 1, 1)
-
- # value range [0, 1] -> [-1, 1]
- x = x * 2 - 1
-
- return x
-
-
-def calculate_lpips(videos1, videos2, device):
- # image should be RGB, IMPORTANT: normalized to [-1,1]
-
- assert videos1.shape == videos2.shape
-
- # videos [batch_size, timestamps, channel, h, w]
-
- # support grayscale input, if grayscale -> channel*3
- # value range [0, 1] -> [-1, 1]
- videos1 = trans(videos1)
- videos2 = trans(videos2)
-
- lpips_results = []
-
- for video_num in range(videos1.shape[0]):
- # get a video
- # video [timestamps, channel, h, w]
- video1 = videos1[video_num]
- video2 = videos2[video_num]
-
- lpips_results_of_a_video = []
- for clip_timestamp in range(len(video1)):
- # get a img
- # img [timestamps[x], channel, h, w]
- # img [channel, h, w] tensor
-
- img1 = video1[clip_timestamp].unsqueeze(0).to(device)
- img2 = video2[clip_timestamp].unsqueeze(0).to(device)
-
- loss_fn.to(device)
-
- # calculate lpips of a video
- lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
- lpips_results.append(lpips_results_of_a_video)
-
- lpips_results = np.array(lpips_results)
-
- lpips = {}
- lpips_std = {}
-
- for clip_timestamp in range(len(video1)):
- lpips[clip_timestamp] = np.mean(lpips_results[:, clip_timestamp])
- lpips_std[clip_timestamp] = np.std(lpips_results[:, clip_timestamp])
-
- result = {
- "value": lpips,
- "value_std": lpips_std,
- "video_setting": video1.shape,
- "video_setting_name": "time, channel, heigth, width",
- }
-
- return result
-
-
-# test code / using example
-
-
-def main():
- NUMBER_OF_VIDEOS = 8
- VIDEO_LENGTH = 50
- CHANNEL = 3
- SIZE = 64
- videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
- videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
- device = torch.device("cuda")
- # device = torch.device("cpu")
-
- import json
-
- result = calculate_lpips(videos1, videos2, device)
- print(json.dumps(result, indent=4))
-
-
-if __name__ == "__main__":
- main()
diff --git a/eval/pab/commom_metrics/calculate_psnr.py b/eval/pab/commom_metrics/calculate_psnr.py
deleted file mode 100644
index 416bc48a94e5fa5c242cd92a89c9b165e522b86b..0000000000000000000000000000000000000000
--- a/eval/pab/commom_metrics/calculate_psnr.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import math
-
-import numpy as np
-import torch
-
-
-def img_psnr(img1, img2):
- # [0,1]
- # compute mse
- # mse = np.mean((img1-img2)**2)
- mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
- # compute psnr
- if mse < 1e-10:
- return 100
- psnr = 20 * math.log10(1 / math.sqrt(mse))
- return psnr
-
-
-def trans(x):
- return x
-
-
-def calculate_psnr(videos1, videos2):
- # videos [batch_size, timestamps, channel, h, w]
-
- assert videos1.shape == videos2.shape
-
- videos1 = trans(videos1)
- videos2 = trans(videos2)
-
- psnr_results = []
-
- for video_num in range(videos1.shape[0]):
- # get a video
- # video [timestamps, channel, h, w]
- video1 = videos1[video_num]
- video2 = videos2[video_num]
-
- psnr_results_of_a_video = []
- for clip_timestamp in range(len(video1)):
- # get a img
- # img [timestamps[x], channel, h, w]
- # img [channel, h, w] numpy
-
- img1 = video1[clip_timestamp].numpy()
- img2 = video2[clip_timestamp].numpy()
-
- # calculate psnr of a video
- psnr_results_of_a_video.append(img_psnr(img1, img2))
-
- psnr_results.append(psnr_results_of_a_video)
-
- psnr_results = np.array(psnr_results)
-
- psnr = {}
- psnr_std = {}
-
- for clip_timestamp in range(len(video1)):
- psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp])
- psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp])
-
- result = {
- "value": psnr,
- "value_std": psnr_std,
- "video_setting": video1.shape,
- "video_setting_name": "time, channel, heigth, width",
- }
-
- return result
-
-
-# test code / using example
-
-
-def main():
- NUMBER_OF_VIDEOS = 8
- VIDEO_LENGTH = 50
- CHANNEL = 3
- SIZE = 64
- videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
- videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
-
- import json
-
- result = calculate_psnr(videos1, videos2)
- print(json.dumps(result, indent=4))
-
-
-if __name__ == "__main__":
- main()
diff --git a/eval/pab/commom_metrics/calculate_ssim.py b/eval/pab/commom_metrics/calculate_ssim.py
deleted file mode 100644
index aa78bd5ca2cd826f04b4f42e7dd5c53d61ed7231..0000000000000000000000000000000000000000
--- a/eval/pab/commom_metrics/calculate_ssim.py
+++ /dev/null
@@ -1,116 +0,0 @@
-import cv2
-import numpy as np
-import torch
-
-
-def ssim(img1, img2):
- C1 = 0.01**2
- C2 = 0.03**2
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
- kernel = cv2.getGaussianKernel(11, 1.5)
- window = np.outer(kernel, kernel.transpose())
- mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
- mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
- mu1_sq = mu1**2
- mu2_sq = mu2**2
- mu1_mu2 = mu1 * mu2
- sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
- sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
- sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
- return ssim_map.mean()
-
-
-def calculate_ssim_function(img1, img2):
- # [0,1]
- # ssim is the only metric extremely sensitive to gray being compared to b/w
- if not img1.shape == img2.shape:
- raise ValueError("Input images must have the same dimensions.")
- if img1.ndim == 2:
- return ssim(img1, img2)
- elif img1.ndim == 3:
- if img1.shape[0] == 3:
- ssims = []
- for i in range(3):
- ssims.append(ssim(img1[i], img2[i]))
- return np.array(ssims).mean()
- elif img1.shape[0] == 1:
- return ssim(np.squeeze(img1), np.squeeze(img2))
- else:
- raise ValueError("Wrong input image dimensions.")
-
-
-def trans(x):
- return x
-
-
-def calculate_ssim(videos1, videos2):
- # videos [batch_size, timestamps, channel, h, w]
-
- assert videos1.shape == videos2.shape
-
- videos1 = trans(videos1)
- videos2 = trans(videos2)
-
- ssim_results = []
-
- for video_num in range(videos1.shape[0]):
- # get a video
- # video [timestamps, channel, h, w]
- video1 = videos1[video_num]
- video2 = videos2[video_num]
-
- ssim_results_of_a_video = []
- for clip_timestamp in range(len(video1)):
- # get a img
- # img [timestamps[x], channel, h, w]
- # img [channel, h, w] numpy
-
- img1 = video1[clip_timestamp].numpy()
- img2 = video2[clip_timestamp].numpy()
-
- # calculate ssim of a video
- ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
-
- ssim_results.append(ssim_results_of_a_video)
-
- ssim_results = np.array(ssim_results)
-
- ssim = {}
- ssim_std = {}
-
- for clip_timestamp in range(len(video1)):
- ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp])
- ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp])
-
- result = {
- "value": ssim,
- "value_std": ssim_std,
- "video_setting": video1.shape,
- "video_setting_name": "time, channel, heigth, width",
- }
-
- return result
-
-
-# test code / using example
-
-
-def main():
- NUMBER_OF_VIDEOS = 8
- VIDEO_LENGTH = 50
- CHANNEL = 3
- SIZE = 64
- videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
- videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
- torch.device("cuda")
-
- import json
-
- result = calculate_ssim(videos1, videos2)
- print(json.dumps(result, indent=4))
-
-
-if __name__ == "__main__":
- main()
diff --git a/eval/pab/commom_metrics/eval.py b/eval/pab/commom_metrics/eval.py
deleted file mode 100644
index 5f300510d6ecedd5875486a8d260a6b4bbd22327..0000000000000000000000000000000000000000
--- a/eval/pab/commom_metrics/eval.py
+++ /dev/null
@@ -1,160 +0,0 @@
-import argparse
-import os
-
-import imageio
-import torch
-import torchvision.transforms.functional as F
-import tqdm
-from calculate_lpips import calculate_lpips
-from calculate_psnr import calculate_psnr
-from calculate_ssim import calculate_ssim
-
-
-def load_videos(directory, video_ids, file_extension):
- videos = []
- for video_id in video_ids:
- video_path = os.path.join(directory, f"{video_id}.{file_extension}")
- if os.path.exists(video_path):
- video = load_video(video_path) # Define load_video based on how videos are stored
- videos.append(video)
- else:
- raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}")
- return videos
-
-
-def load_video(video_path):
- """
- Load a video from the given path and convert it to a PyTorch tensor.
- """
- # Read the video using imageio
- reader = imageio.get_reader(video_path, "ffmpeg")
-
- # Extract frames and convert to a list of tensors
- frames = []
- for frame in reader:
- # Convert the frame to a tensor and permute the dimensions to match (C, H, W)
- frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
- frames.append(frame_tensor)
-
- # Stack the list of tensors into a single tensor with shape (T, C, H, W)
- video_tensor = torch.stack(frames)
-
- return video_tensor
-
-
-def resize_video(video, target_height, target_width):
- resized_frames = []
- for frame in video:
- resized_frame = F.resize(frame, [target_height, target_width])
- resized_frames.append(resized_frame)
- return torch.stack(resized_frames)
-
-
-def preprocess_eval_video(eval_video, generated_video_shape):
- T_gen, _, H_gen, W_gen = generated_video_shape
- T_eval, _, H_eval, W_eval = eval_video.shape
-
- if T_eval < T_gen:
- raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
-
- if H_eval < H_gen or W_eval < W_gen:
- # Resize the video maintaining the aspect ratio
- resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
- resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
- eval_video = resize_video(eval_video, resize_height, resize_width)
- # Recalculate the dimensions
- T_eval, _, H_eval, W_eval = eval_video.shape
-
- # Center crop
- start_h = (H_eval - H_gen) // 2
- start_w = (W_eval - W_gen) // 2
- cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
-
- return cropped_video
-
-
-def main(args):
- device = "cuda"
- gt_video_dir = args.gt_video_dir
- generated_video_dir = args.generated_video_dir
-
- video_ids = []
- file_extension = "mp4"
- for f in os.listdir(generated_video_dir):
- if f.endswith(f".{file_extension}"):
- video_ids.append(f.replace(f".{file_extension}", ""))
- if not video_ids:
- raise ValueError("No videos found in the generated video dataset. Exiting.")
-
- print(f"Find {len(video_ids)} videos")
- prompt_interval = 1
- batch_size = 16
- calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
-
- lpips_results = []
- psnr_results = []
- ssim_results = []
-
- total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
-
- for idx, video_id in enumerate(tqdm.tqdm(range(total_len))):
- gt_videos_tensor = []
- generated_videos_tensor = []
- for i in range(batch_size):
- video_idx = idx * batch_size + i
- if video_idx >= len(video_ids):
- break
- video_id = video_ids[video_idx]
- generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}"))
- generated_videos_tensor.append(generated_video)
- eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}"))
- gt_videos_tensor.append(eval_video)
- gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
- generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
-
- if calculate_lpips_flag:
- result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
- result = result["value"].values()
- result = sum(result) / len(result)
- lpips_results.append(result)
-
- if calculate_psnr_flag:
- result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
- result = result["value"].values()
- result = sum(result) / len(result)
- psnr_results.append(result)
-
- if calculate_ssim_flag:
- result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
- result = result["value"].values()
- result = sum(result) / len(result)
- ssim_results.append(result)
-
- if (idx + 1) % prompt_interval == 0:
- out_str = ""
- for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
- result = sum(results) / len(results)
- out_str += f"{name}: {result:.4f}, "
- print(f"Processed {idx + 1} videos. {out_str[:-2]}")
-
- out_str = ""
- for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
- result = sum(results) / len(results)
- out_str += f"{name}: {result:.4f}, "
- out_str = out_str[:-2]
-
- # save
- with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f:
- f.write(out_str)
-
- print(f"Processed all videos. {out_str}")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--gt_video_dir", type=str)
- parser.add_argument("--generated_video_dir", type=str)
-
- args = parser.parse_args()
-
- main(args)
diff --git a/eval/pab/experiments/attention_ablation.py b/eval/pab/experiments/attention_ablation.py
deleted file mode 100644
index c78964d09a1e1f74bcdf382d468c1d2ca03e5ce9..0000000000000000000000000000000000000000
--- a/eval/pab/experiments/attention_ablation.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from utils import generate_func, read_prompt_list
-
-import videosys
-from videosys import OpenSoraConfig, OpenSoraPipeline
-from videosys.models.open_sora import OpenSoraPABConfig
-
-
-def attention_ablation_func(pab_kwargs, prompt_list, output_dir):
- pab_config = OpenSoraPABConfig(**pab_kwargs)
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
- pipeline = OpenSoraPipeline(config)
-
- generate_func(pipeline, prompt_list, output_dir)
-
-
-def main(prompt_list):
- # spatial
- gap_list = [2, 3, 4, 5]
- for gap in gap_list:
- pab_kwargs = {
- "spatial_broadcast": True,
- "spatial_gap": gap,
- "temporal_broadcast": False,
- "cross_broadcast": False,
- "mlp_skip": False,
- }
- output_dir = f"./samples/attention_ablation/spatial_g{gap}"
- attention_ablation_func(pab_kwargs, prompt_list, output_dir)
-
- # temporal
- gap_list = [3, 4, 5, 6]
- for gap in gap_list:
- pab_kwargs = {
- "spatial_broadcast": False,
- "temporal_broadcast": True,
- "temporal_gap": gap,
- "cross_broadcast": False,
- "mlp_skip": False,
- }
- output_dir = f"./samples/attention_ablation/temporal_g{gap}"
- attention_ablation_func(pab_kwargs, prompt_list, output_dir)
-
- # cross
- gap_list = [5, 6, 7, 8]
- for gap in gap_list:
- pab_kwargs = {
- "spatial_broadcast": False,
- "temporal_broadcast": False,
- "cross_broadcast": True,
- "cross_gap": gap,
- "mlp_skip": False,
- }
- output_dir = f"./samples/attention_ablation/cross_g{gap}"
- attention_ablation_func(pab_kwargs, prompt_list, output_dir)
-
-
-if __name__ == "__main__":
- videosys.initialize(42)
- prompt_list = read_prompt_list("vbench/VBench_full_info.json")
- main(prompt_list)
diff --git a/eval/pab/experiments/components_ablation.py b/eval/pab/experiments/components_ablation.py
deleted file mode 100644
index 12d88f3a61f031d1f51035876d6660d950fa4575..0000000000000000000000000000000000000000
--- a/eval/pab/experiments/components_ablation.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from utils import generate_func, read_prompt_list
-
-import videosys
-from videosys import OpenSoraConfig, OpenSoraPipeline
-from videosys.models.open_sora import OpenSoraPABConfig
-
-
-def wo_spatial(prompt_list):
- pab_config = OpenSoraPABConfig(spatial_broadcast=False)
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
- pipeline = OpenSoraPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_spatial")
-
-
-def wo_temporal(prompt_list):
- pab_config = OpenSoraPABConfig(temporal_broadcast=False)
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
- pipeline = OpenSoraPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_temporal")
-
-
-def wo_cross(prompt_list):
- pab_config = OpenSoraPABConfig(cross_broadcast=False)
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
- pipeline = OpenSoraPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_cross")
-
-
-def wo_mlp(prompt_list):
- pab_config = OpenSoraPABConfig(mlp_skip=False)
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
- pipeline = OpenSoraPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_mlp")
-
-
-if __name__ == "__main__":
- videosys.initialize(42)
- prompt_list = read_prompt_list("./vbench/VBench_full_info.json")
- wo_spatial(prompt_list)
- wo_temporal(prompt_list)
- wo_cross(prompt_list)
- wo_mlp(prompt_list)
diff --git a/eval/pab/experiments/latte.py b/eval/pab/experiments/latte.py
deleted file mode 100644
index 5748dbaf78b6b8a9af784aea7188f4719d1aaf8c..0000000000000000000000000000000000000000
--- a/eval/pab/experiments/latte.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from utils import generate_func, read_prompt_list
-
-import videosys
-from videosys import LatteConfig, LattePipeline
-from videosys.models.latte import LattePABConfig
-
-
-def eval_base(prompt_list):
- config = LatteConfig()
- pipeline = LattePipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/latte_base", loop=5)
-
-
-def eval_pab1(prompt_list):
- pab_config = LattePABConfig(
- spatial_gap=2,
- temporal_gap=3,
- cross_gap=6,
- )
- config = LatteConfig(enable_pab=True, pab_config=pab_config)
- pipeline = LattePipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/latte_pab1", loop=5)
-
-
-def eval_pab2(prompt_list):
- pab_config = LattePABConfig(
- spatial_gap=3,
- temporal_gap=4,
- cross_gap=7,
- )
- config = LatteConfig(enable_pab=True, pab_config=pab_config)
- pipeline = LattePipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/latte_pab2", loop=5)
-
-
-def eval_pab3(prompt_list):
- pab_config = LattePABConfig(
- spatial_gap=4,
- temporal_gap=6,
- cross_gap=9,
- )
- config = LatteConfig(enable_pab=True, pab_config=pab_config)
- pipeline = LattePipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/latte_pab3", loop=5)
-
-
-if __name__ == "__main__":
- videosys.initialize(42)
- prompt_list = read_prompt_list("vbench/VBench_full_info.json")
- eval_base(prompt_list)
- eval_pab1(prompt_list)
- eval_pab2(prompt_list)
- eval_pab3(prompt_list)
diff --git a/eval/pab/experiments/opensora.py b/eval/pab/experiments/opensora.py
deleted file mode 100644
index 7799c67308704bb2c825996ca8d800f95ba5d2c6..0000000000000000000000000000000000000000
--- a/eval/pab/experiments/opensora.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from utils import generate_func, read_prompt_list
-
-import videosys
-from videosys import OpenSoraConfig, OpenSoraPipeline
-from videosys.models.open_sora import OpenSoraPABConfig
-
-
-def eval_base(prompt_list):
- config = OpenSoraConfig()
- pipeline = OpenSoraPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/opensora_base", loop=5)
-
-
-def eval_pab1(prompt_list):
- config = OpenSoraConfig(enable_pab=True)
- pipeline = OpenSoraPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/opensora_pab1", loop=5)
-
-
-def eval_pab2(prompt_list):
- pab_config = OpenSoraPABConfig(spatial_gap=3, temporal_gap=5, cross_gap=7)
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
- pipeline = OpenSoraPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/opensora_pab2", loop=5)
-
-
-def eval_pab3(prompt_list):
- pab_config = OpenSoraPABConfig(spatial_gap=5, temporal_gap=7, cross_gap=9)
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
- pipeline = OpenSoraPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/opensora_pab3", loop=5)
-
-
-if __name__ == "__main__":
- videosys.initialize(42)
- prompt_list = read_prompt_list("vbench/VBench_full_info.json")
- eval_base(prompt_list)
- eval_pab1(prompt_list)
- eval_pab2(prompt_list)
- eval_pab3(prompt_list)
diff --git a/eval/pab/experiments/opensora_plan.py b/eval/pab/experiments/opensora_plan.py
deleted file mode 100644
index a4e8efc955e6ec469aa4a40d20abab31a8481a42..0000000000000000000000000000000000000000
--- a/eval/pab/experiments/opensora_plan.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from utils import generate_func, read_prompt_list
-
-import videosys
-from videosys import OpenSoraPlanConfig, OpenSoraPlanPipeline
-from videosys.models.open_sora_plan import OpenSoraPlanPABConfig
-
-
-def eval_base(prompt_list):
- config = OpenSoraPlanConfig()
- pipeline = OpenSoraPlanPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/opensoraplan_base", loop=5)
-
-
-def eval_pab1(prompt_list):
- pab_config = OpenSoraPlanPABConfig(
- spatial_gap=2,
- temporal_gap=4,
- cross_gap=6,
- )
- config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
- pipeline = OpenSoraPlanPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab1", loop=5)
-
-
-def eval_pab2(prompt_list):
- pab_config = OpenSoraPlanPABConfig(
- spatial_gap=3,
- temporal_gap=5,
- cross_gap=7,
- )
- config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
- pipeline = OpenSoraPlanPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab2", loop=5)
-
-
-def eval_pab3(prompt_list):
- pab_config = OpenSoraPlanPABConfig(
- spatial_gap=5,
- temporal_gap=7,
- cross_gap=9,
- )
- config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
- pipeline = OpenSoraPlanPipeline(config)
-
- generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab3", loop=5)
-
-
-if __name__ == "__main__":
- videosys.initialize(42)
- prompt_list = read_prompt_list("vbench/VBench_full_info.json")
- eval_base(prompt_list)
- eval_pab1(prompt_list)
- eval_pab2(prompt_list)
- eval_pab3(prompt_list)
diff --git a/eval/pab/experiments/utils.py b/eval/pab/experiments/utils.py
deleted file mode 100644
index cb52309fda21056b8ba352696e9ca4cf1fe1788e..0000000000000000000000000000000000000000
--- a/eval/pab/experiments/utils.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import json
-import os
-
-import tqdm
-
-from videosys.utils.utils import set_seed
-
-
-def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict = {}):
- kwargs["verbose"] = False
- for prompt in tqdm.tqdm(prompt_list):
- for l in range(loop):
- set_seed(l)
- video = pipeline.generate(prompt, **kwargs).video[0]
- pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))
-
-
-def read_prompt_list(prompt_list_path):
- with open(prompt_list_path, "r") as f:
- prompt_list = json.load(f)
- prompt_list = [prompt["prompt_en"] for prompt in prompt_list]
- return prompt_list
diff --git a/eval/pab/vbench/VBench_full_info.json b/eval/pab/vbench/VBench_full_info.json
deleted file mode 100644
index e60c40eb0050a5304791490972be3b32de309e4a..0000000000000000000000000000000000000000
--- a/eval/pab/vbench/VBench_full_info.json
+++ /dev/null
@@ -1,9132 +0,0 @@
-[
- {
- "prompt_en": "In a still frame, a stop sign",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "a toilet, frozen in time",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "a laptop, frozen in time",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of alley",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of bar",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of barn",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of bathroom",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of bedroom",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of cliff",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, courtyard",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, gas station",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of house",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "indoor gymnasium, frozen in time",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of indoor library",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of kitchen",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of palace",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, parking lot",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, phone booth",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of restaurant",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of tower",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a bowl",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of an apple",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a bench",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a bed",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a chair",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a cup",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a dining table",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, a pear",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a bunch of grapes",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a bowl on the kitchen counter",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a beautiful, handcrafted ceramic bowl",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of an antique bowl",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of an exquisite mahogany dining table",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a wooden bench in the park",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a beautiful wrought-iron bench surrounded by blooming flowers",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, a park bench with a view of the lake",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a vintage rocking chair was placed on the porch",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of the jail cell was small and dimly lit, with cold, steel bars",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of the phone booth was tucked away in a quiet alley",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "a dilapidated phone booth stood as a relic of a bygone era on the sidewalk, frozen in time",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of the old red barn stood weathered and iconic against the backdrop of the countryside",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a picturesque barn was painted a warm shade of red and nestled in a picturesque meadow",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, within the desolate desert, an oasis unfolded, characterized by the stoic presence of palm trees and a motionless, glassy pool of water",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, the Parthenon's majestic Doric columns stand in serene solitude atop the Acropolis, framed by the tranquil Athenian landscape",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, the Temple of Hephaestus, with its timeless Doric grace, stands stoically against the backdrop of a quiet Athens",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, the ornate Victorian streetlamp stands solemnly, adorned with intricate ironwork and stained glass panels",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of the Stonehenge presented itself as an enigmatic puzzle, each colossal stone meticulously placed against the backdrop of tranquility",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, in the vast desert, an oasis nestled among dunes, featuring tall palm trees and an air of serenity",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "static view on a desert scene with an oasis, palm trees, and a clear, calm pool of water",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of an ornate Victorian streetlamp standing on a cobblestone street corner, illuminating the empty night",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a tranquil lakeside cabin nestled among tall pines, its reflection mirrored perfectly in the calm water",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, a vintage gas lantern, adorned with intricate details, gracing a historic cobblestone square",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, a tranquil Japanese tea ceremony room, with tatami mats, a delicate tea set, and a bonsai tree in the corner",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of the Parthenon stands resolute in its classical elegance, a timeless symbol of Athens' cultural legacy",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of in the heart of Plaka, the neoclassical architecture of the old city harmonizes with the ancient ruins",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of in the desolate beauty of the American Southwest, Chaco Canyon's ancient ruins whispered tales of an enigmatic civilization that once thrived amidst the arid landscapes",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of at the edge of the Arabian Desert, the ancient city of Petra beckoned with its enigmatic rock-carved fa\u00e7ades",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, amidst the cobblestone streets, an Art Nouveau lamppost stood tall",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of in the quaint village square, a traditional wrought-iron streetlamp featured delicate filigree patterns and amber-hued glass panels",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of the lampposts were adorned with Art Deco motifs, their geometric shapes and frosted glass creating a sense of vintage glamour",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, in the picturesque square, a Gothic-style lamppost adorned with intricate stone carvings added a touch of medieval charm to the setting",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, in the heart of the old city, a row of ornate lantern-style streetlamps bathed the narrow alleyway in a warm, welcoming light",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of in the heart of the Utah desert, a massive sandstone arch spanned the horizon",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of in the Arizona desert, a massive stone bridge arched across a rugged canyon",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of in the corner of the minimalist tea room, a bonsai tree added a touch of nature's beauty to the otherwise simple and elegant space",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, amidst the hushed ambiance of the traditional tea room, a meticulously arranged tea set awaited, with porcelain cups, a bamboo whisk",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, nestled in the Zen garden, a rustic teahouse featured tatami seating and a traditional charcoal brazier",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a country estate's library featured elegant wooden shelves",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of beneath the shade of a solitary oak tree, an old wooden park bench sat patiently",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of beside a tranquil pond, a weeping willow tree draped its branches gracefully over the water's surface, creating a serene tableau of reflection and calm",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of in the Zen garden, a perfectly raked gravel path led to a serene rock garden",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, a tranquil pond was fringed by weeping cherry trees, their blossoms drifting lazily onto the glassy surface",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "In a still frame, within the historic library's reading room, rows of antique leather chairs and mahogany tables offered a serene haven for literary contemplation",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of a peaceful orchid garden showcased a variety of delicate blooms",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "A tranquil tableau of in the serene courtyard, a centuries-old stone well stood as a symbol of a bygone era, its mossy stones bearing witness to the passage of time",
- "dimension": [
- "temporal_flickering"
- ]
- },
- {
- "prompt_en": "a bird and a cat",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "bird and cat"
- }
- }
- },
- {
- "prompt_en": "a cat and a dog",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "cat and dog"
- }
- }
- },
- {
- "prompt_en": "a dog and a horse",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "dog and horse"
- }
- }
- },
- {
- "prompt_en": "a horse and a sheep",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "horse and sheep"
- }
- }
- },
- {
- "prompt_en": "a sheep and a cow",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "sheep and cow"
- }
- }
- },
- {
- "prompt_en": "a cow and an elephant",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "cow and elephant"
- }
- }
- },
- {
- "prompt_en": "an elephant and a bear",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "elephant and bear"
- }
- }
- },
- {
- "prompt_en": "a bear and a zebra",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "bear and zebra"
- }
- }
- },
- {
- "prompt_en": "a zebra and a giraffe",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "zebra and giraffe"
- }
- }
- },
- {
- "prompt_en": "a giraffe and a bird",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "giraffe and bird"
- }
- }
- },
- {
- "prompt_en": "a chair and a couch",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "chair and couch"
- }
- }
- },
- {
- "prompt_en": "a couch and a potted plant",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "couch and potted plant"
- }
- }
- },
- {
- "prompt_en": "a potted plant and a tv",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "potted plant and tv"
- }
- }
- },
- {
- "prompt_en": "a tv and a laptop",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "tv and laptop"
- }
- }
- },
- {
- "prompt_en": "a laptop and a remote",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "laptop and remote"
- }
- }
- },
- {
- "prompt_en": "a remote and a keyboard",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "remote and keyboard"
- }
- }
- },
- {
- "prompt_en": "a keyboard and a cell phone",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "keyboard and cell phone"
- }
- }
- },
- {
- "prompt_en": "a cell phone and a book",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "cell phone and book"
- }
- }
- },
- {
- "prompt_en": "a book and a clock",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "book and clock"
- }
- }
- },
- {
- "prompt_en": "a clock and a backpack",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "clock and backpack"
- }
- }
- },
- {
- "prompt_en": "a backpack and an umbrella",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "backpack and umbrella"
- }
- }
- },
- {
- "prompt_en": "an umbrella and a handbag",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "umbrella and handbag"
- }
- }
- },
- {
- "prompt_en": "a handbag and a tie",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "handbag and tie"
- }
- }
- },
- {
- "prompt_en": "a tie and a suitcase",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "tie and suitcase"
- }
- }
- },
- {
- "prompt_en": "a suitcase and a vase",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "suitcase and vase"
- }
- }
- },
- {
- "prompt_en": "a vase and scissors",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "vase and scissors"
- }
- }
- },
- {
- "prompt_en": "scissors and a teddy bear",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "scissors and teddy bear"
- }
- }
- },
- {
- "prompt_en": "a teddy bear and a frisbee",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "teddy bear and frisbee"
- }
- }
- },
- {
- "prompt_en": "a frisbee and skis",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "frisbee and skis"
- }
- }
- },
- {
- "prompt_en": "skis and a snowboard",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "skis and snowboard"
- }
- }
- },
- {
- "prompt_en": "a snowboard and a sports ball",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "snowboard and sports ball"
- }
- }
- },
- {
- "prompt_en": "a sports ball and a kite",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "sports ball and kite"
- }
- }
- },
- {
- "prompt_en": "a kite and a baseball bat",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "kite and baseball bat"
- }
- }
- },
- {
- "prompt_en": "a baseball bat and a baseball glove",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "baseball bat and baseball glove"
- }
- }
- },
- {
- "prompt_en": "a baseball glove and a skateboard",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "baseball glove and skateboard"
- }
- }
- },
- {
- "prompt_en": "a skateboard and a surfboard",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "skateboard and surfboard"
- }
- }
- },
- {
- "prompt_en": "a surfboard and a tennis racket",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "surfboard and tennis racket"
- }
- }
- },
- {
- "prompt_en": "a tennis racket and a bottle",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "tennis racket and bottle"
- }
- }
- },
- {
- "prompt_en": "a bottle and a chair",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "bottle and chair"
- }
- }
- },
- {
- "prompt_en": "an airplane and a train",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "airplane and train"
- }
- }
- },
- {
- "prompt_en": "a train and a boat",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "train and boat"
- }
- }
- },
- {
- "prompt_en": "a boat and an airplane",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "boat and airplane"
- }
- }
- },
- {
- "prompt_en": "a bicycle and a car",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "bicycle and car"
- }
- }
- },
- {
- "prompt_en": "a car and a motorcycle",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "car and motorcycle"
- }
- }
- },
- {
- "prompt_en": "a motorcycle and a bus",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "motorcycle and bus"
- }
- }
- },
- {
- "prompt_en": "a bus and a traffic light",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "bus and traffic light"
- }
- }
- },
- {
- "prompt_en": "a traffic light and a fire hydrant",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "traffic light and fire hydrant"
- }
- }
- },
- {
- "prompt_en": "a fire hydrant and a stop sign",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "fire hydrant and stop sign"
- }
- }
- },
- {
- "prompt_en": "a stop sign and a parking meter",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "stop sign and parking meter"
- }
- }
- },
- {
- "prompt_en": "a parking meter and a truck",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "parking meter and truck"
- }
- }
- },
- {
- "prompt_en": "a truck and a bicycle",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "truck and bicycle"
- }
- }
- },
- {
- "prompt_en": "a toilet and a hair drier",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "toilet and hair drier"
- }
- }
- },
- {
- "prompt_en": "a hair drier and a toothbrush",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "hair drier and toothbrush"
- }
- }
- },
- {
- "prompt_en": "a toothbrush and a sink",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "toothbrush and sink"
- }
- }
- },
- {
- "prompt_en": "a sink and a toilet",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "sink and toilet"
- }
- }
- },
- {
- "prompt_en": "a wine glass and a chair",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "wine glass and chair"
- }
- }
- },
- {
- "prompt_en": "a cup and a couch",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "cup and couch"
- }
- }
- },
- {
- "prompt_en": "a fork and a potted plant",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "fork and potted plant"
- }
- }
- },
- {
- "prompt_en": "a knife and a tv",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "knife and tv"
- }
- }
- },
- {
- "prompt_en": "a spoon and a laptop",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "spoon and laptop"
- }
- }
- },
- {
- "prompt_en": "a bowl and a remote",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "bowl and remote"
- }
- }
- },
- {
- "prompt_en": "a banana and a keyboard",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "banana and keyboard"
- }
- }
- },
- {
- "prompt_en": "an apple and a cell phone",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "apple and cell phone"
- }
- }
- },
- {
- "prompt_en": "a sandwich and a book",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "sandwich and book"
- }
- }
- },
- {
- "prompt_en": "an orange and a clock",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "orange and clock"
- }
- }
- },
- {
- "prompt_en": "broccoli and a backpack",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "broccoli and backpack"
- }
- }
- },
- {
- "prompt_en": "a carrot and an umbrella",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "carrot and umbrella"
- }
- }
- },
- {
- "prompt_en": "a hot dog and a handbag",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "hot dog and handbag"
- }
- }
- },
- {
- "prompt_en": "a pizza and a tie",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "pizza and tie"
- }
- }
- },
- {
- "prompt_en": "a donut and a suitcase",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "donut and suitcase"
- }
- }
- },
- {
- "prompt_en": "a cake and a vase",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "cake and vase"
- }
- }
- },
- {
- "prompt_en": "an oven and scissors",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "oven and scissors"
- }
- }
- },
- {
- "prompt_en": "a toaster and a teddy bear",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "toaster and teddy bear"
- }
- }
- },
- {
- "prompt_en": "a microwave and a frisbee",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "microwave and frisbee"
- }
- }
- },
- {
- "prompt_en": "a refrigerator and skis",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "refrigerator and skis"
- }
- }
- },
- {
- "prompt_en": "a bicycle and an airplane",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "bicycle and airplane"
- }
- }
- },
- {
- "prompt_en": "a car and a train",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "car and train"
- }
- }
- },
- {
- "prompt_en": "a motorcycle and a boat",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "motorcycle and boat"
- }
- }
- },
- {
- "prompt_en": "a person and a toilet",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "person and toilet"
- }
- }
- },
- {
- "prompt_en": "a person and a hair drier",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "person and hair drier"
- }
- }
- },
- {
- "prompt_en": "a person and a toothbrush",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "person and toothbrush"
- }
- }
- },
- {
- "prompt_en": "a person and a sink",
- "dimension": [
- "multiple_objects"
- ],
- "auxiliary_info": {
- "multiple_objects": {
- "object": "person and sink"
- }
- }
- },
- {
- "prompt_en": "A person is riding a bike",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is marching",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is roller skating",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is tasting beer",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is clapping",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is drawing",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is petting animal (not cat)",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is eating watermelon",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is playing harp",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is wrestling",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is riding scooter",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is sweeping floor",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is skateboarding",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is dunking basketball",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is playing flute",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is stretching leg",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is tying tie",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is skydiving",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is shooting goal (soccer)",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is playing piano",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is finger snapping",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is canoeing or kayaking",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is laughing",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is digging",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is clay pottery making",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is shooting basketball",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is bending back",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is shaking hands",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is bandaging",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is push up",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is catching or throwing frisbee",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is playing trumpet",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is flying kite",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is filling eyebrows",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is shuffling cards",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is folding clothes",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is smoking",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is tai chi",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is squat",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is playing controller",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is throwing axe",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is giving or receiving award",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is air drumming",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is taking a shower",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is planting trees",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is sharpening knives",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is robot dancing",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is rock climbing",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is hula hooping",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is writing",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is bungee jumping",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is pushing cart",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is cleaning windows",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is cutting watermelon",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is cheerleading",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is washing hands",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is ironing",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is cutting nails",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is hugging",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is trimming or shaving beard",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is jogging",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is making bed",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is washing dishes",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is grooming dog",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is doing laundry",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is knitting",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is reading book",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is baby waking up",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is massaging legs",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is brushing teeth",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is crawling baby",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is motorcycling",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is driving car",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is sticking tongue out",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is shaking head",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is sword fighting",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is doing aerobics",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is strumming guitar",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is riding or walking with horse",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is archery",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is catching or throwing baseball",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is playing chess",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is rock scissors paper",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is using computer",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is arranging flowers",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is bending metal",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is ice skating",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is climbing a rope",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is crying",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is dancing ballet",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is getting a haircut",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is running on treadmill",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is kissing",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is counting money",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is barbequing",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is peeling apples",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is milking cow",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is shining shoes",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is making snowman",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "A person is sailing",
- "dimension": [
- "human_action"
- ]
- },
- {
- "prompt_en": "a person swimming in ocean",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a person giving a presentation to a room full of colleagues",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a person washing the dishes",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a person eating a burger",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a person walking in the snowstorm",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a person drinking coffee in a cafe",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a person playing guitar",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bicycle leaning against a tree",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bicycle gliding through a snowy field",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bicycle slowing down to stop",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bicycle accelerating to gain speed",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a car stuck in traffic during rush hour",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a car turning a corner",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a car slowing down to stop",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a car accelerating to gain speed",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a motorcycle cruising along a coastal highway",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a motorcycle turning a corner",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a motorcycle slowing down to stop",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a motorcycle gliding through a snowy field",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a motorcycle accelerating to gain speed",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "an airplane soaring through a clear blue sky",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "an airplane taking off",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "an airplane landing smoothly on a runway",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "an airplane accelerating to gain speed",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bus turning a corner",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bus stuck in traffic during rush hour",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bus accelerating to gain speed",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a train speeding down the tracks",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a train crossing over a tall bridge",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a train accelerating to gain speed",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a truck turning a corner",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a truck anchored in a tranquil bay",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a truck stuck in traffic during rush hour",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a truck slowing down to stop",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a truck accelerating to gain speed",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a boat sailing smoothly on a calm lake",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a boat slowing down to stop",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a boat accelerating to gain speed",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bird soaring gracefully in the sky",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bird building a nest from twigs and leaves",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bird flying over a snowy forest",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a cat grooming itself meticulously with its tongue",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a cat playing in park",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a cat drinking water",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a cat running happily",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a dog enjoying a peaceful walk",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a dog playing in park",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a dog drinking water",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a dog running happily",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a horse bending down to drink water from a river",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a horse galloping across an open field",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a horse taking a peaceful walk",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a horse running to join a herd of its kind",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a sheep bending down to drink water from a river",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a sheep taking a peaceful walk",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a sheep running to join a herd of its kind",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a cow bending down to drink water from a river",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a cow chewing cud while resting in a tranquil barn",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a cow running to join a herd of its kind",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "an elephant spraying itself with water using its trunk to cool down",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "an elephant taking a peaceful walk",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "an elephant running to join a herd of its kind",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bear catching a salmon in its powerful jaws",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bear sniffing the air for scents of food",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bear climbing a tree",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a bear hunting for prey",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a zebra bending down to drink water from a river",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a zebra running to join a herd of its kind",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a zebra taking a peaceful walk",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a giraffe bending down to drink water from a river",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a giraffe taking a peaceful walk",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a giraffe running to join a herd of its kind",
- "dimension": [
- "subject_consistency",
- "dynamic_degree",
- "motion_smoothness"
- ]
- },
- {
- "prompt_en": "a person",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "person"
- }
- }
- },
- {
- "prompt_en": "a bicycle",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "bicycle"
- }
- }
- },
- {
- "prompt_en": "a car",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "car"
- }
- }
- },
- {
- "prompt_en": "a motorcycle",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "motorcycle"
- }
- }
- },
- {
- "prompt_en": "an airplane",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "airplane"
- }
- }
- },
- {
- "prompt_en": "a bus",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "bus"
- }
- }
- },
- {
- "prompt_en": "a train",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "train"
- }
- }
- },
- {
- "prompt_en": "a truck",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "truck"
- }
- }
- },
- {
- "prompt_en": "a boat",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "boat"
- }
- }
- },
- {
- "prompt_en": "a traffic light",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "traffic light"
- }
- }
- },
- {
- "prompt_en": "a fire hydrant",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "fire hydrant"
- }
- }
- },
- {
- "prompt_en": "a stop sign",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "stop sign"
- }
- }
- },
- {
- "prompt_en": "a parking meter",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "parking meter"
- }
- }
- },
- {
- "prompt_en": "a bench",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "bench"
- }
- }
- },
- {
- "prompt_en": "a bird",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "bird"
- }
- }
- },
- {
- "prompt_en": "a cat",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "cat"
- }
- }
- },
- {
- "prompt_en": "a dog",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "dog"
- }
- }
- },
- {
- "prompt_en": "a horse",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "horse"
- }
- }
- },
- {
- "prompt_en": "a sheep",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "sheep"
- }
- }
- },
- {
- "prompt_en": "a cow",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "cow"
- }
- }
- },
- {
- "prompt_en": "an elephant",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "elephant"
- }
- }
- },
- {
- "prompt_en": "a bear",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "bear"
- }
- }
- },
- {
- "prompt_en": "a zebra",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "zebra"
- }
- }
- },
- {
- "prompt_en": "a giraffe",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "giraffe"
- }
- }
- },
- {
- "prompt_en": "a backpack",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "backpack"
- }
- }
- },
- {
- "prompt_en": "an umbrella",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "umbrella"
- }
- }
- },
- {
- "prompt_en": "a handbag",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "handbag"
- }
- }
- },
- {
- "prompt_en": "a tie",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "tie"
- }
- }
- },
- {
- "prompt_en": "a suitcase",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "suitcase"
- }
- }
- },
- {
- "prompt_en": "a frisbee",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "frisbee"
- }
- }
- },
- {
- "prompt_en": "skis",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "skis"
- }
- }
- },
- {
- "prompt_en": "a snowboard",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "snowboard"
- }
- }
- },
- {
- "prompt_en": "a sports ball",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "sports ball"
- }
- }
- },
- {
- "prompt_en": "a kite",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "kite"
- }
- }
- },
- {
- "prompt_en": "a baseball bat",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "baseball bat"
- }
- }
- },
- {
- "prompt_en": "a baseball glove",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "baseball glove"
- }
- }
- },
- {
- "prompt_en": "a skateboard",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "skateboard"
- }
- }
- },
- {
- "prompt_en": "a surfboard",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "surfboard"
- }
- }
- },
- {
- "prompt_en": "a tennis racket",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "tennis racket"
- }
- }
- },
- {
- "prompt_en": "a bottle",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "bottle"
- }
- }
- },
- {
- "prompt_en": "a wine glass",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "wine glass"
- }
- }
- },
- {
- "prompt_en": "a cup",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "cup"
- }
- }
- },
- {
- "prompt_en": "a fork",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "fork"
- }
- }
- },
- {
- "prompt_en": "a knife",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "knife"
- }
- }
- },
- {
- "prompt_en": "a spoon",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "spoon"
- }
- }
- },
- {
- "prompt_en": "a bowl",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "bowl"
- }
- }
- },
- {
- "prompt_en": "a banana",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "banana"
- }
- }
- },
- {
- "prompt_en": "an apple",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "apple"
- }
- }
- },
- {
- "prompt_en": "a sandwich",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "sandwich"
- }
- }
- },
- {
- "prompt_en": "an orange",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "orange"
- }
- }
- },
- {
- "prompt_en": "broccoli",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "broccoli"
- }
- }
- },
- {
- "prompt_en": "a carrot",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "carrot"
- }
- }
- },
- {
- "prompt_en": "a hot dog",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "hot dog"
- }
- }
- },
- {
- "prompt_en": "a pizza",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "pizza"
- }
- }
- },
- {
- "prompt_en": "a donut",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "donut"
- }
- }
- },
- {
- "prompt_en": "a cake",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "cake"
- }
- }
- },
- {
- "prompt_en": "a chair",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "chair"
- }
- }
- },
- {
- "prompt_en": "a couch",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "couch"
- }
- }
- },
- {
- "prompt_en": "a potted plant",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "potted plant"
- }
- }
- },
- {
- "prompt_en": "a bed",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "bed"
- }
- }
- },
- {
- "prompt_en": "a dining table",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "dining table"
- }
- }
- },
- {
- "prompt_en": "a toilet",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "toilet"
- }
- }
- },
- {
- "prompt_en": "a tv",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "tv"
- }
- }
- },
- {
- "prompt_en": "a laptop",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "laptop"
- }
- }
- },
- {
- "prompt_en": "a remote",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "remote"
- }
- }
- },
- {
- "prompt_en": "a keyboard",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "keyboard"
- }
- }
- },
- {
- "prompt_en": "a cell phone",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "cell phone"
- }
- }
- },
- {
- "prompt_en": "a microwave",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "microwave"
- }
- }
- },
- {
- "prompt_en": "an oven",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "oven"
- }
- }
- },
- {
- "prompt_en": "a toaster",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "toaster"
- }
- }
- },
- {
- "prompt_en": "a sink",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "sink"
- }
- }
- },
- {
- "prompt_en": "a refrigerator",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "refrigerator"
- }
- }
- },
- {
- "prompt_en": "a book",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "book"
- }
- }
- },
- {
- "prompt_en": "a clock",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "clock"
- }
- }
- },
- {
- "prompt_en": "a vase",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "vase"
- }
- }
- },
- {
- "prompt_en": "scissors",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "scissors"
- }
- }
- },
- {
- "prompt_en": "a teddy bear",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "teddy bear"
- }
- }
- },
- {
- "prompt_en": "a hair drier",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "hair drier"
- }
- }
- },
- {
- "prompt_en": "a toothbrush",
- "dimension": [
- "object_class"
- ],
- "auxiliary_info": {
- "object_class": {
- "object": "toothbrush"
- }
- }
- },
- {
- "prompt_en": "a red bicycle",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "red"
- }
- }
- },
- {
- "prompt_en": "a green bicycle",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "green"
- }
- }
- },
- {
- "prompt_en": "a blue bicycle",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "blue"
- }
- }
- },
- {
- "prompt_en": "a yellow bicycle",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "yellow"
- }
- }
- },
- {
- "prompt_en": "an orange bicycle",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "orange"
- }
- }
- },
- {
- "prompt_en": "a purple bicycle",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "purple"
- }
- }
- },
- {
- "prompt_en": "a pink bicycle",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "pink"
- }
- }
- },
- {
- "prompt_en": "a black bicycle",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "black"
- }
- }
- },
- {
- "prompt_en": "a white bicycle",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "white"
- }
- }
- },
- {
- "prompt_en": "a red car",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "red"
- }
- }
- },
- {
- "prompt_en": "a green car",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "green"
- }
- }
- },
- {
- "prompt_en": "a blue car",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "blue"
- }
- }
- },
- {
- "prompt_en": "a yellow car",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "yellow"
- }
- }
- },
- {
- "prompt_en": "an orange car",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "orange"
- }
- }
- },
- {
- "prompt_en": "a purple car",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "purple"
- }
- }
- },
- {
- "prompt_en": "a pink car",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "pink"
- }
- }
- },
- {
- "prompt_en": "a black car",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "black"
- }
- }
- },
- {
- "prompt_en": "a white car",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "white"
- }
- }
- },
- {
- "prompt_en": "a red bird",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "red"
- }
- }
- },
- {
- "prompt_en": "a green bird",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "green"
- }
- }
- },
- {
- "prompt_en": "a blue bird",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "blue"
- }
- }
- },
- {
- "prompt_en": "a yellow bird",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "yellow"
- }
- }
- },
- {
- "prompt_en": "an orange bird",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "orange"
- }
- }
- },
- {
- "prompt_en": "a purple bird",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "purple"
- }
- }
- },
- {
- "prompt_en": "a pink bird",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "pink"
- }
- }
- },
- {
- "prompt_en": "a black bird",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "black"
- }
- }
- },
- {
- "prompt_en": "a white bird",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "white"
- }
- }
- },
- {
- "prompt_en": "a black cat",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "black"
- }
- }
- },
- {
- "prompt_en": "a white cat",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "white"
- }
- }
- },
- {
- "prompt_en": "an orange cat",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "orange"
- }
- }
- },
- {
- "prompt_en": "a yellow cat",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "yellow"
- }
- }
- },
- {
- "prompt_en": "a red umbrella",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "red"
- }
- }
- },
- {
- "prompt_en": "a green umbrella",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "green"
- }
- }
- },
- {
- "prompt_en": "a blue umbrella",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "blue"
- }
- }
- },
- {
- "prompt_en": "a yellow umbrella",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "yellow"
- }
- }
- },
- {
- "prompt_en": "an orange umbrella",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "orange"
- }
- }
- },
- {
- "prompt_en": "a purple umbrella",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "purple"
- }
- }
- },
- {
- "prompt_en": "a pink umbrella",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "pink"
- }
- }
- },
- {
- "prompt_en": "a black umbrella",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "black"
- }
- }
- },
- {
- "prompt_en": "a white umbrella",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "white"
- }
- }
- },
- {
- "prompt_en": "a red suitcase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "red"
- }
- }
- },
- {
- "prompt_en": "a green suitcase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "green"
- }
- }
- },
- {
- "prompt_en": "a blue suitcase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "blue"
- }
- }
- },
- {
- "prompt_en": "a yellow suitcase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "yellow"
- }
- }
- },
- {
- "prompt_en": "an orange suitcase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "orange"
- }
- }
- },
- {
- "prompt_en": "a purple suitcase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "purple"
- }
- }
- },
- {
- "prompt_en": "a pink suitcase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "pink"
- }
- }
- },
- {
- "prompt_en": "a black suitcase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "black"
- }
- }
- },
- {
- "prompt_en": "a white suitcase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "white"
- }
- }
- },
- {
- "prompt_en": "a red bowl",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "red"
- }
- }
- },
- {
- "prompt_en": "a green bowl",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "green"
- }
- }
- },
- {
- "prompt_en": "a blue bowl",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "blue"
- }
- }
- },
- {
- "prompt_en": "a yellow bowl",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "yellow"
- }
- }
- },
- {
- "prompt_en": "an orange bowl",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "orange"
- }
- }
- },
- {
- "prompt_en": "a purple bowl",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "purple"
- }
- }
- },
- {
- "prompt_en": "a pink bowl",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "pink"
- }
- }
- },
- {
- "prompt_en": "a black bowl",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "black"
- }
- }
- },
- {
- "prompt_en": "a white bowl",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "white"
- }
- }
- },
- {
- "prompt_en": "a red chair",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "red"
- }
- }
- },
- {
- "prompt_en": "a green chair",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "green"
- }
- }
- },
- {
- "prompt_en": "a blue chair",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "blue"
- }
- }
- },
- {
- "prompt_en": "a yellow chair",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "yellow"
- }
- }
- },
- {
- "prompt_en": "an orange chair",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "orange"
- }
- }
- },
- {
- "prompt_en": "a purple chair",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "purple"
- }
- }
- },
- {
- "prompt_en": "a pink chair",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "pink"
- }
- }
- },
- {
- "prompt_en": "a black chair",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "black"
- }
- }
- },
- {
- "prompt_en": "a white chair",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "white"
- }
- }
- },
- {
- "prompt_en": "a red clock",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "red"
- }
- }
- },
- {
- "prompt_en": "a green clock",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "green"
- }
- }
- },
- {
- "prompt_en": "a blue clock",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "blue"
- }
- }
- },
- {
- "prompt_en": "a yellow clock",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "yellow"
- }
- }
- },
- {
- "prompt_en": "an orange clock",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "orange"
- }
- }
- },
- {
- "prompt_en": "a purple clock",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "purple"
- }
- }
- },
- {
- "prompt_en": "a pink clock",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "pink"
- }
- }
- },
- {
- "prompt_en": "a black clock",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "black"
- }
- }
- },
- {
- "prompt_en": "a white clock",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "white"
- }
- }
- },
- {
- "prompt_en": "a red vase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "red"
- }
- }
- },
- {
- "prompt_en": "a green vase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "green"
- }
- }
- },
- {
- "prompt_en": "a blue vase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "blue"
- }
- }
- },
- {
- "prompt_en": "a yellow vase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "yellow"
- }
- }
- },
- {
- "prompt_en": "an orange vase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "orange"
- }
- }
- },
- {
- "prompt_en": "a purple vase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "purple"
- }
- }
- },
- {
- "prompt_en": "a pink vase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "pink"
- }
- }
- },
- {
- "prompt_en": "a black vase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "black"
- }
- }
- },
- {
- "prompt_en": "a white vase",
- "dimension": [
- "color"
- ],
- "auxiliary_info": {
- "color": {
- "color": "white"
- }
- }
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, Van Gogh style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "Van Gogh style"
- }
- }
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, oil painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "oil painting"
- }
- }
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand by Hokusai, in the style of Ukiyo",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "by Hokusai, in the style of Ukiyo"
- }
- }
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, black and white",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "black and white"
- }
- }
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, pixel art",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "pixel art"
- }
- }
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, in cyberpunk style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "in cyberpunk style"
- }
- }
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, animated style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "animated style"
- }
- }
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, watercolor painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "watercolor painting"
- }
- }
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, surrealism style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "surrealism style"
- }
- }
- },
- {
- "prompt_en": "The bund Shanghai, Van Gogh style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "Van Gogh style"
- }
- }
- },
- {
- "prompt_en": "The bund Shanghai, oil painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "oil painting"
- }
- }
- },
- {
- "prompt_en": "The bund Shanghai by Hokusai, in the style of Ukiyo",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "by Hokusai, in the style of Ukiyo"
- }
- }
- },
- {
- "prompt_en": "The bund Shanghai, black and white",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "black and white"
- }
- }
- },
- {
- "prompt_en": "The bund Shanghai, pixel art",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "pixel art"
- }
- }
- },
- {
- "prompt_en": "The bund Shanghai, in cyberpunk style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "in cyberpunk style"
- }
- }
- },
- {
- "prompt_en": "The bund Shanghai, animated style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "animated style"
- }
- }
- },
- {
- "prompt_en": "The bund Shanghai, watercolor painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "watercolor painting"
- }
- }
- },
- {
- "prompt_en": "The bund Shanghai, surrealism style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "surrealism style"
- }
- }
- },
- {
- "prompt_en": "a shark is swimming in the ocean, Van Gogh style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "Van Gogh style"
- }
- }
- },
- {
- "prompt_en": "a shark is swimming in the ocean, oil painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "oil painting"
- }
- }
- },
- {
- "prompt_en": "a shark is swimming in the ocean by Hokusai, in the style of Ukiyo",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "by Hokusai, in the style of Ukiyo"
- }
- }
- },
- {
- "prompt_en": "a shark is swimming in the ocean, black and white",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "black and white"
- }
- }
- },
- {
- "prompt_en": "a shark is swimming in the ocean, pixel art",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "pixel art"
- }
- }
- },
- {
- "prompt_en": "a shark is swimming in the ocean, in cyberpunk style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "in cyberpunk style"
- }
- }
- },
- {
- "prompt_en": "a shark is swimming in the ocean, animated style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "animated style"
- }
- }
- },
- {
- "prompt_en": "a shark is swimming in the ocean, watercolor painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "watercolor painting"
- }
- }
- },
- {
- "prompt_en": "a shark is swimming in the ocean, surrealism style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "surrealism style"
- }
- }
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, Van Gogh style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "Van Gogh style"
- }
- }
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, oil painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "oil painting"
- }
- }
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris by Hokusai, in the style of Ukiyo",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "by Hokusai, in the style of Ukiyo"
- }
- }
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, black and white",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "black and white"
- }
- }
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, pixel art",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "pixel art"
- }
- }
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, in cyberpunk style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "in cyberpunk style"
- }
- }
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, animated style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "animated style"
- }
- }
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, watercolor painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "watercolor painting"
- }
- }
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, surrealism style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "surrealism style"
- }
- }
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, Van Gogh style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "Van Gogh style"
- }
- }
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, oil painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "oil painting"
- }
- }
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset by Hokusai, in the style of Ukiyo",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "by Hokusai, in the style of Ukiyo"
- }
- }
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, black and white",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "black and white"
- }
- }
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, pixel art",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "pixel art"
- }
- }
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, in cyberpunk style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "in cyberpunk style"
- }
- }
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, animated style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "animated style"
- }
- }
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, watercolor painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "watercolor painting"
- }
- }
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, surrealism style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "surrealism style"
- }
- }
- },
- {
- "prompt_en": "Gwen Stacy reading a book, Van Gogh style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "Van Gogh style"
- }
- }
- },
- {
- "prompt_en": "Gwen Stacy reading a book, oil painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "oil painting"
- }
- }
- },
- {
- "prompt_en": "Gwen Stacy reading a book by Hokusai, in the style of Ukiyo",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "by Hokusai, in the style of Ukiyo"
- }
- }
- },
- {
- "prompt_en": "Gwen Stacy reading a book, black and white",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "black and white"
- }
- }
- },
- {
- "prompt_en": "Gwen Stacy reading a book, pixel art",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "pixel art"
- }
- }
- },
- {
- "prompt_en": "Gwen Stacy reading a book, in cyberpunk style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "in cyberpunk style"
- }
- }
- },
- {
- "prompt_en": "Gwen Stacy reading a book, animated style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "animated style"
- }
- }
- },
- {
- "prompt_en": "Gwen Stacy reading a book, watercolor painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "watercolor painting"
- }
- }
- },
- {
- "prompt_en": "Gwen Stacy reading a book, surrealism style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "surrealism style"
- }
- }
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, Van Gogh style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "Van Gogh style"
- }
- }
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, oil painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "oil painting"
- }
- }
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background by Hokusai, in the style of Ukiyo",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "by Hokusai, in the style of Ukiyo"
- }
- }
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, black and white",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "black and white"
- }
- }
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pixel art",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "pixel art"
- }
- }
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, in cyberpunk style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "in cyberpunk style"
- }
- }
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, animated style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "animated style"
- }
- }
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, watercolor painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "watercolor painting"
- }
- }
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, surrealism style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "surrealism style"
- }
- }
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, Van Gogh style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "Van Gogh style"
- }
- }
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, oil painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "oil painting"
- }
- }
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas by Hokusai, in the style of Ukiyo",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "by Hokusai, in the style of Ukiyo"
- }
- }
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, black and white",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "black and white"
- }
- }
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pixel art",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "pixel art"
- }
- }
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, in cyberpunk style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "in cyberpunk style"
- }
- }
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, animated style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "animated style"
- }
- }
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, watercolor painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "watercolor painting"
- }
- }
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, surrealism style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "surrealism style"
- }
- }
- },
- {
- "prompt_en": "An astronaut flying in space, Van Gogh style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "Van Gogh style"
- }
- }
- },
- {
- "prompt_en": "An astronaut flying in space, oil painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "oil painting"
- }
- }
- },
- {
- "prompt_en": "An astronaut flying in space by Hokusai, in the style of Ukiyo",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "by Hokusai, in the style of Ukiyo"
- }
- }
- },
- {
- "prompt_en": "An astronaut flying in space, black and white",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "black and white"
- }
- }
- },
- {
- "prompt_en": "An astronaut flying in space, pixel art",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "pixel art"
- }
- }
- },
- {
- "prompt_en": "An astronaut flying in space, in cyberpunk style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "in cyberpunk style"
- }
- }
- },
- {
- "prompt_en": "An astronaut flying in space, animated style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "animated style"
- }
- }
- },
- {
- "prompt_en": "An astronaut flying in space, watercolor painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "watercolor painting"
- }
- }
- },
- {
- "prompt_en": "An astronaut flying in space, surrealism style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "surrealism style"
- }
- }
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, Van Gogh style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "Van Gogh style"
- }
- }
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, oil painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "oil painting"
- }
- }
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks by Hokusai, in the style of Ukiyo",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "by Hokusai, in the style of Ukiyo"
- }
- }
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, black and white",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "black and white"
- }
- }
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pixel art",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "pixel art"
- }
- }
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, in cyberpunk style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "in cyberpunk style"
- }
- }
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, animated style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "animated style"
- }
- }
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, watercolor painting",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "watercolor painting"
- }
- }
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, surrealism style",
- "dimension": [
- "appearance_style"
- ],
- "auxiliary_info": {
- "appearance_style": {
- "appearance_style": "surrealism style"
- }
- }
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, in super slow motion",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, zoom in",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, zoom out",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, pan left",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, pan right",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, tilt up",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, tilt down",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, with an intense shaking effect",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, featuring a steady and smooth perspective",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand, racking focus",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, in super slow motion",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, zoom in",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, zoom out",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, pan left",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, pan right",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, tilt up",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, tilt down",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, with an intense shaking effect",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, featuring a steady and smooth perspective",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, racking focus",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean, in super slow motion",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean, zoom in",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean, zoom out",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean, pan left",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean, pan right",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean, tilt up",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean, tilt down",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean, with an intense shaking effect",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean, featuring a steady and smooth perspective",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean, racking focus",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, in super slow motion",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, zoom in",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, zoom out",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, pan left",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, pan right",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, tilt up",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, tilt down",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, with an intense shaking effect",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, featuring a steady and smooth perspective",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris, racking focus",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, in super slow motion",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, zoom in",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, zoom out",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, pan left",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, pan right",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, tilt up",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, tilt down",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, with an intense shaking effect",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, featuring a steady and smooth perspective",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset, racking focus",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book, in super slow motion",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book, zoom in",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book, zoom out",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book, pan left",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book, pan right",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book, tilt up",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book, tilt down",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book, with an intense shaking effect",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book, featuring a steady and smooth perspective",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book, racking focus",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, in super slow motion",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, zoom in",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, zoom out",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pan left",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pan right",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, tilt up",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, tilt down",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, with an intense shaking effect",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, featuring a steady and smooth perspective",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background, racking focus",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, in super slow motion",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, zoom in",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, zoom out",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pan left",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pan right",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, tilt up",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, tilt down",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, with an intense shaking effect",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, featuring a steady and smooth perspective",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, racking focus",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space, in super slow motion",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space, zoom in",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space, zoom out",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space, pan left",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space, pan right",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space, tilt up",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space, tilt down",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space, with an intense shaking effect",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space, featuring a steady and smooth perspective",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space, racking focus",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, in super slow motion",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, zoom in",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, zoom out",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pan left",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pan right",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, tilt up",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, tilt down",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, with an intense shaking effect",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, featuring a steady and smooth perspective",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, racking focus",
- "dimension": [
- "temporal_style"
- ]
- },
- {
- "prompt_en": "Close up of grapes on a rotating table.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Turtle swimming in ocean.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A storm trooper vacuuming the beach.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A panda standing on a surfboard in the ocean in sunset.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "An astronaut feeding ducks on a sunny afternoon, reflection from the water.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Two pandas discussing an academic paper.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Sunset time lapse at the beach with moving clouds and colors in the sky.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A fat rabbit wearing a purple robe walking through a fantasy landscape.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A koala bear playing piano in the forest.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "An astronaut flying in space.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Fireworks.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "An animated painting of fluffy white clouds moving in sky.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Flying through fantasy landscapes.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A bigfoot walking in the snowstorm.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A squirrel eating a burger.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A cat wearing sunglasses and working as a lifeguard at a pool.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Splash of turquoise water in extreme slow motion, alpha channel included.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "an ice cream is melting on the table.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "a drone flying over a snowy forest.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "a shark is swimming in the ocean.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Aerial panoramic video from a drone of a fantasy land.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "a teddy bear is swimming in the ocean.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "time lapse of sunrise on mars.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "golden fish swimming in the ocean.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "An artist brush painting on a canvas close up.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A drone view of celebration with Christmas tree and fireworks, starry sky - background.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "happy dog wearing a yellow turtleneck, studio, portrait, facing camera, dark background",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Origami dancers in white paper, 3D render, on white background, studio shot, dancing modern dance.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Campfire at night in a snowy forest with starry sky in the background.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "a fantasy landscape",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A 3D model of a 1800s victorian house.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "this is how I do makeup in the morning.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A raccoon that looks like a turtle, digital art.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Robot dancing in Times Square.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Busy freeway at night.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Balloon full of water exploding in extreme slow motion.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "An astronaut is riding a horse in the space in a photorealistic style.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Macro slo-mo. Slow motion cropped closeup of roasted coffee beans falling into an empty bowl.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Sewing machine, old sewing machine working.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Motion colour drop in water, ink swirling in water, colourful ink in water, abstraction fancy dream cloud of ink.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Few big purple plums rotating on the turntable. water drops appear on the skin during rotation. isolated on the white background. close-up. macro.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Vampire makeup face of beautiful girl, red contact lenses.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Ashtray full of butts on table, smoke flowing on black background, close-up",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Pacific coast, carmel by the sea ocean and waves.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A teddy bear is playing drum kit in NYC Times Square.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A corgi is playing drum kit.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "An Iron man is playing the electronic guitar, high electronic guitar.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A raccoon is playing the electronic guitar.",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background by Vincent van Gogh",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A corgi's head depicted as an explosion of a nebula",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A fantasy landscape",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A future where humans have achieved teleportation technology",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A jellyfish floating through the ocean, with bioluminescent tentacles",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A Mars rover moving on Mars",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A panda drinking coffee in a cafe in Paris",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A space shuttle launching into orbit, with flames and smoke billowing out from the engines",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A steam train moving on a mountainside",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A super cool giant robot in Cyberpunk Beijing",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A tropical beach at sunrise, with palm trees and crystal-clear water in the foreground",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Cinematic shot of Van Gogh's selfie, Van Gogh style",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Gwen Stacy reading a book",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Iron Man flying in the sky",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, oil painting",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Yoda playing guitar on the stage",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand by Hokusai, in the style of Ukiyo",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A beautiful coastal beach in spring, waves lapping on sand by Vincent van Gogh",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A boat sailing leisurely along the Seine River with the Eiffel Tower in background",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A car moving slowly on an empty street, rainy evening",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A cat eating food out of a bowl",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A cat wearing sunglasses at a pool",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A confused panda in calculus class",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A cute fluffy panda eating Chinese food in a restaurant",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A cute happy Corgi playing in park, sunset",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A cute raccoon playing guitar in a boat on the ocean",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A happy fuzzy panda playing guitar nearby a campfire, snow mountain in the background",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A lightning striking atop of eiffel tower, dark clouds in the sky",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A modern art museum, with colorful paintings",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A panda cooking in the kitchen",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A panda playing on a swing set",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A polar bear is playing guitar",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A raccoon dressed in suit playing the trumpet, stage background",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A robot DJ is playing the turntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-fi, fantasy",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A shark swimming in clear Caribbean ocean",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A super robot protecting city",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "A teddy bear washing the dishes",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "An epic tornado attacking above a glowing city at night, the tornado is made of smoke",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with umbrellas",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Clown fish swimming through the coral reef",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Hyper-realistic spaceship landing on Mars",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "The bund Shanghai, vibrant color",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Vincent van Gogh is painting in the room",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "Yellow flowers swing in the wind",
- "dimension": [
- "overall_consistency",
- "aesthetic_quality",
- "imaging_quality"
- ]
- },
- {
- "prompt_en": "alley",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "alley"
- }
- }
- }
- },
- {
- "prompt_en": "amusement park",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "amusement park"
- }
- }
- }
- },
- {
- "prompt_en": "aquarium",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "aquarium"
- }
- }
- }
- },
- {
- "prompt_en": "arch",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "arch"
- }
- }
- }
- },
- {
- "prompt_en": "art gallery",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "art gallery"
- }
- }
- }
- },
- {
- "prompt_en": "bathroom",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "bathroom"
- }
- }
- }
- },
- {
- "prompt_en": "bakery shop",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "bakery shop"
- }
- }
- }
- },
- {
- "prompt_en": "ballroom",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "ballroom"
- }
- }
- }
- },
- {
- "prompt_en": "bar",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "bar"
- }
- }
- }
- },
- {
- "prompt_en": "barn",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "barn"
- }
- }
- }
- },
- {
- "prompt_en": "basement",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "basement"
- }
- }
- }
- },
- {
- "prompt_en": "beach",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "beach"
- }
- }
- }
- },
- {
- "prompt_en": "bedroom",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "bedroom"
- }
- }
- }
- },
- {
- "prompt_en": "bridge",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "bridge"
- }
- }
- }
- },
- {
- "prompt_en": "botanical garden",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "botanical garden"
- }
- }
- }
- },
- {
- "prompt_en": "cafeteria",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "cafeteria"
- }
- }
- }
- },
- {
- "prompt_en": "campsite",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "campsite"
- }
- }
- }
- },
- {
- "prompt_en": "campus",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "campus"
- }
- }
- }
- },
- {
- "prompt_en": "carrousel",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "carrousel"
- }
- }
- }
- },
- {
- "prompt_en": "castle",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "castle"
- }
- }
- }
- },
- {
- "prompt_en": "cemetery",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "cemetery"
- }
- }
- }
- },
- {
- "prompt_en": "classroom",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "classroom"
- }
- }
- }
- },
- {
- "prompt_en": "cliff",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "cliff"
- }
- }
- }
- },
- {
- "prompt_en": "crosswalk",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "crosswalk"
- }
- }
- }
- },
- {
- "prompt_en": "construction site",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "construction site"
- }
- }
- }
- },
- {
- "prompt_en": "corridor",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "corridor"
- }
- }
- }
- },
- {
- "prompt_en": "courtyard",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "courtyard"
- }
- }
- }
- },
- {
- "prompt_en": "desert",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "desert"
- }
- }
- }
- },
- {
- "prompt_en": "downtown",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "downtown"
- }
- }
- }
- },
- {
- "prompt_en": "driveway",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "driveway"
- }
- }
- }
- },
- {
- "prompt_en": "farm",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "farm"
- }
- }
- }
- },
- {
- "prompt_en": "food court",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "food court"
- }
- }
- }
- },
- {
- "prompt_en": "football field",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "football field"
- }
- }
- }
- },
- {
- "prompt_en": "forest road",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "forest road"
- }
- }
- }
- },
- {
- "prompt_en": "fountain",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "fountain"
- }
- }
- }
- },
- {
- "prompt_en": "gas station",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "gas station"
- }
- }
- }
- },
- {
- "prompt_en": "glacier",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "glacier"
- }
- }
- }
- },
- {
- "prompt_en": "golf course",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "golf course"
- }
- }
- }
- },
- {
- "prompt_en": "indoor gymnasium",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "indoor gymnasium"
- }
- }
- }
- },
- {
- "prompt_en": "harbor",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "harbor"
- }
- }
- }
- },
- {
- "prompt_en": "highway",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "highway"
- }
- }
- }
- },
- {
- "prompt_en": "hospital",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "hospital"
- }
- }
- }
- },
- {
- "prompt_en": "house",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "house"
- }
- }
- }
- },
- {
- "prompt_en": "iceberg",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "iceberg"
- }
- }
- }
- },
- {
- "prompt_en": "industrial area",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "industrial area"
- }
- }
- }
- },
- {
- "prompt_en": "jail cell",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "jail cell"
- }
- }
- }
- },
- {
- "prompt_en": "junkyard",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "junkyard"
- }
- }
- }
- },
- {
- "prompt_en": "kitchen",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "kitchen"
- }
- }
- }
- },
- {
- "prompt_en": "indoor library",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "indoor library"
- }
- }
- }
- },
- {
- "prompt_en": "lighthouse",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "lighthouse"
- }
- }
- }
- },
- {
- "prompt_en": "laboratory",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "laboratory"
- }
- }
- }
- },
- {
- "prompt_en": "mansion",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "mansion"
- }
- }
- }
- },
- {
- "prompt_en": "marsh",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "marsh"
- }
- }
- }
- },
- {
- "prompt_en": "mountain",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "mountain"
- }
- }
- }
- },
- {
- "prompt_en": "indoor movie theater",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "indoor movie theater"
- }
- }
- }
- },
- {
- "prompt_en": "indoor museum",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "indoor museum"
- }
- }
- }
- },
- {
- "prompt_en": "music studio",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "music studio"
- }
- }
- }
- },
- {
- "prompt_en": "nursery",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "nursery"
- }
- }
- }
- },
- {
- "prompt_en": "ocean",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "ocean"
- }
- }
- }
- },
- {
- "prompt_en": "office",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "office"
- }
- }
- }
- },
- {
- "prompt_en": "palace",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "palace"
- }
- }
- }
- },
- {
- "prompt_en": "parking lot",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "parking lot"
- }
- }
- }
- },
- {
- "prompt_en": "pharmacy",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "pharmacy"
- }
- }
- }
- },
- {
- "prompt_en": "phone booth",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "phone booth"
- }
- }
- }
- },
- {
- "prompt_en": "raceway",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "raceway"
- }
- }
- }
- },
- {
- "prompt_en": "restaurant",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "restaurant"
- }
- }
- }
- },
- {
- "prompt_en": "river",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "river"
- }
- }
- }
- },
- {
- "prompt_en": "science museum",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "science museum"
- }
- }
- }
- },
- {
- "prompt_en": "shower",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "shower"
- }
- }
- }
- },
- {
- "prompt_en": "ski slope",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "ski slope"
- }
- }
- }
- },
- {
- "prompt_en": "sky",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "sky"
- }
- }
- }
- },
- {
- "prompt_en": "skyscraper",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "skyscraper"
- }
- }
- }
- },
- {
- "prompt_en": "baseball stadium",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "baseball stadium"
- }
- }
- }
- },
- {
- "prompt_en": "staircase",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "staircase"
- }
- }
- }
- },
- {
- "prompt_en": "street",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "street"
- }
- }
- }
- },
- {
- "prompt_en": "supermarket",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "supermarket"
- }
- }
- }
- },
- {
- "prompt_en": "indoor swimming pool",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "indoor swimming pool"
- }
- }
- }
- },
- {
- "prompt_en": "tower",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "tower"
- }
- }
- }
- },
- {
- "prompt_en": "outdoor track",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "outdoor track"
- }
- }
- }
- },
- {
- "prompt_en": "train railway",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "train railway"
- }
- }
- }
- },
- {
- "prompt_en": "train station platform",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "train station platform"
- }
- }
- }
- },
- {
- "prompt_en": "underwater coral reef",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "underwater coral reef"
- }
- }
- }
- },
- {
- "prompt_en": "valley",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "valley"
- }
- }
- }
- },
- {
- "prompt_en": "volcano",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "volcano"
- }
- }
- }
- },
- {
- "prompt_en": "waterfall",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "waterfall"
- }
- }
- }
- },
- {
- "prompt_en": "windmill",
- "dimension": [
- "scene",
- "background_consistency"
- ],
- "auxiliary_info": {
- "scene": {
- "scene": {
- "scene": "windmill"
- }
- }
- }
- },
- {
- "prompt_en": "a bicycle on the left of a car, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "bicycle",
- "object_b": "car",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a car on the right of a motorcycle, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "car",
- "object_b": "motorcycle",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a motorcycle on the left of a bus, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "motorcycle",
- "object_b": "bus",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a bus on the right of a traffic light, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "bus",
- "object_b": "traffic light",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a traffic light on the left of a fire hydrant, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "traffic light",
- "object_b": "fire hydrant",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a fire hydrant on the right of a stop sign, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "fire hydrant",
- "object_b": "stop sign",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a stop sign on the left of a parking meter, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "stop sign",
- "object_b": "parking meter",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a parking meter on the right of a bench, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "parking meter",
- "object_b": "bench",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a bench on the left of a truck, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "bench",
- "object_b": "truck",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a truck on the right of a bicycle, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "truck",
- "object_b": "bicycle",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a bird on the left of a cat, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "bird",
- "object_b": "cat",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a cat on the right of a dog, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "cat",
- "object_b": "dog",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a dog on the left of a horse, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "dog",
- "object_b": "horse",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a horse on the right of a sheep, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "horse",
- "object_b": "sheep",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a sheep on the left of a cow, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "sheep",
- "object_b": "cow",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a cow on the right of an elephant, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "cow",
- "object_b": "elephant",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "an elephant on the left of a bear, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "elephant",
- "object_b": "bear",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a bear on the right of a zebra, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "bear",
- "object_b": "zebra",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a zebra on the left of a giraffe, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "zebra",
- "object_b": "giraffe",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a giraffe on the right of a bird, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "giraffe",
- "object_b": "bird",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a bottle on the left of a wine glass, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "bottle",
- "object_b": "wine glass",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a wine glass on the right of a cup, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "wine glass",
- "object_b": "cup",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a cup on the left of a fork, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "cup",
- "object_b": "fork",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a fork on the right of a knife, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "fork",
- "object_b": "knife",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a knife on the left of a spoon, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "knife",
- "object_b": "spoon",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a spoon on the right of a bowl, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "spoon",
- "object_b": "bowl",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a bowl on the left of a bottle, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "bowl",
- "object_b": "bottle",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a potted plant on the left of a remote, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "potted plant",
- "object_b": "remote",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a remote on the right of a clock, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "remote",
- "object_b": "clock",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a clock on the left of a vase, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "clock",
- "object_b": "vase",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a vase on the right of scissors, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "vase",
- "object_b": "scissors",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "scissors on the left of a teddy bear, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "scissors",
- "object_b": "teddy bear",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a teddy bear on the right of a potted plant, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "teddy bear",
- "object_b": "potted plant",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a frisbee on the left of a sports ball, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "frisbee",
- "object_b": "sports ball",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a sports ball on the right of a baseball bat, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "sports ball",
- "object_b": "baseball bat",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a baseball bat on the left of a baseball glove, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "baseball bat",
- "object_b": "baseball glove",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a baseball glove on the right of a tennis racket, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "baseball glove",
- "object_b": "tennis racket",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a tennis racket on the left of a frisbee, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "tennis racket",
- "object_b": "frisbee",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a toilet on the left of a hair drier, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "toilet",
- "object_b": "hair drier",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a hair drier on the right of a toothbrush, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "hair drier",
- "object_b": "toothbrush",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a toothbrush on the left of a sink, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "toothbrush",
- "object_b": "sink",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a sink on the right of a toilet, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "sink",
- "object_b": "toilet",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a chair on the left of a couch, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "chair",
- "object_b": "couch",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a couch on the right of a bed, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "couch",
- "object_b": "bed",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a bed on the left of a tv, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "bed",
- "object_b": "tv",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a tv on the right of a dining table, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "tv",
- "object_b": "dining table",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a dining table on the left of a chair, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "dining table",
- "object_b": "chair",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "an airplane on the left of a train, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "airplane",
- "object_b": "train",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "a train on the right of a boat, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "train",
- "object_b": "boat",
- "relationship": "on the right of"
- }
- }
- }
- },
- {
- "prompt_en": "a boat on the left of an airplane, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "boat",
- "object_b": "airplane",
- "relationship": "on the left of"
- }
- }
- }
- },
- {
- "prompt_en": "an oven on the top of a toaster, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "oven",
- "object_b": "toaster",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "an oven on the bottom of a toaster, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "oven",
- "object_b": "toaster",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a toaster on the top of a microwave, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "toaster",
- "object_b": "microwave",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a toaster on the bottom of a microwave, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "toaster",
- "object_b": "microwave",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a microwave on the top of an oven, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "microwave",
- "object_b": "oven",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a microwave on the bottom of an oven, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "microwave",
- "object_b": "oven",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a banana on the top of an apple, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "banana",
- "object_b": "apple",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a banana on the bottom of an apple, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "banana",
- "object_b": "apple",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "an apple on the top of a sandwich, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "apple",
- "object_b": "sandwich",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "an apple on the bottom of a sandwich, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "apple",
- "object_b": "sandwich",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a sandwich on the top of an orange, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "sandwich",
- "object_b": "orange",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a sandwich on the bottom of an orange, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "sandwich",
- "object_b": "orange",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "an orange on the top of a carrot, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "orange",
- "object_b": "carrot",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "an orange on the bottom of a carrot, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "orange",
- "object_b": "carrot",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a carrot on the top of a hot dog, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "carrot",
- "object_b": "hot dog",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a carrot on the bottom of a hot dog, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "carrot",
- "object_b": "hot dog",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a hot dog on the top of a pizza, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "hot dog",
- "object_b": "pizza",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a hot dog on the bottom of a pizza, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "hot dog",
- "object_b": "pizza",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a pizza on the top of a donut, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "pizza",
- "object_b": "donut",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a pizza on the bottom of a donut, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "pizza",
- "object_b": "donut",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a donut on the top of broccoli, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "donut",
- "object_b": "broccoli",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a donut on the bottom of broccoli, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "donut",
- "object_b": "broccoli",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "broccoli on the top of a banana, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "broccoli",
- "object_b": "banana",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "broccoli on the bottom of a banana, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "broccoli",
- "object_b": "banana",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "skis on the top of a snowboard, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "skis",
- "object_b": "snowboard",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "skis on the bottom of a snowboard, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "skis",
- "object_b": "snowboard",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a snowboard on the top of a kite, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "snowboard",
- "object_b": "kite",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a snowboard on the bottom of a kite, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "snowboard",
- "object_b": "kite",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a kite on the top of a skateboard, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "kite",
- "object_b": "skateboard",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a kite on the bottom of a skateboard, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "kite",
- "object_b": "skateboard",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a skateboard on the top of a surfboard, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "skateboard",
- "object_b": "surfboard",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a skateboard on the bottom of a surfboard, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "skateboard",
- "object_b": "surfboard",
- "relationship": "on the bottom of"
- }
- }
- }
- },
- {
- "prompt_en": "a surfboard on the top of skis, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "surfboard",
- "object_b": "skis",
- "relationship": "on the top of"
- }
- }
- }
- },
- {
- "prompt_en": "a surfboard on the bottom of skis, front view",
- "dimension": [
- "spatial_relationship"
- ],
- "auxiliary_info": {
- "spatial_relationship": {
- "spatial_relationship": {
- "object_a": "surfboard",
- "object_b": "skis",
- "relationship": "on the bottom of"
- }
- }
- }
- }
-]
diff --git a/eval/pab/vbench/cal_vbench.py b/eval/pab/vbench/cal_vbench.py
deleted file mode 100644
index ec1cbbab64e9977983ae8c3349df1d8e0f03bdb0..0000000000000000000000000000000000000000
--- a/eval/pab/vbench/cal_vbench.py
+++ /dev/null
@@ -1,154 +0,0 @@
-import argparse
-import json
-import os
-
-SEMANTIC_WEIGHT = 1
-QUALITY_WEIGHT = 4
-
-QUALITY_LIST = [
- "subject consistency",
- "background consistency",
- "temporal flickering",
- "motion smoothness",
- "aesthetic quality",
- "imaging quality",
- "dynamic degree",
-]
-
-SEMANTIC_LIST = [
- "object class",
- "multiple objects",
- "human action",
- "color",
- "spatial relationship",
- "scene",
- "appearance style",
- "temporal style",
- "overall consistency",
-]
-
-NORMALIZE_DIC = {
- "subject consistency": {"Min": 0.1462, "Max": 1.0},
- "background consistency": {"Min": 0.2615, "Max": 1.0},
- "temporal flickering": {"Min": 0.6293, "Max": 1.0},
- "motion smoothness": {"Min": 0.706, "Max": 0.9975},
- "dynamic degree": {"Min": 0.0, "Max": 1.0},
- "aesthetic quality": {"Min": 0.0, "Max": 1.0},
- "imaging quality": {"Min": 0.0, "Max": 1.0},
- "object class": {"Min": 0.0, "Max": 1.0},
- "multiple objects": {"Min": 0.0, "Max": 1.0},
- "human action": {"Min": 0.0, "Max": 1.0},
- "color": {"Min": 0.0, "Max": 1.0},
- "spatial relationship": {"Min": 0.0, "Max": 1.0},
- "scene": {"Min": 0.0, "Max": 0.8222},
- "appearance style": {"Min": 0.0009, "Max": 0.2855},
- "temporal style": {"Min": 0.0, "Max": 0.364},
- "overall consistency": {"Min": 0.0, "Max": 0.364},
-}
-
-DIM_WEIGHT = {
- "subject consistency": 1,
- "background consistency": 1,
- "temporal flickering": 1,
- "motion smoothness": 1,
- "aesthetic quality": 1,
- "imaging quality": 1,
- "dynamic degree": 0.5,
- "object class": 1,
- "multiple objects": 1,
- "human action": 1,
- "color": 1,
- "spatial relationship": 1,
- "scene": 1,
- "appearance style": 1,
- "temporal style": 1,
- "overall consistency": 1,
-}
-
-ordered_scaled_res = [
- "total score",
- "quality score",
- "semantic score",
- "subject consistency",
- "background consistency",
- "temporal flickering",
- "motion smoothness",
- "dynamic degree",
- "aesthetic quality",
- "imaging quality",
- "object class",
- "multiple objects",
- "human action",
- "color",
- "spatial relationship",
- "scene",
- "appearance style",
- "temporal style",
- "overall consistency",
-]
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument("--score_dir", required=True, type=str)
- args = parser.parse_args()
- return args
-
-
-if __name__ == "__main__":
- args = parse_args()
- res_postfix = "_eval_results.json"
- info_postfix = "_full_info.json"
- files = os.listdir(args.score_dir)
- res_files = [x for x in files if res_postfix in x]
- info_files = [x for x in files if info_postfix in x]
- assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
-
- full_results = {}
- for res_file in res_files:
- # first check if results is normal
- info_file = res_file.split(res_postfix)[0] + info_postfix
- with open(os.path.join(args.score_dir, info_file), "r", encoding="utf-8") as f:
- info = json.load(f)
- assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
- # read results
- with open(os.path.join(args.score_dir, res_file), "r", encoding="utf-8") as f:
- data = json.load(f)
- for key, val in data.items():
- full_results[key] = format(val[0], ".4f")
-
- scaled_results = {}
- dims = set()
- for key, val in full_results.items():
- dim = key.replace("_", " ") if "_" in key else key
- scaled_score = (float(val) - NORMALIZE_DIC[dim]["Min"]) / (
- NORMALIZE_DIC[dim]["Max"] - NORMALIZE_DIC[dim]["Min"]
- )
- scaled_score *= DIM_WEIGHT[dim]
- scaled_results[dim] = scaled_score
- dims.add(dim)
-
- assert len(dims) == len(NORMALIZE_DIC), f"{set(NORMALIZE_DIC.keys())-dims} not calculated yet"
-
- quality_score = sum([scaled_results[i] for i in QUALITY_LIST]) / sum([DIM_WEIGHT[i] for i in QUALITY_LIST])
- semantic_score = sum([scaled_results[i] for i in SEMANTIC_LIST]) / sum([DIM_WEIGHT[i] for i in SEMANTIC_LIST])
- scaled_results["quality score"] = quality_score
- scaled_results["semantic score"] = semantic_score
- scaled_results["total score"] = (quality_score * QUALITY_WEIGHT + semantic_score * SEMANTIC_WEIGHT) / (
- QUALITY_WEIGHT + SEMANTIC_WEIGHT
- )
-
- formated_scaled_results = {"items": []}
- for key in ordered_scaled_res:
- formated_score = format(scaled_results[key] * 100, ".2f") + "%"
- formated_scaled_results["items"].append({key: formated_score})
-
- output_file_path = os.path.join(args.score_dir, "all_results.json")
- with open(output_file_path, "w") as outfile:
- json.dump(full_results, outfile, indent=4, sort_keys=True)
- print(f"results saved to: {output_file_path}")
-
- scaled_file_path = os.path.join(args.score_dir, "scaled_results.json")
- with open(scaled_file_path, "w") as outfile:
- json.dump(formated_scaled_results, outfile, indent=4, sort_keys=True)
- print(f"results saved to: {scaled_file_path}")
diff --git a/eval/pab/vbench/run_vbench.py b/eval/pab/vbench/run_vbench.py
deleted file mode 100644
index 32df0825502614fc3b2a1f7f56e3b2082ccb207c..0000000000000000000000000000000000000000
--- a/eval/pab/vbench/run_vbench.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import argparse
-
-import torch
-from vbench import VBench
-
-full_info_path = "./vbench/VBench_full_info.json"
-
-dimensions = [
- "subject_consistency",
- "imaging_quality",
- "background_consistency",
- "motion_smoothness",
- "overall_consistency",
- "human_action",
- "multiple_objects",
- "spatial_relationship",
- "object_class",
- "color",
- "aesthetic_quality",
- "appearance_style",
- "temporal_flickering",
- "scene",
- "temporal_style",
- "dynamic_degree",
-]
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument("--video_path", required=True, type=str)
- args = parser.parse_args()
- return args
-
-
-if __name__ == "__main__":
- args = parse_args()
- save_path = args.video_path.replace("/samples/", "/vbench_out/")
-
- kwargs = {}
- kwargs["imaging_quality_preprocessing_mode"] = "longer" # use VBench/evaluate.py default
-
- for dimension in dimensions:
- my_VBench = VBench(torch.device("cuda"), full_info_path, save_path)
- my_VBench.evaluate(
- videos_path=args.video_path,
- name=dimension,
- local=False,
- read_frame=False,
- dimension_list=[dimension],
- mode="vbench_standard",
- **kwargs,
- )
diff --git a/examples/cogvideo/sample.py b/examples/cogvideo/sample.py
deleted file mode 100644
index e9a394c2882eaf9debabdd3184d5a29e651e04cc..0000000000000000000000000000000000000000
--- a/examples/cogvideo/sample.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from videosys import CogVideoConfig, VideoSysEngine
-
-
-def run_base():
- config = CogVideoConfig(world_size=1)
- engine = VideoSysEngine(config)
-
- prompt = "Sunset over the sea."
- video = engine.generate(prompt).video[0]
- engine.save_video(video, f"./outputs/{prompt}.mp4")
-
-
-if __name__ == "__main__":
- run_base()
diff --git a/examples/latte/sample.py b/examples/latte/sample.py
deleted file mode 100644
index 45f421d831402611f457ec73fe8739162fbe113b..0000000000000000000000000000000000000000
--- a/examples/latte/sample.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from videosys import LatteConfig, VideoSysEngine
-
-
-def run_base():
- config = LatteConfig(world_size=1)
- engine = VideoSysEngine(config)
-
- prompt = "Sunset over the sea."
- video = engine.generate(prompt).video[0]
- engine.save_video(video, f"./outputs/{prompt}.mp4")
-
-
-def run_pab():
- config = LatteConfig(world_size=1)
- engine = VideoSysEngine(config)
-
- prompt = "Sunset over the sea."
- video = engine.generate(prompt).video[0]
- engine.save_video(video, f"./outputs/{prompt}.mp4")
-
-
-if __name__ == "__main__":
- run_base()
- # run_pab()
diff --git a/examples/open_sora/sample.py b/examples/open_sora/sample.py
deleted file mode 100644
index 17a89f921d09ec5b71aaf98e210afc664aaa2385..0000000000000000000000000000000000000000
--- a/examples/open_sora/sample.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from videosys import OpenSoraConfig, VideoSysEngine
-
-
-def run_base():
- config = OpenSoraConfig(world_size=1)
- engine = VideoSysEngine(config)
-
- prompt = "Sunset over the sea."
- video = engine.generate(prompt).video[0]
- engine.save_video(video, f"./outputs/{prompt}.mp4")
-
-
-def run_pab():
- config = OpenSoraConfig(world_size=1, enable_pab=True)
- engine = VideoSysEngine(config)
-
- prompt = "Sunset over the sea."
- video = engine.generate(prompt).video[0]
- engine.save_video(video, f"./outputs/{prompt}.mp4")
-
-
-if __name__ == "__main__":
- run_base()
- run_pab()
diff --git a/examples/open_sora_plan/sample.py b/examples/open_sora_plan/sample.py
deleted file mode 100644
index b3f3e9681a08906b28329db6a46c98c0a9ce2684..0000000000000000000000000000000000000000
--- a/examples/open_sora_plan/sample.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from videosys import OpenSoraPlanConfig, VideoSysEngine
-
-
-def run_base():
- config = OpenSoraPlanConfig(world_size=1)
- engine = VideoSysEngine(config)
-
- prompt = "Sunset over the sea."
- video = engine.generate(prompt).video[0]
- engine.save_video(video, f"./outputs/{prompt}.mp4")
-
-
-def run_pab():
- config = OpenSoraPlanConfig(world_size=1)
- engine = VideoSysEngine(config)
-
- prompt = "Sunset over the sea."
- video = engine.generate(prompt).video[0]
- engine.save_video(video, f"./outputs/{prompt}.mp4")
-
-
-if __name__ == "__main__":
- run_base()
- # run_pab()
diff --git a/videosys/__init__.py b/videosys/__init__.py
index 6fd86b4acbb1c9e577d3d6b9298b2d9824695e3c..859fb7c37d8d5d081c3edc9173d18e265354f31c 100644
--- a/videosys/__init__.py
+++ b/videosys/__init__.py
@@ -1,19 +1,15 @@
from .core.engine import VideoSysEngine
from .core.parallel_mgr import initialize
-from .models.cogvideo.pipeline import CogVideoConfig, CogVideoPipeline
-from .models.latte.pipeline import LatteConfig, LattePipeline
-from .models.open_sora.pipeline import OpenSoraConfig, OpenSoraPipeline
-from .models.open_sora_plan.pipeline import OpenSoraPlanConfig, OpenSoraPlanPipeline
+from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
+from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
+from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
+from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
__all__ = [
"initialize",
"VideoSysEngine",
- "LattePipeline",
- "LatteConfig",
- "OpenSoraPlanPipeline",
- "OpenSoraPlanConfig",
- "OpenSoraPipeline",
- "OpenSoraConfig",
- "CogVideoConfig",
- "CogVideoPipeline",
-]
+ "LattePipeline", "LatteConfig", "LattePABConfig",
+ "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig",
+ "OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
+ "CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"
+] # fmt: skip
diff --git a/videosys/core/engine.py b/videosys/core/engine.py
index de0976159e51a9b74330e2d7b0879d54efaa6ece..6d4408ef7e1ebacbe1156a68d4450a0b045e6875 100644
--- a/videosys/core/engine.py
+++ b/videosys/core/engine.py
@@ -2,7 +2,6 @@ import os
from functools import partial
from typing import Any, Optional
-import imageio
import torch
import videosys
@@ -120,8 +119,7 @@ class VideoSysEngine:
result.get()
def save_video(self, video, output_path):
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
- imageio.mimwrite(output_path, video, fps=24)
+ return self.driver_worker.save_video(video, output_path)
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
@@ -129,4 +127,4 @@ class VideoSysEngine:
torch.distributed.destroy_process_group()
def __del__(self):
- self.shutdown()
\ No newline at end of file
+ self.shutdown()
diff --git a/videosys/core/pab_mgr.py b/videosys/core/pab_mgr.py
index 56ce857c85efe6769795d4044bc4bad6c42f61c7..6bf6e69a71dce30ffc080879058d18b95419868e 100644
--- a/videosys/core/pab_mgr.py
+++ b/videosys/core/pab_mgr.py
@@ -1,8 +1,3 @@
-import random
-
-import numpy as np
-import torch
-
from videosys.utils.logging import logger
PAB_MANAGER = None
@@ -12,71 +7,56 @@ class PABConfig:
def __init__(
self,
steps: int,
- cross_broadcast: bool,
- cross_threshold: list,
- cross_gap: int,
- spatial_broadcast: bool,
- spatial_threshold: list,
- spatial_gap: int,
- temporal_broadcast: bool,
- temporal_threshold: list,
- temporal_gap: int,
- diffusion_skip: bool,
- diffusion_timestep_respacing: list,
- diffusion_skip_timestep: list,
- mlp_skip: bool,
- mlp_spatial_skip_config: dict,
- mlp_temporal_skip_config: dict,
- full_broadcast: bool = False,
- full_threshold: list = None,
- full_gap: int = 1,
+ cross_broadcast: bool = False,
+ cross_threshold: list = None,
+ cross_range: int = None,
+ spatial_broadcast: bool = False,
+ spatial_threshold: list = None,
+ spatial_range: int = None,
+ temporal_broadcast: bool = False,
+ temporal_threshold: list = None,
+ temporal_range: int = None,
+ mlp_broadcast: bool = False,
+ mlp_spatial_broadcast_config: dict = None,
+ mlp_temporal_broadcast_config: dict = None,
):
self.steps = steps
self.cross_broadcast = cross_broadcast
self.cross_threshold = cross_threshold
- self.cross_gap = cross_gap
+ self.cross_range = cross_range
self.spatial_broadcast = spatial_broadcast
self.spatial_threshold = spatial_threshold
- self.spatial_gap = spatial_gap
+ self.spatial_range = spatial_range
self.temporal_broadcast = temporal_broadcast
self.temporal_threshold = temporal_threshold
- self.temporal_gap = temporal_gap
-
- self.diffusion_skip = diffusion_skip
- self.diffusion_timestep_respacing = diffusion_timestep_respacing
- self.diffusion_skip_timestep = diffusion_skip_timestep
+ self.temporal_range = temporal_range
- self.mlp_skip = mlp_skip
- self.mlp_spatial_skip_config = mlp_spatial_skip_config
- self.mlp_temporal_skip_config = mlp_temporal_skip_config
-
- self.temporal_mlp_outputs = {}
- self.spatial_mlp_outputs = {}
-
- self.full_broadcast = full_broadcast
- self.full_threshold = full_threshold
- self.full_gap = full_gap
+ self.mlp_broadcast = mlp_broadcast
+ self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
+ self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
+ self.mlp_temporal_outputs = {}
+ self.mlp_spatial_outputs = {}
class PABManager:
def __init__(self, config: PABConfig):
self.config: PABConfig = config
- init_prompt = f"Init PABManager. steps: {config.steps}."
- init_prompt += f" spatial_broadcast: {config.spatial_broadcast}, spatial_threshold: {config.spatial_threshold}, spatial_gap: {config.spatial_gap}."
- init_prompt += f" temporal_broadcast: {config.temporal_broadcast}, temporal_threshold: {config.temporal_threshold}, temporal_gap: {config.temporal_gap}."
- init_prompt += f" cross_broadcast: {config.cross_broadcast}, cross_threshold: {config.cross_threshold}, cross_gap: {config.cross_gap}."
- init_prompt += f" full_broadcast: {config.full_broadcast}, full_threshold: {config.full_threshold}, full_gap: {config.full_gap}."
+ init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
+ init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
+ init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
+ init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
+ init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
logger.info(init_prompt)
def if_broadcast_cross(self, timestep: int, count: int):
if (
self.config.cross_broadcast
and (timestep is not None)
- and (count % self.config.cross_gap != 0)
+ and (count % self.config.cross_range != 0)
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
):
flag = True
@@ -89,7 +69,7 @@ class PABManager:
if (
self.config.temporal_broadcast
and (timestep is not None)
- and (count % self.config.temporal_gap != 0)
+ and (count % self.config.temporal_range != 0)
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
):
flag = True
@@ -102,7 +82,7 @@ class PABManager:
if (
self.config.spatial_broadcast
and (timestep is not None)
- and (count % self.config.spatial_gap != 0)
+ and (count % self.config.spatial_range != 0)
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
):
flag = True
@@ -111,19 +91,6 @@ class PABManager:
count = (count + 1) % self.config.steps
return flag, count
- def if_broadcast_full(self, timestep: int, count: int, block_idx: int):
- if (
- self.config.full_broadcast
- and (timestep is not None)
- and (count % self.config.full_gap != 0)
- and (self.config.full_threshold[0] < timestep < self.config.full_threshold[1])
- ):
- flag = True
- else:
- flag = False
- count = (count + 1) % self.config.steps
- return flag, count
-
@staticmethod
def _is_t_in_skip_config(all_timesteps, timestep, config):
is_t_in_skip_config = False
@@ -139,18 +106,18 @@ class PABManager:
return is_t_in_skip_config, skip_range
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
- if not self.config.mlp_skip:
+ if not self.config.mlp_broadcast:
return False, None, False, None
if is_temporal:
- cur_config = self.config.mlp_temporal_skip_config
+ cur_config = self.config.mlp_temporal_broadcast_config
else:
- cur_config = self.config.mlp_spatial_skip_config
+ cur_config = self.config.mlp_spatial_broadcast_config
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
next_flag = False
if (
- self.config.mlp_skip
+ self.config.mlp_broadcast
and (timestep is not None)
and (timestep in cur_config)
and (block_idx in cur_config[timestep]["block"])
@@ -159,7 +126,7 @@ class PABManager:
next_flag = True
count = count + 1
elif (
- self.config.mlp_skip
+ self.config.mlp_broadcast
and (timestep is not None)
and (is_t_in_skip_config)
and (block_idx in cur_config[skip_range[0]]["block"])
@@ -173,22 +140,22 @@ class PABManager:
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
if is_temporal:
- self.config.temporal_mlp_outputs[(timestep, block_idx)] = ff_output
+ self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
else:
- self.config.spatial_mlp_outputs[(timestep, block_idx)] = ff_output
+ self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
skip_start_t = skip_range[0]
if is_temporal:
skip_output = (
- self.config.temporal_mlp_outputs.get((skip_start_t, block_idx), None)
- if self.config.temporal_mlp_outputs is not None
+ self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
+ if self.config.mlp_temporal_outputs is not None
else None
)
else:
skip_output = (
- self.config.spatial_mlp_outputs.get((skip_start_t, block_idx), None)
- if self.config.spatial_mlp_outputs is not None
+ self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
+ if self.config.mlp_spatial_outputs is not None
else None
)
@@ -196,9 +163,9 @@ class PABManager:
if timestep == skip_range[-1]:
# TODO: save memory
if is_temporal:
- del self.config.temporal_mlp_outputs[(skip_start_t, block_idx)]
+ del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
else:
- del self.config.spatial_mlp_outputs[(skip_start_t, block_idx)]
+ del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
else:
raise ValueError(
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
@@ -207,10 +174,10 @@ class PABManager:
return skip_output
def get_spatial_mlp_outputs(self):
- return self.config.spatial_mlp_outputs
+ return self.config.mlp_spatial_outputs
def get_temporal_mlp_outputs(self):
- return self.config.temporal_mlp_outputs
+ return self.config.mlp_temporal_outputs
def set_pab_manager(config: PABConfig):
@@ -250,11 +217,6 @@ def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
return False, count
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
-def if_broadcast_full(timestep: int, count: int, block_idx: int):
- if not enable_pab():
- return False, count
- return PAB_MANAGER.if_broadcast_full(timestep, count, block_idx)
-
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not enable_pab():
@@ -268,97 +230,3 @@ def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False)
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
-
-
-def get_diffusion_skip():
- return enable_pab() and PAB_MANAGER.config.diffusion_skip
-
-
-def get_diffusion_timestep_respacing():
- return PAB_MANAGER.config.diffusion_timestep_respacing
-
-
-def get_diffusion_skip_timestep():
- return enable_pab() and PAB_MANAGER.config.diffusion_skip_timestep
-
-
-def space_timesteps(time_steps, time_bins):
- num_bins = len(time_bins)
- bin_size = time_steps // num_bins
-
- result = []
-
- for i, bin_count in enumerate(time_bins):
- start = i * bin_size
- end = start + bin_size
-
- bin_steps = np.linspace(start, end, bin_count, endpoint=False, dtype=int).tolist()
- result.extend(bin_steps)
-
- result_tensor = torch.tensor(result, dtype=torch.int32)
- sorted_tensor = torch.sort(result_tensor, descending=True).values
-
- return sorted_tensor
-
-
-def skip_diffusion_timestep(timesteps, diffusion_skip_timestep):
- if isinstance(timesteps, list):
- # If timesteps is a list, we assume each element is a tensor
- timesteps_np = [t.cpu().numpy() for t in timesteps]
- device = timesteps[0].device
- else:
- # If timesteps is a tensor
- timesteps_np = timesteps.cpu().numpy()
- device = timesteps.device
-
- num_bins = len(diffusion_skip_timestep)
-
- if isinstance(timesteps_np, list):
- bin_size = len(timesteps_np) // num_bins
- new_timesteps = []
-
- for i in range(num_bins):
- bin_start = i * bin_size
- bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
- bin_timesteps = timesteps_np[bin_start:bin_end]
-
- if diffusion_skip_timestep[i] == 0:
- # If the bin is marked with 0, keep all timesteps
- new_timesteps.extend(bin_timesteps)
- elif diffusion_skip_timestep[i] == 1:
- # If the bin is marked with 1, omit the last timestep in the bin
- new_timesteps.extend(bin_timesteps[1:])
-
- new_timesteps_tensor = [torch.tensor(t, device=device) for t in new_timesteps]
- else:
- bin_size = len(timesteps_np) // num_bins
- new_timesteps = []
-
- for i in range(num_bins):
- bin_start = i * bin_size
- bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
- bin_timesteps = timesteps_np[bin_start:bin_end]
-
- if diffusion_skip_timestep[i] == 0:
- # If the bin is marked with 0, keep all timesteps
- new_timesteps.extend(bin_timesteps)
- elif diffusion_skip_timestep[i] == 1:
- # If the bin is marked with 1, omit the last timestep in the bin
- new_timesteps.extend(bin_timesteps[1:])
- elif diffusion_skip_timestep[i] != 0:
- # If the bin is marked with a non-zero value, randomly omit n timesteps
- if len(bin_timesteps) > diffusion_skip_timestep[i]:
- indices_to_remove = set(random.sample(range(len(bin_timesteps)), diffusion_skip_timestep[i]))
- timesteps_to_keep = [
- timestep for idx, timestep in enumerate(bin_timesteps) if idx not in indices_to_remove
- ]
- else:
- timesteps_to_keep = bin_timesteps # 如果bin_timesteps的长度小于等于n,则不删除任何元素
- new_timesteps.extend(timesteps_to_keep)
-
- new_timesteps_tensor = torch.tensor(new_timesteps, device=device)
-
- if isinstance(timesteps, list):
- return new_timesteps_tensor
- else:
- return new_timesteps_tensor
diff --git a/videosys/datasets/dataloader.py b/videosys/datasets/dataloader.py
deleted file mode 100644
index 22a7be3dd188cb425910571510634ea697ab6550..0000000000000000000000000000000000000000
--- a/videosys/datasets/dataloader.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import random
-from typing import Iterator, Optional
-
-import numpy as np
-import torch
-from torch.utils.data import DataLoader, Dataset, DistributedSampler
-from torch.utils.data.distributed import DistributedSampler
-
-from videosys.core.parallel_mgr import ParallelManager
-
-
-class StatefulDistributedSampler(DistributedSampler):
- def __init__(
- self,
- dataset: Dataset,
- num_replicas: Optional[int] = None,
- rank: Optional[int] = None,
- shuffle: bool = True,
- seed: int = 0,
- drop_last: bool = False,
- ) -> None:
- super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
- self.start_index: int = 0
-
- def __iter__(self) -> Iterator:
- iterator = super().__iter__()
- indices = list(iterator)
- indices = indices[self.start_index :]
- return iter(indices)
-
- def __len__(self) -> int:
- return self.num_samples - self.start_index
-
- def set_start_index(self, start_index: int) -> None:
- self.start_index = start_index
-
-
-def prepare_dataloader(
- dataset,
- batch_size,
- shuffle=False,
- seed=1024,
- drop_last=False,
- pin_memory=False,
- num_workers=0,
- pg_manager: Optional[ParallelManager] = None,
- **kwargs,
-):
- r"""
- Prepare a dataloader for distributed training. The dataloader will be wrapped by
- `torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
-
-
- Args:
- dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
- shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
- seed (int, optional): Random worker seed for sampling, defaults to 1024.
- add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
- drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
- is not divisible by the batch size. If False and the size of dataset is not divisible by
- the batch size, then the last batch will be smaller, defaults to False.
- pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
- num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
- kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
- `DataLoader `_.
-
- Returns:
- :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
- """
- _kwargs = kwargs.copy()
- sampler = StatefulDistributedSampler(
- dataset,
- num_replicas=pg_manager.size(pg_manager.dp_axis),
- rank=pg_manager.coordinate(pg_manager.dp_axis),
- shuffle=shuffle,
- )
-
- # Deterministic dataloader
- def seed_worker(worker_id):
- worker_seed = seed
- np.random.seed(worker_seed)
- torch.manual_seed(worker_seed)
- random.seed(worker_seed)
-
- return DataLoader(
- dataset,
- batch_size=batch_size,
- sampler=sampler,
- worker_init_fn=seed_worker,
- drop_last=drop_last,
- pin_memory=pin_memory,
- num_workers=num_workers,
- **_kwargs,
- )
diff --git a/videosys/datasets/image_transform.py b/videosys/datasets/image_transform.py
deleted file mode 100644
index 7efa8bb45c5b51adf072c5bc5b710f7e2e272409..0000000000000000000000000000000000000000
--- a/videosys/datasets/image_transform.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Adapted from DiT
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# DiT: https://github.com/facebookresearch/DiT
-# --------------------------------------------------------
-
-
-import numpy as np
-import torchvision.transforms as transforms
-from PIL import Image
-
-
-def center_crop_arr(pil_image, image_size):
- """
- Center cropping implementation from ADM.
- https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
- """
- while min(*pil_image.size) >= 2 * image_size:
- pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
-
- scale = image_size / min(*pil_image.size)
- pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
-
- arr = np.array(pil_image)
- crop_y = (arr.shape[0] - image_size) // 2
- crop_x = (arr.shape[1] - image_size) // 2
- return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
-
-
-def get_transforms_image(image_size=256):
- transform = transforms.Compose(
- [
- transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
- ]
- )
- return transform
diff --git a/videosys/datasets/video_transform.py b/videosys/datasets/video_transform.py
deleted file mode 100644
index 36f0fb440026d078835a19f5389a86930e697010..0000000000000000000000000000000000000000
--- a/videosys/datasets/video_transform.py
+++ /dev/null
@@ -1,441 +0,0 @@
-# Adapted from OpenSora and Latte
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# OpenSora: https://github.com/hpcaitech/Open-Sora
-# Latte: https://github.com/Vchitect/Latte
-# --------------------------------------------------------
-
-import numbers
-import random
-
-import numpy as np
-import torch
-from PIL import Image
-
-
-def _is_tensor_video_clip(clip):
- if not torch.is_tensor(clip):
- raise TypeError("clip should be Tensor. Got %s" % type(clip))
-
- if not clip.ndimension() == 4:
- raise ValueError("clip should be 4D. Got %dD" % clip.dim())
-
- return True
-
-
-def center_crop_arr(pil_image, image_size):
- """
- Center cropping implementation from ADM.
- https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
- """
- while min(*pil_image.size) >= 2 * image_size:
- pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
-
- scale = image_size / min(*pil_image.size)
- pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
-
- arr = np.array(pil_image)
- crop_y = (arr.shape[0] - image_size) // 2
- crop_x = (arr.shape[1] - image_size) // 2
- return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
-
-
-def crop(clip, i, j, h, w):
- """
- Args:
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
- """
- if len(clip.size()) != 4:
- raise ValueError("clip should be a 4D tensor")
- return clip[..., i : i + h, j : j + w]
-
-
-def resize(clip, target_size, interpolation_mode):
- if len(target_size) != 2:
- raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
- return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
-
-
-def resize_scale(clip, target_size, interpolation_mode):
- if len(target_size) != 2:
- raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
- H, W = clip.size(-2), clip.size(-1)
- scale_ = target_size[0] / min(H, W)
- return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
-
-
-def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
- """
- Do spatial cropping and resizing to the video clip
- Args:
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
- i (int): i in (i,j) i.e coordinates of the upper left corner.
- j (int): j in (i,j) i.e coordinates of the upper left corner.
- h (int): Height of the cropped region.
- w (int): Width of the cropped region.
- size (tuple(int, int)): height and width of resized clip
- Returns:
- clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
- """
- if not _is_tensor_video_clip(clip):
- raise ValueError("clip should be a 4D torch.tensor")
- clip = crop(clip, i, j, h, w)
- clip = resize(clip, size, interpolation_mode)
- return clip
-
-
-def center_crop(clip, crop_size):
- if not _is_tensor_video_clip(clip):
- raise ValueError("clip should be a 4D torch.tensor")
- h, w = clip.size(-2), clip.size(-1)
- th, tw = crop_size
- if h < th or w < tw:
- raise ValueError("height and width must be no smaller than crop_size")
-
- i = int(round((h - th) / 2.0))
- j = int(round((w - tw) / 2.0))
- return crop(clip, i, j, th, tw)
-
-
-def center_crop_using_short_edge(clip):
- if not _is_tensor_video_clip(clip):
- raise ValueError("clip should be a 4D torch.tensor")
- h, w = clip.size(-2), clip.size(-1)
- if h < w:
- th, tw = h, h
- i = 0
- j = int(round((w - tw) / 2.0))
- else:
- th, tw = w, w
- i = int(round((h - th) / 2.0))
- j = 0
- return crop(clip, i, j, th, tw)
-
-
-def random_shift_crop(clip):
- """
- Slide along the long edge, with the short edge as crop size
- """
- if not _is_tensor_video_clip(clip):
- raise ValueError("clip should be a 4D torch.tensor")
- h, w = clip.size(-2), clip.size(-1)
-
- if h <= w:
- short_edge = h
- else:
- short_edge = w
-
- th, tw = short_edge, short_edge
-
- i = torch.randint(0, h - th + 1, size=(1,)).item()
- j = torch.randint(0, w - tw + 1, size=(1,)).item()
- return crop(clip, i, j, th, tw)
-
-
-def to_tensor(clip):
- """
- Convert tensor data type from uint8 to float, divide value by 255.0 and
- permute the dimensions of clip tensor
- Args:
- clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
- Return:
- clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
- """
- _is_tensor_video_clip(clip)
- if not clip.dtype == torch.uint8:
- raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
- # return clip.float().permute(3, 0, 1, 2) / 255.0
- return clip.float() / 255.0
-
-
-def normalize(clip, mean, std, inplace=False):
- """
- Args:
- clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
- mean (tuple): pixel RGB mean. Size is (3)
- std (tuple): pixel standard deviation. Size is (3)
- Returns:
- normalized clip (torch.tensor): Size is (T, C, H, W)
- """
- if not _is_tensor_video_clip(clip):
- raise ValueError("clip should be a 4D torch.tensor")
- if not inplace:
- clip = clip.clone()
- mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
- # print(mean)
- std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
- clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
- return clip
-
-
-def hflip(clip):
- """
- Args:
- clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
- Returns:
- flipped clip (torch.tensor): Size is (T, C, H, W)
- """
- if not _is_tensor_video_clip(clip):
- raise ValueError("clip should be a 4D torch.tensor")
- return clip.flip(-1)
-
-
-class RandomCropVideo:
- def __init__(self, size):
- if isinstance(size, numbers.Number):
- self.size = (int(size), int(size))
- else:
- self.size = size
-
- def __call__(self, clip):
- """
- Args:
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
- Returns:
- torch.tensor: randomly cropped video clip.
- size is (T, C, OH, OW)
- """
- i, j, h, w = self.get_params(clip)
- return crop(clip, i, j, h, w)
-
- def get_params(self, clip):
- h, w = clip.shape[-2:]
- th, tw = self.size
-
- if h < th or w < tw:
- raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
-
- if w == tw and h == th:
- return 0, 0, h, w
-
- i = torch.randint(0, h - th + 1, size=(1,)).item()
- j = torch.randint(0, w - tw + 1, size=(1,)).item()
-
- return i, j, th, tw
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size})"
-
-
-class CenterCropResizeVideo:
- """
- First use the short side for cropping length,
- center crop video, then resize to the specified size
- """
-
- def __init__(
- self,
- size,
- interpolation_mode="bilinear",
- ):
- if isinstance(size, tuple):
- if len(size) != 2:
- raise ValueError(f"size should be tuple (height, width), instead got {size}")
- self.size = size
- else:
- self.size = (size, size)
-
- self.interpolation_mode = interpolation_mode
-
- def __call__(self, clip):
- """
- Args:
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
- Returns:
- torch.tensor: scale resized / center cropped video clip.
- size is (T, C, crop_size, crop_size)
- """
- clip_center_crop = center_crop_using_short_edge(clip)
- clip_center_crop_resize = resize(
- clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
- )
- return clip_center_crop_resize
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
-
-
-class UCFCenterCropVideo:
- """
- First scale to the specified size in equal proportion to the short edge,
- then center cropping
- """
-
- def __init__(
- self,
- size,
- interpolation_mode="bilinear",
- ):
- if isinstance(size, tuple):
- if len(size) != 2:
- raise ValueError(f"size should be tuple (height, width), instead got {size}")
- self.size = size
- else:
- self.size = (size, size)
-
- self.interpolation_mode = interpolation_mode
-
- def __call__(self, clip):
- """
- Args:
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
- Returns:
- torch.tensor: scale resized / center cropped video clip.
- size is (T, C, crop_size, crop_size)
- """
- clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
- clip_center_crop = center_crop(clip_resize, self.size)
- return clip_center_crop
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
-
-
-class KineticsRandomCropResizeVideo:
- """
- Slide along the long edge, with the short edge as crop size. And resie to the desired size.
- """
-
- def __init__(
- self,
- size,
- interpolation_mode="bilinear",
- ):
- if isinstance(size, tuple):
- if len(size) != 2:
- raise ValueError(f"size should be tuple (height, width), instead got {size}")
- self.size = size
- else:
- self.size = (size, size)
-
- self.interpolation_mode = interpolation_mode
-
- def __call__(self, clip):
- clip_random_crop = random_shift_crop(clip)
- clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
- return clip_resize
-
-
-class CenterCropVideo:
- def __init__(
- self,
- size,
- interpolation_mode="bilinear",
- ):
- if isinstance(size, tuple):
- if len(size) != 2:
- raise ValueError(f"size should be tuple (height, width), instead got {size}")
- self.size = size
- else:
- self.size = (size, size)
-
- self.interpolation_mode = interpolation_mode
-
- def __call__(self, clip):
- """
- Args:
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
- Returns:
- torch.tensor: center cropped video clip.
- size is (T, C, crop_size, crop_size)
- """
- clip_center_crop = center_crop(clip, self.size)
- return clip_center_crop
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
-
-
-class NormalizeVideo:
- """
- Normalize the video clip by mean subtraction and division by standard deviation
- Args:
- mean (3-tuple): pixel RGB mean
- std (3-tuple): pixel RGB standard deviation
- inplace (boolean): whether do in-place normalization
- """
-
- def __init__(self, mean, std, inplace=False):
- self.mean = mean
- self.std = std
- self.inplace = inplace
-
- def __call__(self, clip):
- """
- Args:
- clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
- """
- return normalize(clip, self.mean, self.std, self.inplace)
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
-
-
-class ToTensorVideo:
- """
- Convert tensor data type from uint8 to float, divide value by 255.0 and
- permute the dimensions of clip tensor
- """
-
- def __init__(self):
- pass
-
- def __call__(self, clip):
- """
- Args:
- clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
- Return:
- clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
- """
- return to_tensor(clip)
-
- def __repr__(self) -> str:
- return self.__class__.__name__
-
-
-class RandomHorizontalFlipVideo:
- """
- Flip the video clip along the horizontal direction with a given probability
- Args:
- p (float): probability of the clip being flipped. Default value is 0.5
- """
-
- def __init__(self, p=0.5):
- self.p = p
-
- def __call__(self, clip):
- """
- Args:
- clip (torch.tensor): Size is (T, C, H, W)
- Return:
- clip (torch.tensor): Size is (T, C, H, W)
- """
- if random.random() < self.p:
- clip = hflip(clip)
- return clip
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
-
-
-# ------------------------------------------------------------
-# --------------------- Sampling ---------------------------
-# ------------------------------------------------------------
-class TemporalRandomCrop(object):
- """Temporally crop the given frame indices at a random location.
-
- Args:
- size (int): Desired length of frames will be seen in the model.
- """
-
- def __init__(self, size):
- self.size = size
-
- def __call__(self, total_frames):
- rand_end = max(0, total_frames - self.size - 1)
- begin_index = random.randint(0, rand_end)
- end_index = min(begin_index + self.size, total_frames)
- return begin_index, end_index
diff --git a/videosys/diffusion/__init__.py b/videosys/diffusion/__init__.py
deleted file mode 100644
index 0d16b842bf5a1bdc145b923693143e25a7e2ce81..0000000000000000000000000000000000000000
--- a/videosys/diffusion/__init__.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Modified from OpenAI's diffusion repos and Meta DiT
-# DiT: https://github.com/facebookresearch/DiT/tree/main
-# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
-# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
-# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
-
-from . import gaussian_diffusion as gd
-from .respace import SpacedDiffusion, space_timesteps
-
-
-def create_diffusion(
- timestep_respacing,
- noise_schedule="linear",
- use_kl=False,
- sigma_small=False,
- predict_xstart=False,
- learn_sigma=True,
- rescale_learned_sigmas=False,
- diffusion_steps=1000,
-):
- betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
- if use_kl:
- loss_type = gd.LossType.RESCALED_KL
- elif rescale_learned_sigmas:
- loss_type = gd.LossType.RESCALED_MSE
- else:
- loss_type = gd.LossType.MSE
- if timestep_respacing is None or timestep_respacing == "":
- timestep_respacing = [diffusion_steps]
- return SpacedDiffusion(
- use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
- betas=betas,
- model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
- model_var_type=(
- (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
- if not learn_sigma
- else gd.ModelVarType.LEARNED_RANGE
- ),
- loss_type=loss_type
- # rescale_timesteps=rescale_timesteps,
- )
diff --git a/videosys/diffusion/diffusion_utils.py b/videosys/diffusion/diffusion_utils.py
deleted file mode 100644
index 056471c0b0b560d17d18b95f9b8ef3dbc1b8317e..0000000000000000000000000000000000000000
--- a/videosys/diffusion/diffusion_utils.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# Modified from OpenAI's diffusion repos
-# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
-# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
-# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
-
-import numpy as np
-import torch as th
-
-
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- Compute the KL divergence between two gaussians.
- Shapes are automatically broadcasted, so batches can be compared to
- scalars, among other use cases.
- """
- tensor = None
- for obj in (mean1, logvar1, mean2, logvar2):
- if isinstance(obj, th.Tensor):
- tensor = obj
- break
- assert tensor is not None, "at least one argument must be a Tensor"
-
- # Force variances to be Tensors. Broadcasting helps convert scalars to
- # Tensors, but it does not work for th.exp().
- logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
-
- return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
-
-
-def approx_standard_normal_cdf(x):
- """
- A fast approximation of the cumulative distribution function of the
- standard normal.
- """
- return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
-
-
-def continuous_gaussian_log_likelihood(x, *, means, log_scales):
- """
- Compute the log-likelihood of a continuous Gaussian distribution.
- :param x: the targets
- :param means: the Gaussian mean Tensor.
- :param log_scales: the Gaussian log stddev Tensor.
- :return: a tensor like x of log probabilities (in nats).
- """
- centered_x = x - means
- inv_stdv = th.exp(-log_scales)
- normalized_x = centered_x * inv_stdv
- log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
- return log_probs
-
-
-def discretized_gaussian_log_likelihood(x, *, means, log_scales):
- """
- Compute the log-likelihood of a Gaussian distribution discretizing to a
- given image.
- :param x: the target images. It is assumed that this was uint8 values,
- rescaled to the range [-1, 1].
- :param means: the Gaussian mean Tensor.
- :param log_scales: the Gaussian log stddev Tensor.
- :return: a tensor like x of log probabilities (in nats).
- """
- assert x.shape == means.shape == log_scales.shape
- centered_x = x - means
- inv_stdv = th.exp(-log_scales)
- plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
- cdf_plus = approx_standard_normal_cdf(plus_in)
- min_in = inv_stdv * (centered_x - 1.0 / 255.0)
- cdf_min = approx_standard_normal_cdf(min_in)
- log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
- log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
- cdf_delta = cdf_plus - cdf_min
- log_probs = th.where(
- x < -0.999,
- log_cdf_plus,
- th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
- )
- assert log_probs.shape == x.shape
- return log_probs
diff --git a/videosys/diffusion/gaussian_diffusion.py b/videosys/diffusion/gaussian_diffusion.py
deleted file mode 100644
index cf734a641690a1c1f4f5256eea1792afd71b800c..0000000000000000000000000000000000000000
--- a/videosys/diffusion/gaussian_diffusion.py
+++ /dev/null
@@ -1,829 +0,0 @@
-# Modified from OpenAI's diffusion repos
-# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
-# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
-# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
-
-
-import enum
-import math
-
-import numpy as np
-import torch as th
-
-from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
-
-
-def mean_flat(tensor):
- """
- Take the mean over all non-batch dimensions.
- """
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
-
-
-class ModelMeanType(enum.Enum):
- """
- Which type of output the model predicts.
- """
-
- PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
- START_X = enum.auto() # the model predicts x_0
- EPSILON = enum.auto() # the model predicts epsilon
-
-
-class ModelVarType(enum.Enum):
- """
- What is used as the model's output variance.
- The LEARNED_RANGE option has been added to allow the model to predict
- values between FIXED_SMALL and FIXED_LARGE, making its job easier.
- """
-
- LEARNED = enum.auto()
- FIXED_SMALL = enum.auto()
- FIXED_LARGE = enum.auto()
- LEARNED_RANGE = enum.auto()
-
-
-class LossType(enum.Enum):
- MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
- RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
- KL = enum.auto() # use the variational lower-bound
- RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
-
- def is_vb(self):
- return self == LossType.KL or self == LossType.RESCALED_KL
-
-
-def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
- betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
- warmup_time = int(num_diffusion_timesteps * warmup_frac)
- betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
- return betas
-
-
-def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
- """
- This is the deprecated API for creating beta schedules.
- See get_named_beta_schedule() for the new library of schedules.
- """
- if beta_schedule == "quad":
- betas = (
- np.linspace(
- beta_start**0.5,
- beta_end**0.5,
- num_diffusion_timesteps,
- dtype=np.float64,
- )
- ** 2
- )
- elif beta_schedule == "linear":
- betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
- elif beta_schedule == "warmup10":
- betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
- elif beta_schedule == "warmup50":
- betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
- elif beta_schedule == "const":
- betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
- elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
- betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
- else:
- raise NotImplementedError(beta_schedule)
- assert betas.shape == (num_diffusion_timesteps,)
- return betas
-
-
-def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
- """
- Get a pre-defined beta schedule for the given name.
- The beta schedule library consists of beta schedules which remain similar
- in the limit of num_diffusion_timesteps.
- Beta schedules may be added, but should not be removed or changed once
- they are committed to maintain backwards compatibility.
- """
- if schedule_name == "linear":
- # Linear schedule from Ho et al, extended to work for any number of
- # diffusion steps.
- scale = 1000 / num_diffusion_timesteps
- return get_beta_schedule(
- "linear",
- beta_start=scale * 0.0001,
- beta_end=scale * 0.02,
- num_diffusion_timesteps=num_diffusion_timesteps,
- )
- elif schedule_name == "squaredcos_cap_v2":
- return betas_for_alpha_bar(
- num_diffusion_timesteps,
- lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
- )
- else:
- raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
-
-
-def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
- """
- Create a beta schedule that discretizes the given alpha_t_bar function,
- which defines the cumulative product of (1-beta) over time from t = [0,1].
- :param num_diffusion_timesteps: the number of betas to produce.
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
- produces the cumulative product of (1-beta) up to that
- part of the diffusion process.
- :param max_beta: the maximum beta to use; use values lower than 1 to
- prevent singularities.
- """
- betas = []
- for i in range(num_diffusion_timesteps):
- t1 = i / num_diffusion_timesteps
- t2 = (i + 1) / num_diffusion_timesteps
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
- return np.array(betas)
-
-
-class GaussianDiffusion:
- """
- Utilities for training and sampling diffusion models.
- Original ported from this codebase:
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
- :param betas: a 1-D numpy array of betas for each diffusion timestep,
- starting at T and going to 1.
- """
-
- def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
- self.model_mean_type = model_mean_type
- self.model_var_type = model_var_type
- self.loss_type = loss_type
-
- # Use float64 for accuracy.
- betas = np.array(betas, dtype=np.float64)
- self.betas = betas
- assert len(betas.shape) == 1, "betas must be 1-D"
- assert (betas > 0).all() and (betas <= 1).all()
-
- self.num_timesteps = int(betas.shape[0])
-
- alphas = 1.0 - betas
- self.alphas_cumprod = np.cumprod(alphas, axis=0)
- self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
- self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
- assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
- self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
- self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
- self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
- self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
-
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.posterior_log_variance_clipped = (
- np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
- if len(self.posterior_variance) > 1
- else np.array([])
- )
-
- self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
- self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
-
- def q_mean_variance(self, x_start, t):
- """
- Get the distribution q(x_t | x_0).
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
- """
- mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
- variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
- log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
- return mean, variance, log_variance
-
- def q_sample(self, x_start, t, noise=None):
- """
- Diffuse the data for a given number of diffusion steps.
- In other words, sample from q(x_t | x_0).
- :param x_start: the initial data batch.
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
- :param noise: if specified, the split-out normal noise.
- :return: A noisy version of x_start.
- """
- if noise is None:
- noise = th.randn_like(x_start)
- assert noise.shape == x_start.shape
- return (
- _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
- + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
- )
-
- def q_posterior_mean_variance(self, x_start, x_t, t):
- """
- Compute the mean and variance of the diffusion posterior:
- q(x_{t-1} | x_t, x_0)
- """
- assert x_start.shape == x_t.shape
- posterior_mean = (
- _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
- + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
- )
- posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
- posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
- assert (
- posterior_mean.shape[0]
- == posterior_variance.shape[0]
- == posterior_log_variance_clipped.shape[0]
- == x_start.shape[0]
- )
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
- def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
- """
- Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
- the initial x, x_0.
- :param model: the model, which takes a signal and a batch of timesteps
- as input.
- :param x: the [N x C x ...] tensor at time t.
- :param t: a 1-D Tensor of timesteps.
- :param clip_denoised: if True, clip the denoised signal into [-1, 1].
- :param denoised_fn: if not None, a function which applies to the
- x_start prediction before it is used to sample. Applies before
- clip_denoised.
- :param model_kwargs: if not None, a dict of extra keyword arguments to
- pass to the model. This can be used for conditioning.
- :return: a dict with the following keys:
- - 'mean': the model mean output.
- - 'variance': the model variance output.
- - 'log_variance': the log of 'variance'.
- - 'pred_xstart': the prediction for x_0.
- """
- if model_kwargs is None:
- model_kwargs = {}
-
- B, C = x.shape[:2]
- assert t.shape == (B,)
- model_output = model(x, t, **model_kwargs)
- if isinstance(model_output, tuple):
- model_output, extra = model_output
- else:
- extra = None
-
- if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
- assert model_output.shape == (B, C * 2, *x.shape[2:])
- model_output, model_var_values = th.split(model_output, C, dim=1)
- min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
- max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
- # The model_var_values is [-1, 1] for [min_var, max_var].
- frac = (model_var_values + 1) / 2
- model_log_variance = frac * max_log + (1 - frac) * min_log
- model_variance = th.exp(model_log_variance)
- else:
- model_variance, model_log_variance = {
- # for fixedlarge, we set the initial (log-)variance like so
- # to get a better decoder log likelihood.
- ModelVarType.FIXED_LARGE: (
- np.append(self.posterior_variance[1], self.betas[1:]),
- np.log(np.append(self.posterior_variance[1], self.betas[1:])),
- ),
- ModelVarType.FIXED_SMALL: (
- self.posterior_variance,
- self.posterior_log_variance_clipped,
- ),
- }[self.model_var_type]
- model_variance = _extract_into_tensor(model_variance, t, x.shape)
- model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
-
- def process_xstart(x):
- if denoised_fn is not None:
- x = denoised_fn(x)
- if clip_denoised:
- return x.clamp(-1, 1)
- return x
-
- if self.model_mean_type == ModelMeanType.START_X:
- pred_xstart = process_xstart(model_output)
- else:
- pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
-
- assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
- return {
- "mean": model_mean,
- "variance": model_variance,
- "log_variance": model_log_variance,
- "pred_xstart": pred_xstart,
- "extra": extra,
- }
-
- def _predict_xstart_from_eps(self, x_t, t, eps):
- assert x_t.shape == eps.shape
- return (
- _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
- )
-
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
- return (
- _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
- ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
-
- def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
- """
- Compute the mean for the previous step, given a function cond_fn that
- computes the gradient of a conditional log probability with respect to
- x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
- condition on y.
- This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
- """
- gradient = cond_fn(x, t, **model_kwargs)
- new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
- return new_mean
-
- def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
- """
- Compute what the p_mean_variance output would have been, should the
- model's score function be conditioned by cond_fn.
- See condition_mean() for details on cond_fn.
- Unlike condition_mean(), this instead uses the conditioning strategy
- from Song et al (2020).
- """
- alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
-
- eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
- eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
-
- out = p_mean_var.copy()
- out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
- out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
- return out
-
- def p_sample(
- self,
- model,
- x,
- t,
- clip_denoised=True,
- denoised_fn=None,
- cond_fn=None,
- model_kwargs=None,
- ):
- """
- Sample x_{t-1} from the model at the given timestep.
- :param model: the model to sample from.
- :param x: the current tensor at x_{t-1}.
- :param t: the value of t, starting at 0 for the first diffusion step.
- :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
- :param denoised_fn: if not None, a function which applies to the
- x_start prediction before it is used to sample.
- :param cond_fn: if not None, this is a gradient function that acts
- similarly to the model.
- :param model_kwargs: if not None, a dict of extra keyword arguments to
- pass to the model. This can be used for conditioning.
- :return: a dict containing the following keys:
- - 'sample': a random sample from the model.
- - 'pred_xstart': a prediction of x_0.
- """
- out = self.p_mean_variance(
- model,
- x,
- t,
- clip_denoised=clip_denoised,
- denoised_fn=denoised_fn,
- model_kwargs=model_kwargs,
- )
- noise = th.randn_like(x)
- nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
- if cond_fn is not None:
- out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
- sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
- return {"sample": sample, "pred_xstart": out["pred_xstart"]}
-
- def p_sample_loop(
- self,
- model,
- shape,
- noise=None,
- clip_denoised=True,
- denoised_fn=None,
- cond_fn=None,
- model_kwargs=None,
- device=None,
- progress=False,
- ):
- """
- Generate samples from the model.
- :param model: the model module.
- :param shape: the shape of the samples, (N, C, H, W).
- :param noise: if specified, the noise from the encoder to sample.
- Should be of the same shape as `shape`.
- :param clip_denoised: if True, clip x_start predictions to [-1, 1].
- :param denoised_fn: if not None, a function which applies to the
- x_start prediction before it is used to sample.
- :param cond_fn: if not None, this is a gradient function that acts
- similarly to the model.
- :param model_kwargs: if not None, a dict of extra keyword arguments to
- pass to the model. This can be used for conditioning.
- :param device: if specified, the device to create the samples on.
- If not specified, use a model parameter's device.
- :param progress: if True, show a tqdm progress bar.
- :return: a non-differentiable batch of samples.
- """
- final = None
- for sample in self.p_sample_loop_progressive(
- model,
- shape,
- noise=noise,
- clip_denoised=clip_denoised,
- denoised_fn=denoised_fn,
- cond_fn=cond_fn,
- model_kwargs=model_kwargs,
- device=device,
- progress=progress,
- ):
- final = sample
- return final["sample"]
-
- def p_sample_loop_progressive(
- self,
- model,
- shape,
- noise=None,
- clip_denoised=True,
- denoised_fn=None,
- cond_fn=None,
- model_kwargs=None,
- device=None,
- progress=False,
- ):
- """
- Generate samples from the model and yield intermediate samples from
- each timestep of diffusion.
- Arguments are the same as p_sample_loop().
- Returns a generator over dicts, where each dict is the return value of
- p_sample().
- """
- if device is None:
- device = next(model.parameters()).device
- assert isinstance(shape, (tuple, list))
- if noise is not None:
- img = noise
- else:
- img = th.randn(*shape, device=device)
- indices = list(range(self.num_timesteps))[::-1]
-
- if progress:
- # Lazy import so that we don't depend on tqdm.
- from tqdm.auto import tqdm
-
- indices = tqdm(indices)
-
- for i in indices:
- t = th.tensor([i] * shape[0], device=device)
- with th.no_grad():
- out = self.p_sample(
- model,
- img,
- t,
- clip_denoised=clip_denoised,
- denoised_fn=denoised_fn,
- cond_fn=cond_fn,
- model_kwargs=model_kwargs,
- )
- yield out
- img = out["sample"]
-
- def ddim_sample(
- self,
- model,
- x,
- t,
- clip_denoised=True,
- denoised_fn=None,
- cond_fn=None,
- model_kwargs=None,
- eta=0.0,
- ):
- """
- Sample x_{t-1} from the model using DDIM.
- Same usage as p_sample().
- """
- out = self.p_mean_variance(
- model,
- x,
- t,
- clip_denoised=clip_denoised,
- denoised_fn=denoised_fn,
- model_kwargs=model_kwargs,
- )
- if cond_fn is not None:
- out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
-
- # Usually our model outputs epsilon, but we re-derive it
- # in case we used x_start or x_prev prediction.
- eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
-
- alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
- alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
- sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
- # Equation 12.
- noise = th.randn_like(x)
- mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
- nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
- sample = mean_pred + nonzero_mask * sigma * noise
- return {"sample": sample, "pred_xstart": out["pred_xstart"]}
-
- def ddim_reverse_sample(
- self,
- model,
- x,
- t,
- clip_denoised=True,
- denoised_fn=None,
- cond_fn=None,
- model_kwargs=None,
- eta=0.0,
- ):
- """
- Sample x_{t+1} from the model using DDIM reverse ODE.
- """
- assert eta == 0.0, "Reverse ODE only for deterministic path"
- out = self.p_mean_variance(
- model,
- x,
- t,
- clip_denoised=clip_denoised,
- denoised_fn=denoised_fn,
- model_kwargs=model_kwargs,
- )
- if cond_fn is not None:
- out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
- # Usually our model outputs epsilon, but we re-derive it
- # in case we used x_start or x_prev prediction.
- eps = (
- _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
- ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
- alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
-
- # Equation 12. reversed
- mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
-
- return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
-
- def ddim_sample_loop(
- self,
- model,
- shape,
- noise=None,
- clip_denoised=True,
- denoised_fn=None,
- cond_fn=None,
- model_kwargs=None,
- device=None,
- progress=False,
- eta=0.0,
- ):
- """
- Generate samples from the model using DDIM.
- Same usage as p_sample_loop().
- """
- final = None
- for sample in self.ddim_sample_loop_progressive(
- model,
- shape,
- noise=noise,
- clip_denoised=clip_denoised,
- denoised_fn=denoised_fn,
- cond_fn=cond_fn,
- model_kwargs=model_kwargs,
- device=device,
- progress=progress,
- eta=eta,
- ):
- final = sample
- return final["sample"]
-
- def ddim_sample_loop_progressive(
- self,
- model,
- shape,
- noise=None,
- clip_denoised=True,
- denoised_fn=None,
- cond_fn=None,
- model_kwargs=None,
- device=None,
- progress=False,
- eta=0.0,
- ):
- """
- Use DDIM to sample from the model and yield intermediate samples from
- each timestep of DDIM.
- Same usage as p_sample_loop_progressive().
- """
- if device is None:
- device = next(model.parameters()).device
- assert isinstance(shape, (tuple, list))
- if noise is not None:
- img = noise
- else:
- img = th.randn(*shape, device=device)
- indices = list(range(self.num_timesteps))[::-1]
-
- if progress:
- # Lazy import so that we don't depend on tqdm.
- from tqdm.auto import tqdm
-
- indices = tqdm(indices)
-
- for i in indices:
- t = th.tensor([i] * shape[0], device=device)
- with th.no_grad():
- out = self.ddim_sample(
- model,
- img,
- t,
- clip_denoised=clip_denoised,
- denoised_fn=denoised_fn,
- cond_fn=cond_fn,
- model_kwargs=model_kwargs,
- eta=eta,
- )
- yield out
- img = out["sample"]
-
- def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
- """
- Get a term for the variational lower-bound.
- The resulting units are bits (rather than nats, as one might expect).
- This allows for comparison to other papers.
- :return: a dict with the following keys:
- - 'output': a shape [N] tensor of NLLs or KLs.
- - 'pred_xstart': the x_0 predictions.
- """
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
- out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
- kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
- kl = mean_flat(kl) / np.log(2.0)
-
- decoder_nll = -discretized_gaussian_log_likelihood(
- x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
- )
- assert decoder_nll.shape == x_start.shape
- decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
-
- # At the first timestep return the decoder NLL,
- # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
- output = th.where((t == 0), decoder_nll, kl)
- return {"output": output, "pred_xstart": out["pred_xstart"]}
-
- def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
- """
- Compute training losses for a single timestep.
- :param model: the model to evaluate loss on.
- :param x_start: the [N x C x ...] tensor of inputs.
- :param t: a batch of timestep indices.
- :param model_kwargs: if not None, a dict of extra keyword arguments to
- pass to the model. This can be used for conditioning.
- :param noise: if specified, the specific Gaussian noise to try to remove.
- :return: a dict with the key "loss" containing a tensor of shape [N].
- Some mean or variance settings may also have other keys.
- """
- if model_kwargs is None:
- model_kwargs = {}
- if noise is None:
- noise = th.randn_like(x_start)
- x_t = self.q_sample(x_start, t, noise=noise)
-
- terms = {}
-
- if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
- terms["loss"] = self._vb_terms_bpd(
- model=model,
- x_start=x_start,
- x_t=x_t,
- t=t,
- clip_denoised=False,
- model_kwargs=model_kwargs,
- )["output"]
- if self.loss_type == LossType.RESCALED_KL:
- terms["loss"] *= self.num_timesteps
- elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
- model_output = model(x_t, t, **model_kwargs)
-
- if self.model_var_type in [
- ModelVarType.LEARNED,
- ModelVarType.LEARNED_RANGE,
- ]:
- B, C = x_t.shape[:2]
- assert model_output.shape == (B, C * 2, *x_t.shape[2:])
- model_output, model_var_values = th.split(model_output, C, dim=1)
- # Learn the variance using the variational bound, but don't let
- # it affect our mean prediction.
- frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
- terms["vb"] = self._vb_terms_bpd(
- model=lambda *args, r=frozen_out: r,
- x_start=x_start,
- x_t=x_t,
- t=t,
- clip_denoised=False,
- )["output"]
- if self.loss_type == LossType.RESCALED_MSE:
- # Divide by 1000 for equivalence with initial implementation.
- # Without a factor of 1/1000, the VB term hurts the MSE term.
- terms["vb"] *= self.num_timesteps / 1000.0
-
- target = {
- ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
- ModelMeanType.START_X: x_start,
- ModelMeanType.EPSILON: noise,
- }[self.model_mean_type]
- assert model_output.shape == target.shape == x_start.shape
- terms["mse"] = mean_flat((target - model_output) ** 2)
- if "vb" in terms:
- terms["loss"] = terms["mse"] + terms["vb"]
- else:
- terms["loss"] = terms["mse"]
- else:
- raise NotImplementedError(self.loss_type)
-
- return terms
-
- def _prior_bpd(self, x_start):
- """
- Get the prior KL term for the variational lower-bound, measured in
- bits-per-dim.
- This term can't be optimized, as it only depends on the encoder.
- :param x_start: the [N x C x ...] tensor of inputs.
- :return: a batch of [N] KL values (in bits), one per batch element.
- """
- batch_size = x_start.shape[0]
- t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
- return mean_flat(kl_prior) / np.log(2.0)
-
- def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
- """
- Compute the entire variational lower-bound, measured in bits-per-dim,
- as well as other related quantities.
- :param model: the model to evaluate loss on.
- :param x_start: the [N x C x ...] tensor of inputs.
- :param clip_denoised: if True, clip denoised samples.
- :param model_kwargs: if not None, a dict of extra keyword arguments to
- pass to the model. This can be used for conditioning.
- :return: a dict containing the following keys:
- - total_bpd: the total variational lower-bound, per batch element.
- - prior_bpd: the prior term in the lower-bound.
- - vb: an [N x T] tensor of terms in the lower-bound.
- - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
- - mse: an [N x T] tensor of epsilon MSEs for each timestep.
- """
- device = x_start.device
- batch_size = x_start.shape[0]
-
- vb = []
- xstart_mse = []
- mse = []
- for t in list(range(self.num_timesteps))[::-1]:
- t_batch = th.tensor([t] * batch_size, device=device)
- noise = th.randn_like(x_start)
- x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
- # Calculate VLB term at the current timestep
- with th.no_grad():
- out = self._vb_terms_bpd(
- model,
- x_start=x_start,
- x_t=x_t,
- t=t_batch,
- clip_denoised=clip_denoised,
- model_kwargs=model_kwargs,
- )
- vb.append(out["output"])
- xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
- eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
- mse.append(mean_flat((eps - noise) ** 2))
-
- vb = th.stack(vb, dim=1)
- xstart_mse = th.stack(xstart_mse, dim=1)
- mse = th.stack(mse, dim=1)
-
- prior_bpd = self._prior_bpd(x_start)
- total_bpd = vb.sum(dim=1) + prior_bpd
- return {
- "total_bpd": total_bpd,
- "prior_bpd": prior_bpd,
- "vb": vb,
- "xstart_mse": xstart_mse,
- "mse": mse,
- }
-
-
-def _extract_into_tensor(arr, timesteps, broadcast_shape):
- """
- Extract values from a 1-D numpy array for a batch of indices.
- :param arr: the 1-D numpy array.
- :param timesteps: a tensor of indices into the array to extract.
- :param broadcast_shape: a larger shape of K dimensions with the batch
- dimension equal to the length of timesteps.
- :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
- """
- res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
- while len(res.shape) < len(broadcast_shape):
- res = res[..., None]
- return res + th.zeros(broadcast_shape, device=timesteps.device)
diff --git a/videosys/diffusion/respace.py b/videosys/diffusion/respace.py
deleted file mode 100644
index e5754aa70a9f221a2320ba7a56ab0cb5f4ed9188..0000000000000000000000000000000000000000
--- a/videosys/diffusion/respace.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# Modified from OpenAI's diffusion repos
-# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
-# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
-# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
-
-import numpy as np
-import torch as th
-
-from .gaussian_diffusion import GaussianDiffusion
-
-
-def space_timesteps(num_timesteps, section_counts):
- """
- Create a list of timesteps to use from an original diffusion process,
- given the number of timesteps we want to take from equally-sized portions
- of the original process.
- For example, if there's 300 timesteps and the section counts are [10,15,20]
- then the first 100 timesteps are strided to be 10 timesteps, the second 100
- are strided to be 15 timesteps, and the final 100 are strided to be 20.
- If the stride is a string starting with "ddim", then the fixed striding
- from the DDIM paper is used, and only one section is allowed.
- :param num_timesteps: the number of diffusion steps in the original
- process to divide up.
- :param section_counts: either a list of numbers, or a string containing
- comma-separated numbers, indicating the step count
- per section. As a special case, use "ddimN" where N
- is a number of steps to use the striding from the
- DDIM paper.
- :return: a set of diffusion steps from the original process to use.
- """
- if isinstance(section_counts, str):
- if section_counts.startswith("ddim"):
- desired_count = int(section_counts[len("ddim") :])
- for i in range(1, num_timesteps):
- if len(range(0, num_timesteps, i)) == desired_count:
- return set(range(0, num_timesteps, i))
- raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
- section_counts = [int(x) for x in section_counts.split(",")]
- size_per = num_timesteps // len(section_counts)
- extra = num_timesteps % len(section_counts)
- start_idx = 0
- all_steps = []
- for i, section_count in enumerate(section_counts):
- size = size_per + (1 if i < extra else 0)
- if size < section_count:
- raise ValueError(f"cannot divide section of {size} steps into {section_count}")
- if section_count <= 1:
- frac_stride = 1
- else:
- frac_stride = (size - 1) / (section_count - 1)
- cur_idx = 0.0
- taken_steps = []
- for _ in range(section_count):
- taken_steps.append(start_idx + round(cur_idx))
- cur_idx += frac_stride
- all_steps += taken_steps
- start_idx += size
- return set(all_steps)
-
-
-class SpacedDiffusion(GaussianDiffusion):
- """
- A diffusion process which can skip steps in a base diffusion process.
- :param use_timesteps: a collection (sequence or set) of timesteps from the
- original diffusion process to retain.
- :param kwargs: the kwargs to create the base diffusion process.
- """
-
- def __init__(self, use_timesteps, **kwargs):
- self.use_timesteps = set(use_timesteps)
- self.timestep_map = []
- self.original_num_steps = len(kwargs["betas"])
-
- base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
- last_alpha_cumprod = 1.0
- new_betas = []
- for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
- if i in self.use_timesteps:
- new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
- last_alpha_cumprod = alpha_cumprod
- self.timestep_map.append(i)
- kwargs["betas"] = np.array(new_betas)
- super().__init__(**kwargs)
-
- def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
- return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
-
- def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
- return super().training_losses(self._wrap_model(model), *args, **kwargs)
-
- def condition_mean(self, cond_fn, *args, **kwargs):
- return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
-
- def condition_score(self, cond_fn, *args, **kwargs):
- return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
-
- def _wrap_model(self, model):
- if isinstance(model, _WrappedModel):
- return model
- return _WrappedModel(model, self.timestep_map, self.original_num_steps)
-
- def _scale_timesteps(self, t):
- # Scaling is done by the wrapped model.
- return t
-
-
-class _WrappedModel:
- def __init__(self, model, timestep_map, original_num_steps):
- self.model = model
- self.timestep_map = timestep_map
- # self.rescale_timesteps = rescale_timesteps
- self.original_num_steps = original_num_steps
-
- def __call__(self, x, ts, **kwargs):
- map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
- new_ts = map_tensor[ts]
- # if self.rescale_timesteps:
- # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
- return self.model(x, new_ts, **kwargs)
diff --git a/videosys/diffusion/timestep_sampler.py b/videosys/diffusion/timestep_sampler.py
deleted file mode 100644
index fdaa45acfcf239d7b6aaf5a83ee12fd553bc06b8..0000000000000000000000000000000000000000
--- a/videosys/diffusion/timestep_sampler.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# Modified from OpenAI's diffusion repos
-# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
-# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
-# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
-
-from abc import ABC, abstractmethod
-
-import numpy as np
-import torch as th
-import torch.distributed as dist
-
-
-def create_named_schedule_sampler(name, diffusion):
- """
- Create a ScheduleSampler from a library of pre-defined samplers.
- :param name: the name of the sampler.
- :param diffusion: the diffusion object to sample for.
- """
- if name == "uniform":
- return UniformSampler(diffusion)
- elif name == "loss-second-moment":
- return LossSecondMomentResampler(diffusion)
- else:
- raise NotImplementedError(f"unknown schedule sampler: {name}")
-
-
-class ScheduleSampler(ABC):
- """
- A distribution over timesteps in the diffusion process, intended to reduce
- variance of the objective.
- By default, samplers perform unbiased importance sampling, in which the
- objective's mean is unchanged.
- However, subclasses may override sample() to change how the resampled
- terms are reweighted, allowing for actual changes in the objective.
- """
-
- @abstractmethod
- def weights(self):
- """
- Get a numpy array of weights, one per diffusion step.
- The weights needn't be normalized, but must be positive.
- """
-
- def sample(self, batch_size, device):
- """
- Importance-sample timesteps for a batch.
- :param batch_size: the number of timesteps.
- :param device: the torch device to save to.
- :return: a tuple (timesteps, weights):
- - timesteps: a tensor of timestep indices.
- - weights: a tensor of weights to scale the resulting losses.
- """
- w = self.weights()
- p = w / np.sum(w)
- indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
- indices = th.from_numpy(indices_np).long().to(device)
- weights_np = 1 / (len(p) * p[indices_np])
- weights = th.from_numpy(weights_np).float().to(device)
- return indices, weights
-
-
-class UniformSampler(ScheduleSampler):
- def __init__(self, diffusion):
- self.diffusion = diffusion
- self._weights = np.ones([diffusion.num_timesteps])
-
- def weights(self):
- return self._weights
-
-
-class LossAwareSampler(ScheduleSampler):
- def update_with_local_losses(self, local_ts, local_losses):
- """
- Update the reweighting using losses from a model.
- Call this method from each rank with a batch of timesteps and the
- corresponding losses for each of those timesteps.
- This method will perform synchronization to make sure all of the ranks
- maintain the exact same reweighting.
- :param local_ts: an integer Tensor of timesteps.
- :param local_losses: a 1D Tensor of losses.
- """
- batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())]
- dist.all_gather(
- batch_sizes,
- th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
- )
-
- # Pad all_gather batches to be the maximum batch size.
- batch_sizes = [x.item() for x in batch_sizes]
- max_bs = max(batch_sizes)
-
- timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
- loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
- dist.all_gather(timestep_batches, local_ts)
- dist.all_gather(loss_batches, local_losses)
- timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]]
- losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
- self.update_with_all_losses(timesteps, losses)
-
- @abstractmethod
- def update_with_all_losses(self, ts, losses):
- """
- Update the reweighting using losses from a model.
- Sub-classes should override this method to update the reweighting
- using losses from the model.
- This method directly updates the reweighting without synchronizing
- between workers. It is called by update_with_local_losses from all
- ranks with identical arguments. Thus, it should have deterministic
- behavior to maintain state across workers.
- :param ts: a list of int timesteps.
- :param losses: a list of float losses, one per timestep.
- """
-
-
-class LossSecondMomentResampler(LossAwareSampler):
- def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
- self.diffusion = diffusion
- self.history_per_term = history_per_term
- self.uniform_prob = uniform_prob
- self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64)
- self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
-
- def weights(self):
- if not self._warmed_up():
- return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
- weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
- weights /= np.sum(weights)
- weights *= 1 - self.uniform_prob
- weights += self.uniform_prob / len(weights)
- return weights
-
- def update_with_all_losses(self, ts, losses):
- for t, loss in zip(ts, losses):
- if self._loss_counts[t] == self.history_per_term:
- # Shift out the oldest loss term.
- self._loss_history[t, :-1] = self._loss_history[t, 1:]
- self._loss_history[t, -1] = loss
- else:
- self._loss_history[t, self._loss_counts[t]] = loss
- self._loss_counts[t] += 1
-
- def _warmed_up(self):
- return (self._loss_counts == self.history_per_term).all()
diff --git a/eval/pab/commom_metrics/__init__.py b/videosys/models/autoencoders/__init__.py
similarity index 100%
rename from eval/pab/commom_metrics/__init__.py
rename to videosys/models/autoencoders/__init__.py
diff --git a/videosys/models/cogvideo/autoencoder_kl.py b/videosys/models/autoencoders/autoencoder_kl_cogvideox.py
similarity index 66%
rename from videosys/models/cogvideo/autoencoder_kl.py
rename to videosys/models/autoencoders/autoencoder_kl_cogvideox.py
index b5e52a2b80b50346adcd95a4dca9693884a5a3d3..aefcd039aabcfb9c7746ed46e48368d0f0926154 100644
--- a/videosys/models/cogvideo/autoencoder_kl.py
+++ b/videosys/models/autoencoders/autoencoder_kl_cogvideox.py
@@ -20,16 +20,16 @@ from diffusers.models.activations import get_activation
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
-from diffusers.utils import logging
from diffusers.utils.accelerate_utils import apply_forward_hook
-from .modules import CogVideoXDownsample3D, CogVideoXUpsample3D
+from videosys.utils.logging import logger
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+from ..modules.downsampling import CogVideoXDownsample3D
+from ..modules.upsampling import CogVideoXUpsample3D
class CogVideoXSafeConv3d(nn.Conv3d):
- """
+ r"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
"""
@@ -61,12 +61,12 @@ class CogVideoXCausalConv3d(nn.Module):
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
Args:
- in_channels (int): Number of channels in the input tensor.
- out_channels (int): Number of output channels.
- kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
- stride (int, optional): Stride of the convolution. Default is 1.
- dilation (int, optional): Dilation rate of the convolution. Default is 1.
- pad_mode (str, optional): Padding mode. Default is "constant".
+ in_channels (`int`): Number of channels in the input tensor.
+ out_channels (`int`): Number of output channels produced by the convolution.
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
+ stride (`int`, defaults to `1`): Stride of the convolution.
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
"""
def __init__(
@@ -111,19 +111,10 @@ class CogVideoXCausalConv3d(nn.Module):
self.conv_cache = None
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
- dim = self.temporal_dim
kernel_size = self.time_kernel_size
- if kernel_size == 1:
- return inputs
-
- inputs = inputs.transpose(0, dim)
-
- if self.conv_cache is not None:
- inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
- else:
- inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
-
- inputs = inputs.transpose(0, dim).contiguous()
+ if kernel_size > 1:
+ cached_inputs = [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs
def _clear_fake_context_parallel_cache(self):
@@ -131,16 +122,17 @@ class CogVideoXCausalConv3d(nn.Module):
self.conv_cache = None
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
- input_parallel = self.fake_context_parallel_forward(inputs)
+ inputs = self.fake_context_parallel_forward(inputs)
self._clear_fake_context_parallel_cache()
- self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
+ # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
+ # hundred megabytes and so let's not do it for now
+ self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
- input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
- output_parallel = self.conv(input_parallel)
- output = output_parallel
+ output = self.conv(inputs)
return output
@@ -156,6 +148,8 @@ class CogVideoXSpatialNorm3D(nn.Module):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
+ groups (`int`):
+ Number of groups to separate the channels into for group normalization.
"""
def __init__(
@@ -190,17 +184,26 @@ class CogVideoXResnetBlock3D(nn.Module):
A 3D ResNet block used in the CogVideoX model.
Args:
- in_channels (int): Number of input channels.
- out_channels (Optional[int], optional):
- Number of output channels. If None, defaults to `in_channels`. Default is None.
- dropout (float, optional): Dropout rate. Default is 0.0.
- temb_channels (int, optional): Number of time embedding channels. Default is 512.
- groups (int, optional): Number of groups for group normalization. Default is 32.
- eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
- non_linearity (str, optional): Activation function to use. Default is "swish".
- conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
- spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
- pad_mode (str, optional): Padding mode. Default is "first".
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ non_linearity (`str`, defaults to `"swish"`):
+ Activation function to use.
+ conv_shortcut (bool, defaults to `False`):
+ Whether or not to use a convolution shortcut.
+ spatial_norm_dim (`int`, *optional*):
+ The dimension to use for spatial norm if it is to be used instead of group norm.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
"""
def __init__(
@@ -302,18 +305,28 @@ class CogVideoXDownBlock3D(nn.Module):
A downsampling block used in the CogVideoX model.
Args:
- in_channels (int): Number of input channels.
- out_channels (int): Number of output channels.
- temb_channels (int): Number of time embedding channels.
- dropout (float, optional): Dropout rate. Default is 0.0.
- num_layers (int, optional): Number of layers in the block. Default is 1.
- resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
- resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
- resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
- add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
- downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
- compress_time (bool, optional): If True, apply temporal compression. Default is False.
- pad_mode (str, optional): Padding mode. Default is "first".
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ resnet_groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ add_downsample (`bool`, defaults to `True`):
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to downsample across temporal dimension.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
"""
_supports_gradient_checkpointing = True
@@ -398,15 +411,24 @@ class CogVideoXMidBlock3D(nn.Module):
A middle block used in the CogVideoX model.
Args:
- in_channels (int): Number of input channels.
- temb_channels (int): Number of time embedding channels.
- dropout (float, optional): Dropout rate. Default is 0.0.
- num_layers (int, optional): Number of layers in the block. Default is 1.
- resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
- resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
- resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
- spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
- pad_mode (str, optional): Padding mode. Default is "first".
+ in_channels (`int`):
+ Number of input channels.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ resnet_groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ spatial_norm_dim (`int`, *optional*):
+ The dimension to use for spatial norm if it is to be used instead of group norm.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
"""
_supports_gradient_checkpointing = True
@@ -473,19 +495,30 @@ class CogVideoXUpBlock3D(nn.Module):
An upsampling block used in the CogVideoX model.
Args:
- in_channels (int): Number of input channels.
- out_channels (int): Number of output channels.
- temb_channels (int): Number of time embedding channels.
- dropout (float, optional): Dropout rate. Default is 0.0.
- num_layers (int, optional): Number of layers in the block. Default is 1.
- resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
- resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
- resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
- spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
- add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
- upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
- compress_time (bool, optional): If True, apply temporal compression. Default is False.
- pad_mode (str, optional): Padding mode. Default is "first".
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ resnet_groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ spatial_norm_dim (`int`, defaults to `16`):
+ The dimension to use for spatial norm if it is to be used instead of group norm.
+ add_upsample (`bool`, defaults to `True`):
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to downsample across temporal dimension.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
"""
def __init__(
@@ -576,14 +609,12 @@ class CogVideoXEncoder3D(nn.Module):
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
- act_fn (`str`, *optional*, defaults to `"silu"`):
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
- double_z (`bool`, *optional*, defaults to `True`):
- Whether to double the number of output channels for the last block.
"""
_supports_gradient_checkpointing = True
@@ -712,14 +743,12 @@ class CogVideoXDecoder3D(nn.Module):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
- act_fn (`str`, *optional*, defaults to `"silu"`):
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
- norm_type (`str`, *optional*, defaults to `"group"`):
- The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
_supports_gradient_checkpointing = True
@@ -860,7 +889,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
- scaling_factor (`float`, *optional*, defaults to 0.18215):
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
@@ -900,7 +929,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
norm_eps: float = 1e-6,
norm_num_groups: int = 32,
temporal_compression_ratio: float = 4,
- sample_size: int = 256,
+ sample_height: int = 480,
+ sample_width: int = 720,
scaling_factor: float = 1.15258426,
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None,
@@ -939,25 +969,105 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_slicing = False
self.use_tiling = False
- self.tile_sample_min_size = self.config.sample_size
- sample_size = (
- self.config.sample_size[0]
- if isinstance(self.config.sample_size, (list, tuple))
- else self.config.sample_size
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
+ # If you decode X latent frames together, the number of output frames is:
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
+ #
+ # Example with num_latent_frames_batch_size = 2:
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
+ # => 6 * 8 = 48 frames
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
+ # => 1 * 9 + 5 * 8 = 49 frames
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
+ # number of temporal frames.
+ self.num_latent_frames_batch_size = 2
+
+ # We make the minimum height and width of sample for tiling half that of the generally supported
+ self.tile_sample_min_height = sample_height // 2
+ self.tile_sample_min_width = sample_width // 2
+ self.tile_latent_min_height = int(
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
- self.tile_overlap_factor = 0.25
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
+
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
+ # and so the tiling implementation has only been tested on those specific resolutions.
+ self.tile_overlap_factor_height = 1 / 6
+ self.tile_overlap_factor_width = 1 / 5
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value
- def clear_fake_context_parallel_cache(self):
+ def _clear_fake_context_parallel_cache(self):
for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d):
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
module._clear_fake_context_parallel_cache()
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_overlap_factor_height: Optional[float] = None,
+ tile_overlap_factor_width: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_overlap_factor_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
+ value might cause more tiles to be processed leading to slow down of the decoding process.
+ tile_overlap_factor_width (`int`, *optional*):
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
+ value might cause more tiles to be processed leading to slow down of the decoding process.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_latent_min_height = int(
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
+ )
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
@@ -982,8 +1092,34 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ frame_batch_size = self.num_latent_frames_batch_size
+ dec = []
+ for i in range(num_frames // frame_batch_size):
+ remaining_frames = num_frames % frame_batch_size
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
+ z_intermediate = z[:, :, start_frame:end_frame]
+ if self.post_quant_conv is not None:
+ z_intermediate = self.post_quant_conv(z_intermediate)
+ z_intermediate = self.decoder(z_intermediate)
+ dec.append(z_intermediate)
+
+ self._clear_fake_context_parallel_cache()
+ dec = torch.cat(dec, dim=2)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
@apply_forward_hook
- def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
@@ -996,13 +1132,111 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
"""
- if self.post_quant_conv is not None:
- z = self.post_quant_conv(z)
- dec = self.decoder(z)
+ # Rough memory assessment:
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
+ # - Assume fp16 (2 bytes per value).
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
+ #
+ # Memory assessment when using tiling:
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
+
+ batch_size, num_channels, num_frames, height, width = z.shape
+
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
+ frame_batch_size = self.num_latent_frames_batch_size
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, overlap_height):
+ row = []
+ for j in range(0, width, overlap_width):
+ time = []
+ for k in range(num_frames // frame_batch_size):
+ remaining_frames = num_frames % frame_batch_size
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
+ tile = z[
+ :,
+ :,
+ start_frame:end_frame,
+ i : i + self.tile_latent_min_height,
+ j : j + self.tile_latent_min_width,
+ ]
+ if self.post_quant_conv is not None:
+ tile = self.post_quant_conv(tile)
+ tile = self.decoder(tile)
+ time.append(tile)
+ self._clear_fake_context_parallel_cache()
+ row.append(torch.cat(time, dim=2))
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)
+
if not return_dict:
return (dec,)
+
return DecoderOutput(sample=dec)
def forward(
diff --git a/videosys/models/open_sora/vae.py b/videosys/models/autoencoders/autoencoder_kl_open_sora.py
similarity index 98%
rename from videosys/models/open_sora/vae.py
rename to videosys/models/autoencoders/autoencoder_kl_open_sora.py
index ae3f92292c4b266f76554a225f55c67c379c3cd1..2919b2960475bb856e67afc878599fbe187340c6 100644
--- a/videosys/models/open_sora/vae.py
+++ b/videosys/models/autoencoders/autoencoder_kl_open_sora.py
@@ -18,8 +18,6 @@ from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from einops import rearrange
from transformers import PretrainedConfig, PreTrainedModel
-from .utils import load_checkpoint
-
class DiagonalGaussianDistribution(object):
def __init__(
@@ -474,7 +472,7 @@ class VAE_Temporal(nn.Module):
return recon_video, posterior, z
-def VAE_Temporal_SD(from_pretrained=None, **kwargs):
+def VAE_Temporal_SD(**kwargs):
model = VAE_Temporal(
in_out_channels=4,
latent_embed_dim=4,
@@ -485,8 +483,6 @@ def VAE_Temporal_SD(from_pretrained=None, **kwargs):
temporal_downsample=(False, True, True),
**kwargs,
)
- if from_pretrained is not None:
- load_checkpoint(model, from_pretrained)
return model
@@ -634,7 +630,7 @@ class VideoAutoencoderPipeline(PreTrainedModel):
micro_batch_size=4,
subfolder="vae",
)
- self.temporal_vae = VAE_Temporal_SD(from_pretrained=None)
+ self.temporal_vae = VAE_Temporal_SD()
self.cal_loss = config.cal_loss
self.micro_frame_size = config.micro_frame_size
self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]
@@ -763,7 +759,4 @@ def OpenSoraVAE_V1_2(
else:
config = VideoAutoencoderPipelineConfig(**kwargs)
model = VideoAutoencoderPipeline(config)
-
- if from_pretrained:
- load_checkpoint(model, from_pretrained)
return model
diff --git a/videosys/models/open_sora_plan/ae.py b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py
similarity index 52%
rename from videosys/models/open_sora_plan/ae.py
rename to videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py
index cd023d44fbd91c65e1d0e4092f58d6cb8ae87e5a..162060c5f46bddf9fb329f7b3066cf66fd123877 100644
--- a/videosys/models/open_sora_plan/ae.py
+++ b/videosys/models/autoencoders/autoencoder_kl_open_sora_plan.py
@@ -6,20 +6,24 @@
# References:
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
# --------------------------------------------------------
-
import glob
-import importlib
import os
from typing import Optional, Tuple, Union
import numpy as np
import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import logging
from einops import rearrange
from torch import nn
+logging.set_verbosity_error()
+
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
@@ -80,13 +84,7 @@ class DiagonalGaussianDistribution(object):
def resolve_str_to_obj(str_val, append=True):
- if append:
- str_val = "videosys.models.open_sora_plan.modules." + str_val
- if "opensora.models.ae.videobase." in str_val:
- str_val = str_val.replace("opensora.models.ae.videobase.", "videosys.models.open_sora_plan.")
- module_name, class_name = str_val.rsplit(".", 1)
- module = importlib.import_module(module_name)
- return getattr(module, class_name)
+ return globals()[str_val]
class VideoBaseAE_PL(ModelMixin, ConfigMixin):
@@ -130,7 +128,6 @@ class VideoBaseAE_PL(ModelMixin, ConfigMixin):
model.init_from_ckpt(last_ckpt_file)
return model
else:
- print(f"Loading model from {pretrained_model_name_or_path}")
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
@@ -431,8 +428,6 @@ class CausalVAEModel(VideoBaseAE_PL):
self.learning_rate = lr
self.lr_g_factor = 1.0
- self.loss = resolve_str_to_obj(loss_type, append=False)(**loss_params)
-
self.encoder = Encoder(
z_channels=z_channels,
hidden_size=hidden_size,
@@ -471,8 +466,6 @@ class CausalVAEModel(VideoBaseAE_PL):
quant_conv_cls = resolve_str_to_obj(q_conv)
self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
- if hasattr(self.loss, "discriminator"):
- self.automatic_optimization = False
def encode(self, x):
if self.use_tiling and (
@@ -855,3 +848,793 @@ def getae_wrapper(ae):
ae = videobase_ae.get(ae, None)
assert ae is not None
return ae
+
+
+def video_to_image(func):
+ def wrapper(self, x, *args, **kwargs):
+ if x.dim() == 5:
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ x = func(self, x, *args, **kwargs)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
+ return x
+
+ return wrapper
+
+
+class Block(nn.Module):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+
+class LinearAttention(Block):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock3D(Block):
+ """Compatible with old versions, there are issues, use with caution."""
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, t, h, w = q.shape
+ q = q.reshape(b * t, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b * t, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b * t, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, t, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class AttnBlock3DFix(nn.Module):
+ """
+ Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172.
+ """
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c)
+ b, c, t, h, w = q.shape
+ q = q.permute(0, 2, 1, 3, 4)
+ q = q.reshape(b * t, c, h * w)
+ q = q.permute(0, 2, 1)
+
+ # k: (b c t h w) -> (b t c h w) -> (b*t c h*w)
+ k = k.permute(0, 2, 1, 3, 4)
+ k = k.reshape(b * t, c, h * w)
+
+ # w: (b*t hw hw)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ # v: (b c t h w) -> (b t c h w) -> (bt c hw)
+ # w_: (bt hw hw) -> (bt hw hw)
+ v = v.permute(0, 2, 1, 3, 4)
+ v = v.reshape(b * t, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+
+ # h_: (b*t c hw) -> (b t c h w) -> (b c t h w)
+ h_ = h_.reshape(b, t, c, h, w)
+ h_ = h_.permute(0, 2, 1, 3, 4)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class AttnBlock(Block):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class TemporalAttnBlock(Block):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, t, h, w = q.shape
+ q = rearrange(q, "b c t h w -> (b h w) t c")
+ k = rearrange(k, "b c t h w -> (b h w) c t")
+ v = rearrange(v, "b c t h w -> (b h w) c t")
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ print(attn_type)
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla3D":
+ return AttnBlock3D(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Conv2d(nn.Conv2d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int]] = 3,
+ stride: Union[int, Tuple[int]] = 1,
+ padding: Union[str, int, Tuple[int]] = 0,
+ dilation: Union[int, Tuple[int]] = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = "zeros",
+ device=None,
+ dtype=None,
+ ) -> None:
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups,
+ bias,
+ padding_mode,
+ device,
+ dtype,
+ )
+
+ @video_to_image
+ def forward(self, x):
+ return super().forward(x)
+
+
+class CausalConv3d(nn.Module):
+ def __init__(
+ self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
+ ):
+ super().__init__()
+ self.kernel_size = cast_tuple(kernel_size, 3)
+ self.time_kernel_size = self.kernel_size[0]
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ stride = kwargs.pop("stride", 1)
+ padding = kwargs.pop("padding", 0)
+ padding = list(cast_tuple(padding, 3))
+ padding[0] = 0
+ stride = cast_tuple(stride, 3)
+ self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
+ self._init_weights(init_method)
+
+ def _init_weights(self, init_method):
+ torch.tensor(self.kernel_size)
+ if init_method == "avg":
+ assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample"
+ assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
+ weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
+
+ eyes = torch.concat(
+ [
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
+ ],
+ dim=-1,
+ )
+ weight[:, :, :, 0, 0] = eyes
+
+ self.conv.weight = nn.Parameter(
+ weight,
+ requires_grad=True,
+ )
+ elif init_method == "zero":
+ self.conv.weight = nn.Parameter(
+ torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
+ requires_grad=True,
+ )
+ if self.conv.bias is not None:
+ nn.init.constant_(self.conv.bias, 0)
+
+ def forward(self, x):
+ # 1 + 16 16 as video, 1 as image
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w
+ x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
+ return self.conv(x)
+
+
+class GroupNorm(Block):
+ def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True)
+
+ def forward(self, x):
+ return self.norm(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class ActNorm(nn.Module):
+ def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
+ std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height * width * torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
+
+
+def nonlinearity(x):
+ return x * torch.sigmoid(x)
+
+
+def cast_tuple(t, length=1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+
+def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
+ n_dims = len(x.shape)
+ if src_dim < 0:
+ src_dim = n_dims + src_dim
+ if dest_dim < 0:
+ dest_dim = n_dims + dest_dim
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
+ dims = list(range(n_dims))
+ del dims[src_dim]
+ permutation = []
+ ctr = 0
+ for i in range(n_dims):
+ if i == dest_dim:
+ permutation.append(src_dim)
+ else:
+ permutation.append(dims[ctr])
+ ctr += 1
+ x = x.permute(permutation)
+ if make_contiguous:
+ x = x.contiguous()
+ return x
+
+
+class Codebook(nn.Module):
+ def __init__(self, n_codes, embedding_dim):
+ super().__init__()
+ self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
+ self.register_buffer("N", torch.zeros(n_codes))
+ self.register_buffer("z_avg", self.embeddings.data.clone())
+
+ self.n_codes = n_codes
+ self.embedding_dim = embedding_dim
+ self._need_init = True
+
+ def _tile(self, x):
+ d, ew = x.shape
+ if d < self.n_codes:
+ n_repeats = (self.n_codes + d - 1) // d
+ std = 0.01 / np.sqrt(ew)
+ x = x.repeat(n_repeats, 1)
+ x = x + torch.randn_like(x) * std
+ return x
+
+ def _init_embeddings(self, z):
+ # z: [b, c, t, h, w]
+ self._need_init = False
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
+ y = self._tile(flat_inputs)
+
+ y.shape[0]
+ _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
+ if dist.is_initialized():
+ dist.broadcast(_k_rand, 0)
+ self.embeddings.data.copy_(_k_rand)
+ self.z_avg.data.copy_(_k_rand)
+ self.N.data.copy_(torch.ones(self.n_codes))
+
+ def forward(self, z):
+ # z: [b, c, t, h, w]
+ if self._need_init and self.training:
+ self._init_embeddings(z)
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
+ distances = (
+ (flat_inputs**2).sum(dim=1, keepdim=True)
+ - 2 * flat_inputs @ self.embeddings.t()
+ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
+ )
+
+ encoding_indices = torch.argmin(distances, dim=1)
+ encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
+ encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
+
+ embeddings = F.embedding(encoding_indices, self.embeddings)
+ embeddings = shift_dim(embeddings, -1, 1)
+
+ commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
+
+ # EMA codebook update
+ if self.training:
+ n_total = encode_onehot.sum(dim=0)
+ encode_sum = flat_inputs.t() @ encode_onehot
+ if dist.is_initialized():
+ dist.all_reduce(n_total)
+ dist.all_reduce(encode_sum)
+
+ self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
+ self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
+
+ n = self.N.sum()
+ weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
+ encode_normalized = self.z_avg / weights.unsqueeze(1)
+ self.embeddings.data.copy_(encode_normalized)
+
+ y = self._tile(flat_inputs)
+ _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
+ if dist.is_initialized():
+ dist.broadcast(_k_rand, 0)
+
+ usage = (self.N.view(self.n_codes, 1) >= 1).float()
+ self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
+
+ embeddings_st = (embeddings - z).detach() + z
+
+ avg_probs = torch.mean(encode_onehot, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ return dict(
+ embeddings=embeddings_st,
+ encodings=encoding_indices,
+ commitment_loss=commitment_loss,
+ perplexity=perplexity,
+ )
+
+ def dictionary_lookup(self, encodings):
+ embeddings = F.embedding(encodings, self.embeddings)
+ return embeddings
+
+
+class ResnetBlock2D(Block):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ x = x + h
+ return x
+
+
+class ResnetBlock3D(Block):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
+ else:
+ self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
+
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ return x + h
+
+
+class Upsample(Block):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.with_conv = True
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ @video_to_image
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(Block):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.with_conv = True
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
+
+ @video_to_image
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class SpatialDownsample2x(Block):
+ def __init__(
+ self,
+ chan_in,
+ chan_out,
+ kernel_size: Union[int, Tuple[int]] = (3, 3),
+ stride: Union[int, Tuple[int]] = (2, 2),
+ ):
+ super().__init__()
+ kernel_size = cast_tuple(kernel_size, 2)
+ stride = cast_tuple(stride, 2)
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ self.kernel_size = kernel_size
+ self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1, 0, 0)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class SpatialUpsample2x(Block):
+ def __init__(
+ self,
+ chan_in,
+ chan_out,
+ kernel_size: Union[int, Tuple[int]] = (3, 3),
+ stride: Union[int, Tuple[int]] = (1, 1),
+ ):
+ super().__init__()
+ self.chan_in = chan_in
+ self.chan_out = chan_out
+ self.kernel_size = kernel_size
+ self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1)
+
+ def forward(self, x):
+ t = x.shape[2]
+ x = rearrange(x, "b c t h w -> b (c t) h w")
+ x = F.interpolate(x, scale_factor=(2, 2), mode="nearest")
+ x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
+ x = self.conv(x)
+ return x
+
+
+class TimeDownsample2x(Block):
+ def __init__(self, chan_in, chan_out, kernel_size: int = 3):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
+
+ def forward(self, x):
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size - 1, 1, 1))
+ x = torch.concatenate((first_frame_pad, x), dim=2)
+ return self.conv(x)
+
+
+class TimeUpsample2x(Block):
+ def __init__(self, chan_in, chan_out):
+ super().__init__()
+
+ def forward(self, x):
+ if x.size(2) > 1:
+ x, x_ = x[:, :, :1], x[:, :, 1:]
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.concat([x, x_], dim=2)
+ return x
+
+
+class TimeDownsampleRes2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 2.0,
+ ):
+ super().__init__()
+ self.kernel_size = cast_tuple(kernel_size, 3)
+ self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
+ self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ alpha = torch.sigmoid(self.mix_factor)
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
+ x = torch.concatenate((first_frame_pad, x), dim=2)
+ return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x)
+
+
+class TimeUpsampleRes2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 2.0,
+ ):
+ super().__init__()
+ self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ alpha = torch.sigmoid(self.mix_factor)
+ if x.size(2) > 1:
+ x, x_ = x[:, :, :1], x[:, :, 1:]
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.concat([x, x_], dim=2)
+ return alpha * x + (1 - alpha) * self.conv(x)
+
+
+class TimeDownsampleResAdv2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 1.5,
+ ):
+ super().__init__()
+ self.kernel_size = cast_tuple(kernel_size, 3)
+ self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
+ self.attn = TemporalAttnBlock(in_channels)
+ self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
+ self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
+ x = torch.concatenate((first_frame_pad, x), dim=2)
+ alpha = torch.sigmoid(self.mix_factor)
+ return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x))))
+
+
+class TimeUpsampleResAdv2x(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size: int = 3,
+ mix_factor: float = 1.5,
+ ):
+ super().__init__()
+ self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
+ self.attn = TemporalAttnBlock(in_channels)
+ self.norm = Normalize(in_channels=in_channels)
+ self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
+
+ def forward(self, x):
+ if x.size(2) > 1:
+ x, x_ = x[:, :, :1], x[:, :, 1:]
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.concat([x, x_], dim=2)
+ alpha = torch.sigmoid(self.mix_factor)
+ return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x)))
diff --git a/videosys/models/cogvideo/__init__.py b/videosys/models/cogvideo/__init__.py
deleted file mode 100644
index 9e80e1830a6f3a5e53b9c315e377dce38fca82ed..0000000000000000000000000000000000000000
--- a/videosys/models/cogvideo/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from .pipeline import CogVideoConfig, CogVideoPipeline
-
-__all__ = [
- "CogVideoConfig",
- "CogVideoPipeline",
-]
diff --git a/videosys/models/cogvideo/cogvideox_transformer_3d.py b/videosys/models/cogvideo/cogvideox_transformer_3d.py
deleted file mode 100644
index 975e86b568aa0efb864d014f0ec698f597d87901..0000000000000000000000000000000000000000
--- a/videosys/models/cogvideo/cogvideox_transformer_3d.py
+++ /dev/null
@@ -1,339 +0,0 @@
-# Adapted from CogVideo
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# CogVideo: https://github.com/THUDM/CogVideo
-# diffusers: https://github.com/huggingface/diffusers
-# --------------------------------------------------------
-
-from typing import Any, Dict, Optional, Union
-
-import torch
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.models.attention import Attention, FeedForward
-from diffusers.models.embeddings import TimestepEmbedding, Timesteps
-from diffusers.models.modeling_outputs import Transformer2DModelOutput
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.utils import is_torch_version, logging
-from diffusers.utils.torch_utils import maybe_allow_in_graph
-from torch import nn
-
-from .modules import AdaLayerNorm, CogVideoXLayerNormZero, CogVideoXPatchEmbed, get_3d_sincos_pos_embed
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-@maybe_allow_in_graph
-class CogVideoXBlock(nn.Module):
- r"""
- Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
- attention_bias (:
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
- qk_norm (`bool`, defaults to `True`):
- Whether or not to use normalization after query and key projections in Attention.
- norm_elementwise_affine (`bool`, defaults to `True`):
- Whether to use learnable elementwise affine parameters for normalization.
- norm_eps (`float`, defaults to `1e-5`):
- Epsilon value for normalization layers.
- final_dropout (`bool` defaults to `False`):
- Whether to apply a final dropout after the last feed-forward layer.
- ff_inner_dim (`int`, *optional*, defaults to `None`):
- Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
- ff_bias (`bool`, defaults to `True`):
- Whether or not to use bias in Feed-forward layer.
- attention_out_bias (`bool`, defaults to `True`):
- Whether or not to use bias in Attention output projection layer.
- """
-
- def __init__(
- self,
- dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- time_embed_dim: int,
- dropout: float = 0.0,
- activation_fn: str = "gelu-approximate",
- attention_bias: bool = False,
- qk_norm: bool = True,
- norm_elementwise_affine: bool = True,
- norm_eps: float = 1e-5,
- final_dropout: bool = True,
- ff_inner_dim: Optional[int] = None,
- ff_bias: bool = True,
- attention_out_bias: bool = True,
- ):
- super().__init__()
-
- # 1. Self Attention
- self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
-
- self.attn1 = Attention(
- query_dim=dim,
- dim_head=attention_head_dim,
- heads=num_attention_heads,
- qk_norm="layer_norm" if qk_norm else None,
- eps=1e-6,
- bias=attention_bias,
- out_bias=attention_out_bias,
- )
-
- # 2. Feed Forward
- self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
-
- self.ff = FeedForward(
- dim,
- dropout=dropout,
- activation_fn=activation_fn,
- final_dropout=final_dropout,
- inner_dim=ff_inner_dim,
- bias=ff_bias,
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- temb: torch.Tensor,
- ) -> torch.Tensor:
- norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
- hidden_states, encoder_hidden_states, temb
- )
-
- # attention
- text_length = norm_encoder_hidden_states.size(1)
-
- # CogVideoX uses concatenated text + video embeddings with self-attention instead of using
- # them in cross-attention individually
- norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
- attn_output = self.attn1(
- hidden_states=norm_hidden_states,
- encoder_hidden_states=None,
- )
-
- hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
- encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
-
- # norm & modulate
- norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
- hidden_states, encoder_hidden_states, temb
- )
-
- # feed-forward
- norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
- ff_output = self.ff(norm_hidden_states)
-
- hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
- encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
- return hidden_states, encoder_hidden_states
-
-
-class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
- """
- A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
-
- Parameters:
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
- in_channels (`int`, *optional*):
- The number of channels in the input.
- out_channels (`int`, *optional*):
- The number of channels in the output.
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
- attention_bias (`bool`, *optional*):
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
- This is fixed during training since it is used to learn a number of position embeddings.
- patch_size (`int`, *optional*):
- The size of the patches to use in the patch embedding layer.
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
- num_embeds_ada_norm ( `int`, *optional*):
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
- added to the hidden states. During inference, you can denoise for up to but not more steps than
- `num_embeds_ada_norm`.
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
- The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
- Whether or not to use elementwise affine in normalization layers.
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
- caption_channels (`int`, *optional*):
- The number of channels in the caption embeddings.
- video_length (`int`, *optional*):
- The number of frames in the video-like data.
- """
-
- _supports_gradient_checkpointing = True
-
- @register_to_config
- def __init__(
- self,
- num_attention_heads: int = 30,
- attention_head_dim: int = 64,
- in_channels: Optional[int] = 16,
- out_channels: Optional[int] = 16,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- time_embed_dim: int = 512,
- text_embed_dim: int = 4096,
- num_layers: int = 30,
- dropout: float = 0.0,
- attention_bias: bool = True,
- sample_width: int = 90,
- sample_height: int = 60,
- sample_frames: int = 49,
- patch_size: int = 2,
- temporal_compression_ratio: int = 4,
- max_text_seq_length: int = 226,
- activation_fn: str = "gelu-approximate",
- timestep_activation_fn: str = "silu",
- norm_elementwise_affine: bool = True,
- norm_eps: float = 1e-5,
- spatial_interpolation_scale: float = 1.875,
- temporal_interpolation_scale: float = 1.0,
- ):
- super().__init__()
- inner_dim = num_attention_heads * attention_head_dim
-
- post_patch_height = sample_height // patch_size
- post_patch_width = sample_width // patch_size
- post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
- self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
-
- # 1. Patch embedding
- self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
- self.embedding_dropout = nn.Dropout(dropout)
-
- # 2. 3D positional embeddings
- spatial_pos_embedding = get_3d_sincos_pos_embed(
- inner_dim,
- (post_patch_width, post_patch_height),
- post_time_compression_frames,
- spatial_interpolation_scale,
- temporal_interpolation_scale,
- )
- spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
- pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
- pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
- self.register_buffer("pos_embedding", pos_embedding, persistent=False)
-
- # 3. Time embeddings
- self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
- self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
-
- # 4. Define spatio-temporal transformers blocks
- self.transformer_blocks = nn.ModuleList(
- [
- CogVideoXBlock(
- dim=inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- time_embed_dim=time_embed_dim,
- dropout=dropout,
- activation_fn=activation_fn,
- attention_bias=attention_bias,
- norm_elementwise_affine=norm_elementwise_affine,
- norm_eps=norm_eps,
- )
- for _ in range(num_layers)
- ]
- )
- self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
-
- # 5. Output blocks
- self.norm_out = AdaLayerNorm(
- embedding_dim=time_embed_dim,
- output_dim=2 * inner_dim,
- norm_elementwise_affine=norm_elementwise_affine,
- norm_eps=norm_eps,
- chunk_dim=1,
- )
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
-
- self.gradient_checkpointing = False
-
- def _set_gradient_checkpointing(self, module, value=False):
- self.gradient_checkpointing = value
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- timestep: Union[int, float, torch.LongTensor],
- timestep_cond: Optional[torch.Tensor] = None,
- return_dict: bool = True,
- ):
- batch_size, num_frames, channels, height, width = hidden_states.shape
-
- # 1. Time embedding
- timesteps = timestep
- t_emb = self.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=hidden_states.dtype)
- emb = self.time_embedding(t_emb, timestep_cond)
-
- # 2. Patch embedding
- hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
-
- # 3. Position embedding
- seq_length = height * width * num_frames // (self.config.patch_size**2)
-
- pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
- hidden_states = hidden_states + pos_embeds
- hidden_states = self.embedding_dropout(hidden_states)
-
- encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
- hidden_states = hidden_states[:, self.config.max_text_seq_length :]
-
- # 5. Transformer blocks
- for i, block in enumerate(self.transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- encoder_hidden_states,
- emb,
- **ckpt_kwargs,
- )
- else:
- hidden_states, encoder_hidden_states = block(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- temb=emb,
- )
-
- hidden_states = self.norm_final(hidden_states)
-
- # 6. Final block
- hidden_states = self.norm_out(hidden_states, temb=emb)
- hidden_states = self.proj_out(hidden_states)
-
- # 7. Unpatchify
- p = self.config.patch_size
- output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
- output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
-
- if not return_dict:
- return (output,)
- return Transformer2DModelOutput(sample=output)
diff --git a/videosys/models/cogvideo/modules.py b/videosys/models/cogvideo/modules.py
deleted file mode 100644
index 8d5dc49515ce5423026c130581fc16f4155333d4..0000000000000000000000000000000000000000
--- a/videosys/models/cogvideo/modules.py
+++ /dev/null
@@ -1,317 +0,0 @@
-# Adapted from CogVideo
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# CogVideo: https://github.com/THUDM/CogVideo
-# diffusers: https://github.com/huggingface/diffusers
-# --------------------------------------------------------
-
-from typing import Optional, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed_from_grid
-
-
-class CogVideoXDownsample3D(nn.Module):
- # Todo: Wait for paper relase.
- r"""
- A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
-
- Args:
- in_channels (`int`):
- Number of channels in the input image.
- out_channels (`int`):
- Number of channels produced by the convolution.
- kernel_size (`int`, defaults to `3`):
- Size of the convolving kernel.
- stride (`int`, defaults to `2`):
- Stride of the convolution.
- padding (`int`, defaults to `0`):
- Padding added to all four sides of the input.
- compress_time (`bool`, defaults to `False`):
- Whether or not to compress the time dimension.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int = 3,
- stride: int = 2,
- padding: int = 0,
- compress_time: bool = False,
- ):
- super().__init__()
-
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
- self.compress_time = compress_time
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.compress_time:
- batch_size, channels, frames, height, width = x.shape
-
- # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
- x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
-
- if x.shape[-1] % 2 == 1:
- x_first, x_rest = x[..., 0], x[..., 1:]
- if x_rest.shape[-1] > 0:
- # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
- x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
-
- x = torch.cat([x_first[..., None], x_rest], dim=-1)
- # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
- x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
- else:
- # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
- x = F.avg_pool1d(x, kernel_size=2, stride=2)
- # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
- x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
-
- # Pad the tensor
- pad = (0, 1, 0, 1)
- x = F.pad(x, pad, mode="constant", value=0)
- batch_size, channels, frames, height, width = x.shape
- # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
- x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
- x = self.conv(x)
- # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
- x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
- return x
-
-
-class CogVideoXUpsample3D(nn.Module):
- r"""
- A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
-
- Args:
- in_channels (`int`):
- Number of channels in the input image.
- out_channels (`int`):
- Number of channels produced by the convolution.
- kernel_size (`int`, defaults to `3`):
- Size of the convolving kernel.
- stride (`int`, defaults to `1`):
- Stride of the convolution.
- padding (`int`, defaults to `1`):
- Padding added to all four sides of the input.
- compress_time (`bool`, defaults to `False`):
- Whether or not to compress the time dimension.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int = 3,
- stride: int = 1,
- padding: int = 1,
- compress_time: bool = False,
- ) -> None:
- super().__init__()
-
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
- self.compress_time = compress_time
-
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
- if self.compress_time:
- if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
- # split first frame
- x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
-
- x_first = F.interpolate(x_first, scale_factor=2.0)
- x_rest = F.interpolate(x_rest, scale_factor=2.0)
- x_first = x_first[:, :, None, :, :]
- inputs = torch.cat([x_first, x_rest], dim=2)
- elif inputs.shape[2] > 1:
- inputs = F.interpolate(inputs, scale_factor=2.0)
- else:
- inputs = inputs.squeeze(2)
- inputs = F.interpolate(inputs, scale_factor=2.0)
- inputs = inputs[:, :, None, :, :]
- else:
- # only interpolate 2D
- b, c, t, h, w = inputs.shape
- inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
- inputs = F.interpolate(inputs, scale_factor=2.0)
- inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
-
- b, c, t, h, w = inputs.shape
- inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
- inputs = self.conv(inputs)
- inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
-
- return inputs
-
-
-def get_3d_sincos_pos_embed(
- embed_dim: int,
- spatial_size: Union[int, Tuple[int, int]],
- temporal_size: int,
- spatial_interpolation_scale: float = 1.0,
- temporal_interpolation_scale: float = 1.0,
-) -> np.ndarray:
- r"""
- Args:
- embed_dim (`int`):
- spatial_size (`int` or `Tuple[int, int]`):
- temporal_size (`int`):
- spatial_interpolation_scale (`float`, defaults to 1.0):
- temporal_interpolation_scale (`float`, defaults to 1.0):
- """
- if embed_dim % 4 != 0:
- raise ValueError("`embed_dim` must be divisible by 4")
- if isinstance(spatial_size, int):
- spatial_size = (spatial_size, spatial_size)
-
- embed_dim_spatial = 3 * embed_dim // 4
- embed_dim_temporal = embed_dim // 4
-
- # 1. Spatial
- grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
- grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
- grid = np.stack(grid, axis=0)
-
- grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
- pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
-
- # 2. Temporal
- grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
- pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
-
- # 3. Concat
- pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
- pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
-
- pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
- pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
-
- pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
- return pos_embed
-
-
-class CogVideoXPatchEmbed(nn.Module):
- def __init__(
- self,
- patch_size: int = 2,
- in_channels: int = 16,
- embed_dim: int = 1920,
- text_embed_dim: int = 4096,
- bias: bool = True,
- ) -> None:
- super().__init__()
- self.patch_size = patch_size
-
- self.proj = nn.Conv2d(
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
- )
- self.text_proj = nn.Linear(text_embed_dim, embed_dim)
-
- def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
- r"""
- Args:
- text_embeds (`torch.Tensor`):
- Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
- image_embeds (`torch.Tensor`):
- Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
- """
- text_embeds = self.text_proj(text_embeds)
-
- batch, num_frames, channels, height, width = image_embeds.shape
- image_embeds = image_embeds.reshape(-1, channels, height, width)
- image_embeds = self.proj(image_embeds)
- image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
- image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
- image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
-
- embeds = torch.cat(
- [text_embeds, image_embeds], dim=1
- ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
- return embeds
-
-
-class CogVideoXLayerNormZero(nn.Module):
- def __init__(
- self,
- conditioning_dim: int,
- embedding_dim: int,
- elementwise_affine: bool = True,
- eps: float = 1e-5,
- bias: bool = True,
- ) -> None:
- super().__init__()
-
- self.silu = nn.SiLU()
- self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
- self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
-
- def forward(
- self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
- hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
- encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
- return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
-
-
-class AdaLayerNorm(nn.Module):
- r"""
- Norm layer modified to incorporate timestep embeddings.
-
- Parameters:
- embedding_dim (`int`): The size of each embedding vector.
- num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
- output_dim (`int`, *optional*):
- norm_elementwise_affine (`bool`, defaults to `False):
- norm_eps (`bool`, defaults to `False`):
- chunk_dim (`int`, defaults to `0`):
- """
-
- def __init__(
- self,
- embedding_dim: int,
- num_embeddings: Optional[int] = None,
- output_dim: Optional[int] = None,
- norm_elementwise_affine: bool = False,
- norm_eps: float = 1e-5,
- chunk_dim: int = 0,
- ):
- super().__init__()
-
- self.chunk_dim = chunk_dim
- output_dim = output_dim or embedding_dim * 2
-
- if num_embeddings is not None:
- self.emb = nn.Embedding(num_embeddings, embedding_dim)
- else:
- self.emb = None
-
- self.silu = nn.SiLU()
- self.linear = nn.Linear(embedding_dim, output_dim)
- self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
-
- def forward(
- self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
- ) -> torch.Tensor:
- if self.emb is not None:
- temb = self.emb(timestep)
-
- temb = self.linear(self.silu(temb))
-
- if self.chunk_dim == 1:
- # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
- # other if-branch. This branch is specific to CogVideoX for now.
- shift, scale = temb.chunk(2, dim=1)
- shift = shift[:, None, :]
- scale = scale[:, None, :]
- else:
- scale, shift = temb.chunk(2, dim=0)
-
- x = self.norm(x) * (1 + scale) + shift
- return x
diff --git a/videosys/models/cogvideo/retrieve_timesteps.py b/videosys/models/cogvideo/retrieve_timesteps.py
deleted file mode 100644
index 9702ec47a610e7f3f778a98572ffcac6cfb7a6d0..0000000000000000000000000000000000000000
--- a/videosys/models/cogvideo/retrieve_timesteps.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# Adapted from CogVideo
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# CogVideo: https://github.com/THUDM/CogVideo
-# diffusers: https://github.com/huggingface/diffusers
-# --------------------------------------------------------
-
-import inspect
-from typing import List, Optional, Union
-
-import torch
-
-
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- sigmas: Optional[List[float]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
- must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
- `num_inference_steps` and `sigmas` must be `None`.
- sigmas (`List[float]`, *optional*):
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
- `num_inference_steps` and `timesteps` must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None and sigmas is not None:
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- elif sigmas is not None:
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accept_sigmas:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" sigmas schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
diff --git a/videosys/models/latte/__init__.py b/videosys/models/latte/__init__.py
deleted file mode 100644
index 3e277da72ca2bcc598f1b4334bc6303e6207f175..0000000000000000000000000000000000000000
--- a/videosys/models/latte/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .pipeline import LatteConfig, LattePABConfig, LattePipeline
-
-__all__ = [
- "LattePipeline",
- "LattePABConfig",
- "LatteConfig",
-]
diff --git a/eval/pab/experiments/__init__.py b/videosys/models/modules/__init__.py
similarity index 100%
rename from eval/pab/experiments/__init__.py
rename to videosys/models/modules/__init__.py
diff --git a/videosys/models/modules/activations.py b/videosys/models/modules/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf24149506f7b1841de46a04328cd425cf2986eb
--- /dev/null
+++ b/videosys/models/modules/activations.py
@@ -0,0 +1,3 @@
+import torch.nn as nn
+
+approx_gelu = lambda: nn.GELU(approximate="tanh")
diff --git a/videosys/models/modules/attentions.py b/videosys/models/modules/attentions.py
new file mode 100644
index 0000000000000000000000000000000000000000..70c2352b160443b35cac2c48942ea3c554280a72
--- /dev/null
+++ b/videosys/models/modules/attentions.py
@@ -0,0 +1,131 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from videosys.models.modules.normalization import LlamaRMSNorm
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = LlamaRMSNorm,
+ enable_flash_attn: bool = False,
+ rope=None,
+ qk_norm_legacy: bool = False,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.enable_flash_attn = enable_flash_attn
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.qk_norm_legacy = qk_norm_legacy
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.rope = False
+ if rope is not None:
+ self.rope = True
+ self.rotary_emb = rope
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+ # flash attn is not memory efficient for small sequences, this is empirical
+ enable_flash_attn = self.enable_flash_attn and (N > B)
+ qkv = self.qkv(x)
+ qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
+
+ qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ if self.qk_norm_legacy:
+ # WARNING: this may be a bug
+ if self.rope:
+ q = self.rotary_emb(q)
+ k = self.rotary_emb(k)
+ q, k = self.q_norm(q), self.k_norm(k)
+ else:
+ q, k = self.q_norm(q), self.k_norm(k)
+ if self.rope:
+ q = self.rotary_emb(q)
+ k = self.rotary_emb(k)
+
+ if enable_flash_attn:
+ from flash_attn import flash_attn_func
+
+ # (B, #heads, N, #dim) -> (B, N, #heads, #dim)
+ q = q.permute(0, 2, 1, 3)
+ k = k.permute(0, 2, 1, 3)
+ v = v.permute(0, 2, 1, 3)
+ x = flash_attn_func(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ softmax_scale=self.scale,
+ )
+ else:
+ dtype = q.dtype
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1) # translate attn to float32
+ attn = attn.to(torch.float32)
+ attn = attn.softmax(dim=-1)
+ attn = attn.to(dtype) # cast back attn to original dtype
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x_output_shape = (B, N, C)
+ if not enable_flash_attn:
+ x = x.transpose(1, 2)
+ x = x.reshape(x_output_shape)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MultiHeadCrossAttention(nn.Module):
+ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
+ super(MultiHeadCrossAttention, self).__init__()
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
+
+ self.d_model = d_model
+ self.num_heads = num_heads
+ self.head_dim = d_model // num_heads
+
+ self.q_linear = nn.Linear(d_model, d_model)
+ self.kv_linear = nn.Linear(d_model, d_model * 2)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(d_model, d_model)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, cond, mask=None):
+ # query/value: img tokens; key: condition; mask: if padding tokens
+ B, N, C = x.shape
+
+ q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
+ kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
+ k, v = kv.unbind(2)
+
+ attn_bias = None
+ # TODO: support torch computation
+ import xformers.ops
+
+ if mask is not None:
+ attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
+ x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
+
+ x = x.view(B, -1, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/videosys/models/modules/downsampling.py b/videosys/models/modules/downsampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..9455a32209a970a353ef66fc68282638a1eb6422
--- /dev/null
+++ b/videosys/models/modules/downsampling.py
@@ -0,0 +1,71 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class CogVideoXDownsample3D(nn.Module):
+ # Todo: Wait for paper relase.
+ r"""
+ A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
+
+ Args:
+ in_channels (`int`):
+ Number of channels in the input image.
+ out_channels (`int`):
+ Number of channels produced by the convolution.
+ kernel_size (`int`, defaults to `3`):
+ Size of the convolving kernel.
+ stride (`int`, defaults to `2`):
+ Stride of the convolution.
+ padding (`int`, defaults to `0`):
+ Padding added to all four sides of the input.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to compress the time dimension.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 2,
+ padding: int = 0,
+ compress_time: bool = False,
+ ):
+ super().__init__()
+
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.compress_time = compress_time
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.compress_time:
+ batch_size, channels, frames, height, width = x.shape
+
+ # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
+ x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
+
+ if x.shape[-1] % 2 == 1:
+ x_first, x_rest = x[..., 0], x[..., 1:]
+ if x_rest.shape[-1] > 0:
+ # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
+ x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
+
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
+ # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
+ else:
+ # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
+ # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
+
+ # Pad the tensor
+ pad = (0, 1, 0, 1)
+ x = F.pad(x, pad, mode="constant", value=0)
+ batch_size, channels, frames, height, width = x.shape
+ # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
+ x = self.conv(x)
+ # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
+ x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
+ return x
diff --git a/videosys/models/open_sora/modules.py b/videosys/models/modules/embeddings.py
similarity index 52%
rename from videosys/models/open_sora/modules.py
rename to videosys/models/modules/embeddings.py
index 6c127a6cdd5bca058b11b840248b8195f9e47713..13dd629e2cc0a89b927a094cab77891df9bcea8e 100644
--- a/videosys/models/open_sora/modules.py
+++ b/videosys/models/modules/embeddings.py
@@ -1,16 +1,8 @@
-# Adapted from OpenSora
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# OpenSora: https://github.com/hpcaitech/Open-Sora
-# --------------------------------------------------------
-
import functools
import math
-from typing import Optional
+from typing import Optional, Tuple, Union
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -18,40 +10,48 @@ import torch.utils.checkpoint
from einops import rearrange
from timm.models.vision_transformer import Mlp
-approx_gelu = lambda: nn.GELU(approximate="tanh")
-
-class LlamaRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
+class CogVideoXPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ embed_dim: int = 1920,
+ text_embed_dim: int = 4096,
+ bias: bool = True,
+ ) -> None:
super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool):
- return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
-
+ self.patch_size = patch_size
-def t2i_modulate(x, shift, scale):
- return x * (1 + scale) + shift
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ r"""
+ Args:
+ text_embeds (`torch.Tensor`):
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
+ image_embeds (`torch.Tensor`):
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
+ """
+ text_embeds = self.text_proj(text_embeds)
+ batch, num_frames, channels, height, width = image_embeds.shape
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
+ image_embeds = self.proj(image_embeds)
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
-# ===============================================
-# General-purpose Layers
-# ===============================================
+ embeds = torch.cat(
+ [text_embeds, image_embeds], dim=1
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
+ return embeds
-class PatchEmbed3D(nn.Module):
+class OpenSoraPatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
@@ -104,176 +104,6 @@ class PatchEmbed3D(nn.Module):
return x
-class Attention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- attn_drop: float = 0.0,
- proj_drop: float = 0.0,
- norm_layer: nn.Module = LlamaRMSNorm,
- enable_flash_attn: bool = False,
- rope=None,
- qk_norm_legacy: bool = False,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, "dim should be divisible by num_heads"
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim**-0.5
- self.enable_flash_attn = enable_flash_attn
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.qk_norm_legacy = qk_norm_legacy
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- self.rope = False
- if rope is not None:
- self.rope = True
- self.rotary_emb = rope
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- B, N, C = x.shape
- # flash attn is not memory efficient for small sequences, this is empirical
- enable_flash_attn = self.enable_flash_attn and (N > B)
- qkv = self.qkv(x)
- qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
-
- qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0)
- if self.qk_norm_legacy:
- # WARNING: this may be a bug
- if self.rope:
- q = self.rotary_emb(q)
- k = self.rotary_emb(k)
- q, k = self.q_norm(q), self.k_norm(k)
- else:
- q, k = self.q_norm(q), self.k_norm(k)
- if self.rope:
- q = self.rotary_emb(q)
- k = self.rotary_emb(k)
-
- if enable_flash_attn:
- from flash_attn import flash_attn_func
-
- # (B, #heads, N, #dim) -> (B, N, #heads, #dim)
- q = q.permute(0, 2, 1, 3)
- k = k.permute(0, 2, 1, 3)
- v = v.permute(0, 2, 1, 3)
- x = flash_attn_func(
- q,
- k,
- v,
- dropout_p=self.attn_drop.p if self.training else 0.0,
- softmax_scale=self.scale,
- )
- else:
- dtype = q.dtype
- q = q * self.scale
- attn = q @ k.transpose(-2, -1) # translate attn to float32
- attn = attn.to(torch.float32)
- attn = attn.softmax(dim=-1)
- attn = attn.to(dtype) # cast back attn to original dtype
- attn = self.attn_drop(attn)
- x = attn @ v
-
- x_output_shape = (B, N, C)
- if not enable_flash_attn:
- x = x.transpose(1, 2)
- x = x.reshape(x_output_shape)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-class MultiHeadCrossAttention(nn.Module):
- def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
- super(MultiHeadCrossAttention, self).__init__()
- assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
-
- self.d_model = d_model
- self.num_heads = num_heads
- self.head_dim = d_model // num_heads
-
- self.q_linear = nn.Linear(d_model, d_model)
- self.kv_linear = nn.Linear(d_model, d_model * 2)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(d_model, d_model)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x, cond, mask=None):
- # query/value: img tokens; key: condition; mask: if padding tokens
- B, N, C = x.shape
-
- q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
- kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
- k, v = kv.unbind(2)
-
- attn_bias = None
- # TODO: support torch computation
- import xformers.ops
-
- if mask is not None:
- attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
- x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
-
- x = x.view(B, -1, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-class T2IFinalLayer(nn.Module):
- """
- The final layer of PixArt.
- """
-
- def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
- self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
- self.out_channels = out_channels
- self.d_t = d_t
- self.d_s = d_s
-
- def t_mask_select(self, x_mask, x, masked_x, T, S):
- # x: [B, (T, S), C]
- # mased_x: [B, (T, S), C]
- # x_mask: [B, T]
- x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
- masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
- x = torch.where(x_mask[:, :, None, None], x, masked_x)
- x = rearrange(x, "B T S C -> B (T S) C")
- return x
-
- def forward(self, x, t, x_mask=None, t0=None, T=None, S=None):
- if T is None:
- T = self.d_t
- if S is None:
- S = self.d_s
- shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
- x = t2i_modulate(self.norm_final(x), shift, scale)
- if x_mask is not None:
- shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
- x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
- x = self.t_mask_select(x_mask, x, x_zero, T, S)
- x = self.linear(x)
- return x
-
-
-# ===============================================
-# Embedding Layers for Timesteps and Class Labels
-# ===============================================
-
-
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
@@ -350,7 +180,7 @@ class SizeEmbedder(TimestepEmbedder):
return next(self.parameters()).dtype
-class CaptionEmbedder(nn.Module):
+class OpenSoraCaptionEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
@@ -398,7 +228,7 @@ class CaptionEmbedder(nn.Module):
return caption
-class PositionEmbedding2D(nn.Module):
+class OpenSoraPositionEmbedding2D(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
@@ -448,3 +278,135 @@ class PositionEmbedding2D(nn.Module):
base_size: Optional[int] = None,
) -> torch.Tensor:
return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
+
+
+def get_3d_rotary_pos_embed(
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ RoPE for video tokens with 3D structure.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size, corresponding to hidden_size_head.
+ crops_coords (`Tuple[int]`):
+ The top-left and bottom-right coordinates of the crop.
+ grid_size (`Tuple[int]`):
+ The grid size of the spatial positional embedding (height, width).
+ temporal_size (`int`):
+ The size of the temporal dimension.
+ theta (`float`):
+ Scaling factor for frequency computation.
+ use_real (`bool`):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
+ """
+ start, stop = crops_coords
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
+
+ # Compute dimensions for each axis
+ dim_t = embed_dim // 4
+ dim_h = embed_dim // 8 * 3
+ dim_w = embed_dim // 8 * 3
+
+ # Temporal frequencies
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
+ grid_t = torch.from_numpy(grid_t).float()
+ freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1)
+
+ # Spatial frequencies for height and width
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
+ grid_h = torch.from_numpy(grid_h).float()
+ grid_w = torch.from_numpy(grid_w).float()
+ freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
+ freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
+ freqs_h = freqs_h.repeat_interleave(2, dim=-1)
+ freqs_w = freqs_w.repeat_interleave(2, dim=-1)
+
+ # Broadcast and concatenate tensors along specified dimension
+ def broadcast(tensors, dim=-1):
+ num_tensors = len(tensors)
+ shape_lens = {len(t.shape) for t in tensors}
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+ shape_len = list(shape_lens)[0]
+ dim = (dim + shape_len) if dim < 0 else dim
+ dims = list(zip(*(list(t.shape) for t in tensors)))
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+ assert all(
+ [*(len(set(t[1])) <= 2 for t in expandable_dims)]
+ ), "invalid dimensions for broadcastable concatenation"
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
+ expanded_dims.insert(dim, (dim, dims[dim]))
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
+ return torch.cat(tensors, dim=dim)
+
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
+
+ t, h, w, d = freqs.shape
+ freqs = freqs.view(t * h * w, d)
+
+ # Generate sine and cosine components
+ sin = freqs.sin()
+ cos = freqs.cos()
+
+ if use_real:
+ return cos, sin
+ else:
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ x: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ use_real: bool = True,
+ use_real_unbind_dim: int = -1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ if use_real:
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, None]
+ sin = sin[None, None]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+
+ if use_real_unbind_dim == -1:
+ # Use for example in Lumina
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ elif use_real_unbind_dim == -2:
+ # Use for example in Stable Audio
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
+ else:
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
+
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+
+ return out
+ else:
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis.unsqueeze(2)
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
+
+ return x_out.type_as(x)
diff --git a/videosys/models/modules/normalization.py b/videosys/models/modules/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..7985e56f450ca85d00b9b964730a304063d02ba8
--- /dev/null
+++ b/videosys/models/modules/normalization.py
@@ -0,0 +1,102 @@
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class CogVideoXLayerNormZero(nn.Module):
+ def __init__(
+ self,
+ conditioning_dim: int,
+ embedding_dim: int,
+ elementwise_affine: bool = True,
+ eps: float = 1e-5,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
+ self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
+
+ def forward(
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
+ encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
+ return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
+
+
+class AdaLayerNorm(nn.Module):
+ r"""
+ Norm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
+ output_dim (`int`, *optional*):
+ norm_elementwise_affine (`bool`, defaults to `False):
+ norm_eps (`bool`, defaults to `False`):
+ chunk_dim (`int`, defaults to `0`):
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_embeddings: Optional[int] = None,
+ output_dim: Optional[int] = None,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-5,
+ chunk_dim: int = 0,
+ ):
+ super().__init__()
+
+ self.chunk_dim = chunk_dim
+ output_dim = output_dim or embedding_dim * 2
+
+ if num_embeddings is not None:
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ else:
+ self.emb = None
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, output_dim)
+ self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
+
+ def forward(
+ self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if self.emb is not None:
+ temb = self.emb(timestep)
+
+ temb = self.linear(self.silu(temb))
+
+ if self.chunk_dim == 1:
+ # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
+ # other if-branch. This branch is specific to CogVideoX for now.
+ shift, scale = temb.chunk(2, dim=1)
+ shift = shift[:, None, :]
+ scale = scale[:, None, :]
+ else:
+ scale, shift = temb.chunk(2, dim=0)
+
+ x = self.norm(x) * (1 + scale) + shift
+ return x
diff --git a/videosys/models/modules/upsampling.py b/videosys/models/modules/upsampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a61b780e7e9d2908006bfc1b4b617736f5d71b
--- /dev/null
+++ b/videosys/models/modules/upsampling.py
@@ -0,0 +1,67 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class CogVideoXUpsample3D(nn.Module):
+ r"""
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
+
+ Args:
+ in_channels (`int`):
+ Number of channels in the input image.
+ out_channels (`int`):
+ Number of channels produced by the convolution.
+ kernel_size (`int`, defaults to `3`):
+ Size of the convolving kernel.
+ stride (`int`, defaults to `1`):
+ Stride of the convolution.
+ padding (`int`, defaults to `1`):
+ Padding added to all four sides of the input.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to compress the time dimension.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ padding: int = 1,
+ compress_time: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.compress_time = compress_time
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ if self.compress_time:
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
+ # split first frame
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
+
+ x_first = F.interpolate(x_first, scale_factor=2.0)
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
+ x_first = x_first[:, :, None, :, :]
+ inputs = torch.cat([x_first, x_rest], dim=2)
+ elif inputs.shape[2] > 1:
+ inputs = F.interpolate(inputs, scale_factor=2.0)
+ else:
+ inputs = inputs.squeeze(2)
+ inputs = F.interpolate(inputs, scale_factor=2.0)
+ inputs = inputs[:, :, None, :, :]
+ else:
+ # only interpolate 2D
+ b, c, t, h, w = inputs.shape
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ inputs = F.interpolate(inputs, scale_factor=2.0)
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
+
+ b, c, t, h, w = inputs.shape
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ inputs = self.conv(inputs)
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
+
+ return inputs
diff --git a/videosys/models/open_sora/__init__.py b/videosys/models/open_sora/__init__.py
deleted file mode 100644
index b8d92196e09ed1e3707a96162b26e40c751a6d4a..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .pipeline import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
-
-__all__ = [
- "OpenSoraConfig",
- "OpenSoraPABConfig",
- "OpenSoraPipeline",
-]
diff --git a/videosys/models/open_sora/embed.py b/videosys/models/open_sora/embed.py
deleted file mode 100644
index a9c238bc714f0a8ffb62edb36743e488060d120a..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora/embed.py
+++ /dev/null
@@ -1,585 +0,0 @@
-# Adapted from OpenSora and DiT
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# DiT: https://github.com/facebookresearch/DiT
-# OpenSora: https://github.com/hpcaitech/Open-Sora
-# --------------------------------------------------------
-
-import html
-import math
-import re
-
-import ftfy
-import numpy
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import transformers
-from timm.models.vision_transformer import Mlp
-from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
-
-from videosys.modules.embed import get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed_from_grid
-
-transformers.logging.set_verbosity_error()
-
-
-# ===============================================
-# Text Embed
-# ===============================================
-
-
-class AbstractEncoder(nn.Module):
- def __init__(self):
- super().__init__()
-
- def encode(self, *args, **kwargs):
- raise NotImplementedError
-
-
-class FrozenCLIPEmbedder(AbstractEncoder):
- """Uses the CLIP transformer encoder for text (from Hugging Face)"""
-
- def __init__(self, path="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
- super().__init__()
- self.tokenizer = CLIPTokenizer.from_pretrained(path)
- self.transformer = CLIPTextModel.from_pretrained(path)
- self.device = device
- self.max_length = max_length
- self._freeze()
-
- def _freeze(self):
- self.transformer = self.transformer.eval()
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- batch_encoding = self.tokenizer(
- text,
- truncation=True,
- max_length=self.max_length,
- return_length=True,
- return_overflowing_tokens=False,
- padding="max_length",
- return_tensors="pt",
- )
- tokens = batch_encoding["input_ids"].to(self.device)
- outputs = self.transformer(input_ids=tokens)
-
- z = outputs.last_hidden_state
- pooled_z = outputs.pooler_output
- return z, pooled_z
-
- def encode(self, text):
- return self(text)
-
-
-class TextEmbedder(nn.Module):
- """
- Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
- """
-
- def __init__(self, path, hidden_size, dropout_prob=0.1):
- super().__init__()
- self.text_encoder = FrozenCLIPEmbedder(path=path)
- self.dropout_prob = dropout_prob
-
- output_dim = self.text_encoder.transformer.config.hidden_size
- self.output_projection = nn.Linear(output_dim, hidden_size)
-
- def token_drop(self, text_prompts, force_drop_ids=None):
- """
- Drops text to enable classifier-free guidance.
- """
- if force_drop_ids is None:
- drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
- else:
- # TODO
- drop_ids = force_drop_ids == 1
- labels = list(numpy.where(drop_ids, "", text_prompts))
- # print(labels)
- return labels
-
- def forward(self, text_prompts, train, force_drop_ids=None):
- use_dropout = self.dropout_prob > 0
- if (train and use_dropout) or (force_drop_ids is not None):
- text_prompts = self.token_drop(text_prompts, force_drop_ids)
- embeddings, pooled_embeddings = self.text_encoder(text_prompts)
- # return embeddings, pooled_embeddings
- text_embeddings = self.output_projection(pooled_embeddings)
- return text_embeddings
-
-
-class CaptionEmbedder(nn.Module):
- """
- copied from https://github.com/hpcaitech/Open-Sora
-
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
- """
-
- def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120):
- super().__init__()
-
- self.y_proj = Mlp(
- in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0
- )
- self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5))
- self.uncond_prob = uncond_prob
-
- def token_drop(self, caption, force_drop_ids=None):
- """
- Drops labels to enable classifier-free guidance.
- """
- if force_drop_ids is None:
- drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
- else:
- drop_ids = force_drop_ids == 1
- caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
- return caption
-
- def forward(self, caption, train, force_drop_ids=None):
- if train:
- assert caption.shape[2:] == self.y_embedding.shape
- use_dropout = self.uncond_prob > 0
- if (train and use_dropout) or (force_drop_ids is not None):
- caption = self.token_drop(caption, force_drop_ids)
- caption = self.y_proj(caption)
- return caption
-
-
-class T5Embedder:
- available_models = ["DeepFloyd/t5-v1_1-xxl"]
-
- def __init__(
- self,
- device,
- from_pretrained=None,
- *,
- cache_dir=None,
- hf_token=None,
- use_text_preprocessing=True,
- t5_model_kwargs=None,
- torch_dtype=None,
- use_offload_folder=None,
- model_max_length=120,
- local_files_only=False,
- ):
- self.device = torch.device(device)
- self.torch_dtype = torch_dtype or torch.bfloat16
- self.cache_dir = cache_dir
-
- if t5_model_kwargs is None:
- t5_model_kwargs = {
- "low_cpu_mem_usage": True,
- "torch_dtype": self.torch_dtype,
- }
-
- if use_offload_folder is not None:
- t5_model_kwargs["offload_folder"] = use_offload_folder
- t5_model_kwargs["device_map"] = {
- "shared": self.device,
- "encoder.embed_tokens": self.device,
- "encoder.block.0": self.device,
- "encoder.block.1": self.device,
- "encoder.block.2": self.device,
- "encoder.block.3": self.device,
- "encoder.block.4": self.device,
- "encoder.block.5": self.device,
- "encoder.block.6": self.device,
- "encoder.block.7": self.device,
- "encoder.block.8": self.device,
- "encoder.block.9": self.device,
- "encoder.block.10": self.device,
- "encoder.block.11": self.device,
- "encoder.block.12": "disk",
- "encoder.block.13": "disk",
- "encoder.block.14": "disk",
- "encoder.block.15": "disk",
- "encoder.block.16": "disk",
- "encoder.block.17": "disk",
- "encoder.block.18": "disk",
- "encoder.block.19": "disk",
- "encoder.block.20": "disk",
- "encoder.block.21": "disk",
- "encoder.block.22": "disk",
- "encoder.block.23": "disk",
- "encoder.final_layer_norm": "disk",
- "encoder.dropout": "disk",
- }
- else:
- t5_model_kwargs["device_map"] = {
- "shared": self.device,
- "encoder": self.device,
- }
-
- self.use_text_preprocessing = use_text_preprocessing
- self.hf_token = hf_token
-
- assert from_pretrained in self.available_models
- self.tokenizer = AutoTokenizer.from_pretrained(
- from_pretrained,
- cache_dir=cache_dir,
- local_files_only=local_files_only,
- )
- self.model = T5EncoderModel.from_pretrained(
- from_pretrained,
- cache_dir=cache_dir,
- local_files_only=local_files_only,
- **t5_model_kwargs,
- ).eval()
- self.model_max_length = model_max_length
-
- def get_text_embeddings(self, texts):
- text_tokens_and_mask = self.tokenizer(
- texts,
- max_length=self.model_max_length,
- padding="max_length",
- truncation=True,
- return_attention_mask=True,
- add_special_tokens=True,
- return_tensors="pt",
- )
-
- input_ids = text_tokens_and_mask["input_ids"].to(self.device)
- attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
- with torch.no_grad():
- text_encoder_embs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- )["last_hidden_state"].detach()
- return text_encoder_embs, attention_mask
-
-
-class T5Encoder:
- def __init__(
- self,
- from_pretrained="DeepFloyd/t5-v1_1-xxl",
- model_max_length=120,
- device="cuda",
- dtype=torch.float,
- shardformer=False,
- ):
- assert from_pretrained is not None, "Please specify the path to the T5 model"
-
- self.t5 = T5Embedder(
- device=device,
- torch_dtype=dtype,
- from_pretrained=from_pretrained,
- model_max_length=model_max_length,
- )
- self.t5.model.to(dtype=dtype)
- self.y_embedder = None
-
- self.model_max_length = model_max_length
- self.output_dim = self.t5.model.config.d_model
-
- if shardformer:
- self.shardformer_t5()
-
- def shardformer_t5(self):
- from colossalai.shardformer import ShardConfig, ShardFormer
-
- from videosys.core.shardformer.t5.policy import T5EncoderPolicy
- from videosys.utils.utils import requires_grad
-
- shard_config = ShardConfig(
- tensor_parallel_process_group=None,
- pipeline_stage_manager=None,
- enable_tensor_parallelism=False,
- enable_fused_normalization=False,
- enable_flash_attention=False,
- enable_jit_fused=True,
- enable_sequence_parallelism=False,
- enable_sequence_overlap=False,
- )
- shard_former = ShardFormer(shard_config=shard_config)
- optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
- self.t5.model = optim_model.half()
-
- # ensure the weights are frozen
- requires_grad(self.t5.model, False)
-
- def encode(self, text):
- caption_embs, emb_masks = self.t5.get_text_embeddings(text)
- caption_embs = caption_embs[:, None]
- return dict(y=caption_embs, mask=emb_masks)
-
- def null(self, n):
- null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
- return null_y
-
-
-def basic_clean(text):
- text = ftfy.fix_text(text)
- text = html.unescape(html.unescape(text))
- return text.strip()
-
-
-BAD_PUNCT_REGEX = re.compile(
- r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
-) # noqa
-
-
-def clean_caption(caption):
- import urllib.parse as ul
-
- from bs4 import BeautifulSoup
-
- caption = str(caption)
- caption = ul.unquote_plus(caption)
- caption = caption.strip().lower()
- caption = re.sub("", "person", caption)
- # urls:
- caption = re.sub(
- r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
- "",
- caption,
- ) # regex for urls
- caption = re.sub(
- r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
- "",
- caption,
- ) # regex for urls
- # html:
- caption = BeautifulSoup(caption, features="html.parser").text
-
- # @
- caption = re.sub(r"@[\w\d]+\b", "", caption)
-
- # 31C0—31EF CJK Strokes
- # 31F0—31FF Katakana Phonetic Extensions
- # 3200—32FF Enclosed CJK Letters and Months
- # 3300—33FF CJK Compatibility
- # 3400—4DBF CJK Unified Ideographs Extension A
- # 4DC0—4DFF Yijing Hexagram Symbols
- # 4E00—9FFF CJK Unified Ideographs
- caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
- caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
- caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
- caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
- caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
- caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
- caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
- #######################################################
-
- # все виды тире / all types of dash --> "-"
- caption = re.sub(
- r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
- "-",
- caption,
- )
-
- # кавычки к одному стандарту
- caption = re.sub(r"[`´«»“”¨]", '"', caption)
- caption = re.sub(r"[‘’]", "'", caption)
-
- # "
- caption = re.sub(r""?", "", caption)
- # &
- caption = re.sub(r"&", "", caption)
-
- # ip adresses:
- caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
-
- # article ids:
- caption = re.sub(r"\d:\d\d\s+$", "", caption)
-
- # \n
- caption = re.sub(r"\\n", " ", caption)
-
- # "#123"
- caption = re.sub(r"#\d{1,3}\b", "", caption)
- # "#12345.."
- caption = re.sub(r"#\d{5,}\b", "", caption)
- # "123456.."
- caption = re.sub(r"\b\d{6,}\b", "", caption)
- # filenames:
- caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
-
- #
- caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
- caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
-
- caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
- caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
-
- # this-is-my-cute-cat / this_is_my_cute_cat
- regex2 = re.compile(r"(?:\-|\_)")
- if len(re.findall(regex2, caption)) > 3:
- caption = re.sub(regex2, " ", caption)
-
- caption = basic_clean(caption)
-
- caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
- caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
- caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
-
- caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
- caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
- caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
- caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
- caption = re.sub(r"\bpage\s+\d+\b", "", caption)
-
- caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
-
- caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
-
- caption = re.sub(r"\b\s+\:\s+", r": ", caption)
- caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
- caption = re.sub(r"\s+", " ", caption)
-
- caption.strip()
-
- caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
- caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
- caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
- caption = re.sub(r"^\.\S+$", "", caption)
-
- return caption.strip()
-
-
-def text_preprocessing(text, use_text_preprocessing: bool = True):
- if use_text_preprocessing:
- # The exact text cleaning as was in the training stage:
- text = clean_caption(text)
- text = clean_caption(text)
- return text
- else:
- return text.lower().strip()
-
-
-class TimestepEmbedder(nn.Module):
- """
- Embeds scalar timesteps into vector representations.
- """
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10000):
- """
- Create sinusoidal timestep embeddings.
- :param t: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an (N, D) Tensor of positional embeddings.
- """
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
- half = dim // 2
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
- freqs = freqs.to(device=t.device)
- args = t[:, None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t, dtype):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- if t_freq.dtype != dtype:
- t_freq = t_freq.to(dtype)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-
-# ===============================================
-# Sine/Cosine Positional Embedding Functions
-# ===============================================
-
-
-def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
- """
- grid_size: int of the grid height and width
- return:
- pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
- """
- if not isinstance(grid_size, tuple):
- grid_size = (grid_size, grid_size)
-
- grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
- grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
- if base_size is not None:
- grid_h *= base_size / grid_size[0]
- grid_w *= base_size / grid_size[1]
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
- grid = np.stack(grid, axis=0)
-
- grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token and extra_tokens > 0:
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
- return pos_embed
-
-
-def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
- pos = np.arange(0, length)[..., None] / scale
- return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
-
-
-# ===============================================
-# Patch Embed
-# ===============================================
-
-
-class PatchEmbed3D(nn.Module):
- """Video to Patch Embedding.
-
- Args:
- patch_size (int): Patch token size. Default: (2,4,4).
- in_chans (int): Number of input video channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(
- self,
- patch_size=(2, 4, 4),
- in_chans=3,
- embed_dim=96,
- norm_layer=None,
- flatten=True,
- ):
- super().__init__()
- self.patch_size = patch_size
- self.flatten = flatten
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- """Forward function."""
- # padding
- _, _, D, H, W = x.size()
- if W % self.patch_size[2] != 0:
- x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
- if H % self.patch_size[1] != 0:
- x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
- if D % self.patch_size[0] != 0:
- x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
-
- x = self.proj(x) # (B C T H W)
- if self.norm is not None:
- D, Wh, Ww = x.size(2), x.size(3), x.size(4)
- x = x.flatten(2).transpose(1, 2)
- x = self.norm(x)
- x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
- if self.flatten:
- x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
- return x
diff --git a/videosys/models/open_sora/inference_utils.py b/videosys/models/open_sora/inference_utils.py
deleted file mode 100644
index de95fcf717902b4f2c432bc3302829cf719ec980..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora/inference_utils.py
+++ /dev/null
@@ -1,348 +0,0 @@
-# Adapted from OpenSora
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# OpenSora: https://github.com/hpcaitech/Open-Sora
-# --------------------------------------------------------
-
-import json
-import os
-import re
-
-import torch
-
-from .datasets import IMG_FPS, read_from_path
-
-
-def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype):
- if info_type is None:
- return dict()
- elif info_type == "PixArtMS":
- hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1)
- ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1)
- return dict(ar=ar, hw=hw)
- elif info_type in ["STDiT2", "OpenSora"]:
- fps = fps if num_frames > 1 else IMG_FPS
- fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size)
- height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size)
- width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size)
- num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size)
- ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size)
- return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps)
- else:
- raise NotImplementedError
-
-
-def load_prompts(prompt_path, start_idx=None, end_idx=None):
- with open(prompt_path, "r") as f:
- prompts = [line.strip() for line in f.readlines()]
- prompts = prompts[start_idx:end_idx]
- return prompts
-
-
-def get_save_path_name(
- save_dir,
- sample_name=None, # prefix
- sample_idx=None, # sample index
- prompt=None, # used prompt
- prompt_as_path=False, # use prompt as path
- num_sample=1, # number of samples to generate for one prompt
- k=None, # kth sample
-):
- if sample_name is None:
- sample_name = "" if prompt_as_path else "sample"
- sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx:04d}"
- save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix[:50]}")
- if num_sample != 1:
- save_path = f"{save_path}-{k}"
- return save_path
-
-
-def get_eval_save_path_name(
- save_dir,
- id, # add id parameter
- sample_name=None, # prefix
- sample_idx=None, # sample index
- prompt=None, # used prompt
- prompt_as_path=False, # use prompt as path
- num_sample=1, # number of samples to generate for one prompt
- k=None, # kth sample
-):
- if sample_name is None:
- sample_name = "" if prompt_as_path else "sample"
- save_path = os.path.join(save_dir, f"{id}")
- if num_sample != 1:
- save_path = f"{save_path}-{k}"
- return save_path
-
-
-def append_score_to_prompts(prompts, aes=None, flow=None, camera_motion=None):
- new_prompts = []
- for prompt in prompts:
- new_prompt = prompt
- if aes is not None and "aesthetic score:" not in prompt:
- new_prompt = f"{new_prompt} aesthetic score: {aes:.1f}."
- if flow is not None and "motion score:" not in prompt:
- new_prompt = f"{new_prompt} motion score: {flow:.1f}."
- if camera_motion is not None and "camera motion:" not in prompt:
- new_prompt = f"{new_prompt} camera motion: {camera_motion}."
- new_prompts.append(new_prompt)
- return new_prompts
-
-
-def extract_json_from_prompts(prompts, reference, mask_strategy):
- ret_prompts = []
- for i, prompt in enumerate(prompts):
- parts = re.split(r"(?=[{])", prompt)
- assert len(parts) <= 2, f"Invalid prompt: {prompt}"
- ret_prompts.append(parts[0])
- if len(parts) > 1:
- additional_info = json.loads(parts[1])
- for key in additional_info:
- assert key in ["reference_path", "mask_strategy"], f"Invalid key: {key}"
- if key == "reference_path":
- reference[i] = additional_info[key]
- elif key == "mask_strategy":
- mask_strategy[i] = additional_info[key]
- return ret_prompts, reference, mask_strategy
-
-
-def collect_references_batch(reference_paths, vae, image_size):
- refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
- for reference_path in reference_paths:
- if reference_path == "":
- refs_x.append([])
- continue
- ref_path = reference_path.split(";")
- ref = []
- for r_path in ref_path:
- r = read_from_path(r_path, image_size, transform_name="resize_crop")
- r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
- r_x = r_x.squeeze(0)
- ref.append(r_x)
- refs_x.append(ref)
- return refs_x
-
-
-def extract_prompts_loop(prompts, num_loop):
- ret_prompts = []
- for prompt in prompts:
- if prompt.startswith("|0|"):
- prompt_list = prompt.split("|")[1:]
- text_list = []
- for i in range(0, len(prompt_list), 2):
- start_loop = int(prompt_list[i])
- text = prompt_list[i + 1]
- end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + 1
- text_list.extend([text] * (end_loop - start_loop))
- prompt = text_list[num_loop]
- ret_prompts.append(prompt)
- return ret_prompts
-
-
-def split_prompt(prompt_text):
- if prompt_text.startswith("|0|"):
- # this is for prompts which look like
- # |0| a beautiful day |1| a sunny day |2| a rainy day
- # we want to parse it into a list of prompts with the loop index
- prompt_list = prompt_text.split("|")[1:]
- text_list = []
- loop_idx = []
- for i in range(0, len(prompt_list), 2):
- start_loop = int(prompt_list[i])
- text = prompt_list[i + 1].strip()
- text_list.append(text)
- loop_idx.append(start_loop)
- return text_list, loop_idx
- else:
- return [prompt_text], None
-
-
-def merge_prompt(text_list, loop_idx_list=None):
- if loop_idx_list is None:
- return text_list[0]
- else:
- prompt = ""
- for i, text in enumerate(text_list):
- prompt += f"|{loop_idx_list[i]}|{text}"
- return prompt
-
-
-MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"]
-
-
-def parse_mask_strategy(mask_strategy):
- mask_batch = []
- if mask_strategy == "" or mask_strategy is None:
- return mask_batch
-
- mask_strategy = mask_strategy.split(";")
- for mask in mask_strategy:
- mask_group = mask.split(",")
- num_group = len(mask_group)
- assert num_group >= 1 and num_group <= 6, f"Invalid mask strategy: {mask}"
- mask_group.extend(MASK_DEFAULT[num_group:])
- for i in range(5):
- mask_group[i] = int(mask_group[i])
- mask_group[5] = float(mask_group[5])
- mask_batch.append(mask_group)
- return mask_batch
-
-
-def find_nearest_point(value, point, max_value):
- t = value // point
- if value % point > point / 2 and t < max_value // point - 1:
- t += 1
- return t * point
-
-
-def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
- masks = []
- no_mask = True
- for i, mask_strategy in enumerate(mask_strategys):
- no_mask = False
- mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
- mask_strategy = parse_mask_strategy(mask_strategy)
- for mst in mask_strategy:
- loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst
- if loop_id != loop_i:
- continue
- ref = refs_x[i][m_id]
-
- if m_ref_start < 0:
- # ref: [C, T, H, W]
- m_ref_start = ref.shape[1] + m_ref_start
- if m_target_start < 0:
- # z: [B, C, T, H, W]
- m_target_start = z.shape[2] + m_target_start
- if align is not None:
- m_ref_start = find_nearest_point(m_ref_start, align, ref.shape[1])
- m_target_start = find_nearest_point(m_target_start, align, z.shape[2])
- m_length = min(m_length, z.shape[2] - m_target_start, ref.shape[1] - m_ref_start)
- z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
- mask[m_target_start : m_target_start + m_length] = edit_ratio
- masks.append(mask)
- if no_mask:
- return None
- masks = torch.stack(masks)
- return masks
-
-
-def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i, condition_frame_length, condition_frame_edit):
- ref_x = vae.encode(generated_video)
- for j, refs in enumerate(refs_x):
- if refs is None:
- refs_x[j] = [ref_x[j]]
- else:
- refs.append(ref_x[j])
- if mask_strategy[j] is None or mask_strategy[j] == "":
- mask_strategy[j] = ""
- else:
- mask_strategy[j] += ";"
- mask_strategy[
- j
- ] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}"
- return refs_x, mask_strategy
-
-
-def dframe_to_frame(num):
- assert num % 5 == 0, f"Invalid num: {num}"
- return num // 5 * 17
-
-
-OPENAI_CLIENT = None
-REFINE_PROMPTS = None
-REFINE_PROMPTS_PATH = "assets/texts/t2v_pllava.txt"
-REFINE_PROMPTS_TEMPLATE = """
-You need to refine user's input prompt. The user's input prompt is used for video generation task. You need to refine the user's prompt to make it more suitable for the task. Here are some examples of refined prompts:
-{}
-
-The refined prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The refined prompt should be in English.
-"""
-RANDOM_PROMPTS = None
-RANDOM_PROMPTS_TEMPLATE = """
-You need to generate one input prompt for video generation task. The prompt should be suitable for the task. Here are some examples of refined prompts:
-{}
-
-The prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The prompt should be in English.
-"""
-
-
-def get_openai_response(sys_prompt, usr_prompt, model="gpt-4o"):
- global OPENAI_CLIENT
- if OPENAI_CLIENT is None:
- from openai import OpenAI
-
- OPENAI_CLIENT = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
-
- completion = OPENAI_CLIENT.chat.completions.create(
- model=model,
- messages=[
- {
- "role": "system",
- "content": sys_prompt,
- }, # <-- This is the system message that provides context to the model
- {
- "role": "user",
- "content": usr_prompt,
- }, # <-- This is the user message for which the model will generate a response
- ],
- )
-
- return completion.choices[0].message.content
-
-
-def get_random_prompt_by_openai():
- global RANDOM_PROMPTS
- if RANDOM_PROMPTS is None:
- examples = load_prompts(REFINE_PROMPTS_PATH)
- RANDOM_PROMPTS = RANDOM_PROMPTS_TEMPLATE.format("\n".join(examples))
-
- response = get_openai_response(RANDOM_PROMPTS, "Generate one example.")
- return response
-
-
-def refine_prompt_by_openai(prompt):
- global REFINE_PROMPTS
- if REFINE_PROMPTS is None:
- examples = load_prompts(REFINE_PROMPTS_PATH)
- REFINE_PROMPTS = REFINE_PROMPTS_TEMPLATE.format("\n".join(examples))
-
- response = get_openai_response(REFINE_PROMPTS, prompt)
- return response
-
-
-def has_openai_key():
- return "OPENAI_API_KEY" in os.environ
-
-
-def refine_prompts_by_openai(prompts):
- new_prompts = []
- for prompt in prompts:
- try:
- if prompt.strip() == "":
- new_prompt = get_random_prompt_by_openai()
- print(f"[Info] Empty prompt detected, generate random prompt: {new_prompt}")
- else:
- new_prompt = refine_prompt_by_openai(prompt)
- print(f"[Info] Refine prompt: {prompt} -> {new_prompt}")
- new_prompts.append(new_prompt)
- except Exception as e:
- print(f"[Warning] Failed to refine prompt: {prompt} due to {e}")
- new_prompts.append(prompt)
- return new_prompts
-
-
-def add_watermark(
- input_video_path, watermark_image_path="./assets/images/watermark/watermark.png", output_video_path=None
-):
- # execute this command in terminal with subprocess
- # return if the process is successful
- if output_video_path is None:
- output_video_path = input_video_path.replace(".mp4", "_watermark.mp4")
- cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path} -filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}'
- exit_code = os.system(cmd)
- is_success = exit_code == 0
- return is_success
diff --git a/videosys/models/open_sora/pipeline.py b/videosys/models/open_sora/pipeline.py
deleted file mode 100644
index 4eba025dcaee1a0f61d03e7946871417d7a32088..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora/pipeline.py
+++ /dev/null
@@ -1,427 +0,0 @@
-import re
-from typing import Optional, Tuple, Union
-
-import torch
-from diffusers.models import AutoencoderKL
-
-from videosys.core.pab_mgr import PABConfig, set_pab_manager
-from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
-from videosys.utils.utils import save_video
-
-from .datasets import get_image_size, get_num_frames
-from .inference_utils import (
- append_generated,
- append_score_to_prompts,
- apply_mask_strategy,
- collect_references_batch,
- dframe_to_frame,
- extract_json_from_prompts,
- extract_prompts_loop,
- merge_prompt,
- prepare_multi_resolution_info,
- split_prompt,
-)
-from .rflow import RFLOW
-from .stdit3 import STDiT3_XL_2
-from .text_encoder import T5Encoder, text_preprocessing
-from .vae import OpenSoraVAE_V1_2
-
-
-class OpenSoraPABConfig(PABConfig):
- def __init__(
- self,
- steps: int = 50,
- spatial_broadcast: bool = True,
- spatial_threshold: list = [450, 930],
- spatial_gap: int = 2,
- temporal_broadcast: bool = True,
- temporal_threshold: list = [450, 930],
- temporal_gap: int = 4,
- cross_broadcast: bool = True,
- cross_threshold: list = [450, 930],
- cross_gap: int = 6,
- diffusion_skip: bool = False,
- diffusion_timestep_respacing: list = None,
- diffusion_skip_timestep: list = None,
- mlp_skip: bool = True,
- mlp_spatial_skip_config: dict = {
- 676: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
- 788: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
- 864: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
- },
- mlp_temporal_skip_config: dict = {
- 676: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
- 788: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
- 864: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
- },
- ):
- super().__init__(
- steps=steps,
- spatial_broadcast=spatial_broadcast,
- spatial_threshold=spatial_threshold,
- spatial_gap=spatial_gap,
- temporal_broadcast=temporal_broadcast,
- temporal_threshold=temporal_threshold,
- temporal_gap=temporal_gap,
- cross_broadcast=cross_broadcast,
- cross_threshold=cross_threshold,
- cross_gap=cross_gap,
- diffusion_skip=diffusion_skip,
- diffusion_timestep_respacing=diffusion_timestep_respacing,
- diffusion_skip_timestep=diffusion_skip_timestep,
- mlp_skip=mlp_skip,
- mlp_spatial_skip_config=mlp_spatial_skip_config,
- mlp_temporal_skip_config=mlp_temporal_skip_config,
- )
-
-
-class OpenSoraConfig:
- def __init__(
- self,
- world_size: int = 1,
- transformer: str = "hpcai-tech/OpenSora-STDiT-v3",
- vae: str = "hpcai-tech/OpenSora-VAE-v1.2",
- text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
- # ======= scheduler =======
- num_sampling_steps: int = 30,
- cfg_scale: float = 7.0,
- # ======= vae ========
- tiling_size: int = 4,
- # ======= pab ========
- enable_pab: bool = False,
- pab_config: PABConfig = OpenSoraPABConfig(),
- ):
- # ======= engine ========
- self.world_size = world_size
-
- # ======= pipeline ========
- self.pipeline_cls = OpenSoraPipeline
- self.transformer = transformer
- self.vae = vae
- self.text_encoder = text_encoder
-
- # ======= scheduler ========
- self.num_sampling_steps = num_sampling_steps
- self.cfg_scale = cfg_scale
-
- # ======= vae ========
- self.tiling_size = tiling_size
-
- # ======= pab ========
- self.enable_pab = enable_pab
- self.pab_config = pab_config
-
-
-class OpenSoraPipeline(VideoSysPipeline):
- r"""
- Pipeline for text-to-image generation using PixArt-Alpha.
-
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
-
- Args:
- vae ([`AutoencoderKL`]):
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
- text_encoder ([`T5EncoderModel`]):
- Frozen text-encoder. PixArt-Alpha uses
- [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
- [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
- tokenizer (`T5Tokenizer`):
- Tokenizer of class
- [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
- transformer ([`Transformer2DModel`]):
- A text conditioned `Transformer2DModel` to denoise the encoded image latents.
- scheduler ([`SchedulerMixin`]):
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
- """
- bad_punct_regex = re.compile(
- r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
- ) # noqa
-
- _optional_components = ["tokenizer", "text_encoder"]
- model_cpu_offload_seq = "text_encoder->transformer->vae"
-
- def __init__(
- self,
- config: OpenSoraConfig,
- text_encoder: Optional[T5Encoder] = None,
- vae: Optional[AutoencoderKL] = None,
- transformer: Optional[STDiT3_XL_2] = None,
- scheduler: Optional[RFLOW] = None,
- device: torch.device = torch.device("cuda"),
- dtype: torch.dtype = torch.bfloat16,
- ):
- super().__init__()
- self._config = config
- self._device = device
- self._dtype = dtype
-
- # initialize the model if not provided
- if text_encoder is None:
- text_encoder = T5Encoder(
- from_pretrained=config.text_encoder, model_max_length=300, device=device, dtype=dtype
- )
- if vae is None:
- vae = OpenSoraVAE_V1_2(
- from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
- micro_frame_size=17,
- micro_batch_size=config.tiling_size,
- ).to(dtype)
- if transformer is None:
- transformer = STDiT3_XL_2(
- from_pretrained="hpcai-tech/OpenSora-STDiT-v3",
- qk_norm=True,
- enable_flash_attn=True,
- enable_layernorm_kernel=True,
- in_channels=vae.out_channels,
- caption_channels=text_encoder.output_dim,
- model_max_length=text_encoder.model_max_length,
- ).to(device, dtype)
- text_encoder.y_embedder = transformer.y_embedder
- if scheduler is None:
- scheduler = RFLOW(
- use_timestep_transform=True, num_sampling_steps=config.num_sampling_steps, cfg_scale=config.cfg_scale
- )
-
- # pab
- if config.enable_pab:
- set_pab_manager(config.pab_config)
-
- # set eval and device
- self.set_eval_and_device(device, text_encoder, vae, transformer)
-
- self.register_modules(text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler)
-
- @torch.no_grad()
- def generate(
- self,
- prompt: str,
- resolution="480p",
- aspect_ratio="9:16",
- num_frames: int = 51,
- loop: int = 1,
- llm_refine: bool = False,
- negative_prompt: str = "",
- ms: Optional[str] = "",
- refs: Optional[str] = "",
- aes: float = 6.5,
- flow: Optional[float] = None,
- camera_motion: Optional[float] = None,
- condition_frame_length: int = 5,
- align: int = 5,
- condition_frame_edit: float = 0.0,
- return_dict: bool = True,
- verbose: bool = True,
- ) -> Union[VideoSysPipelineOutput, Tuple]:
- """
- Function invoked when calling the pipeline for generation.
-
- Args:
- prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
- instead.
- resolution (`str`, *optional*, defaults to `"480p"`):
- The resolution of the generated video.
- aspect_ratio (`str`, *optional*, defaults to `"9:16"`):
- The aspect ratio of the generated video.
- num_frames (`int`, *optional*, defaults to 51):
- The number of frames to generate.
- negative_prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
- less than `1`).
- num_inference_steps (`int`, *optional*, defaults to 100):
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
- expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
- timesteps are used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
- num_images_per_prompt (`int`, *optional*, defaults to 1):
- The number of images to generate per prompt.
- height (`int`, *optional*, defaults to self.unet.config.sample_size):
- The height in pixels of the generated image.
- width (`int`, *optional*, defaults to self.unet.config.sample_size):
- The width in pixels of the generated image.
- eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
- to make generation deterministic.
- latents (`torch.FloatTensor`, *optional*):
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
- prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
- provided, text embeddings will be generated from `prompt` input argument.
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
- provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
- output_type (`str`, *optional*, defaults to `"pil"`):
- The output format of the generate image. Choose between
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
- callback (`Callable`, *optional*):
- A function that will be called every `callback_steps` steps during inference. The function will be
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
- callback_steps (`int`, *optional*, defaults to 1):
- The frequency at which the `callback` function will be called. If not specified, the callback will be
- called at every step.
- clean_caption (`bool`, *optional*, defaults to `True`):
- Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
- be installed. If the dependencies are not installed, the embeddings will be created from the raw
- prompt.
- mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
-
- Examples:
-
- Returns:
- [`~pipelines.ImagePipelineOutput`] or `tuple`:
- If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
- returned where the first element is a list with the generated images
- """
- # == basic ==
- fps = 24
- image_size = get_image_size(resolution, aspect_ratio)
- num_frames = get_num_frames(num_frames)
-
- # == prepare batch prompts ==
- batch_prompts = [prompt]
- ms = [ms]
- refs = [refs]
-
- # == get json from prompts ==
- batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)
-
- # == get reference for condition ==
- refs = collect_references_batch(refs, self.vae, image_size)
-
- # == multi-resolution info ==
- model_args = prepare_multi_resolution_info(
- "OpenSora", len(batch_prompts), image_size, num_frames, fps, self._device, self._dtype
- )
-
- # == process prompts step by step ==
- # 0. split prompt
- # each element in the list is [prompt_segment_list, loop_idx_list]
- batched_prompt_segment_list = []
- batched_loop_idx_list = []
- for prompt in batch_prompts:
- prompt_segment_list, loop_idx_list = split_prompt(prompt)
- batched_prompt_segment_list.append(prompt_segment_list)
- batched_loop_idx_list.append(loop_idx_list)
-
- # 1. refine prompt by openai
- # if llm_refine:
- # only call openai API when
- # 1. seq parallel is not enabled
- # 2. seq parallel is enabled and the process is rank 0
- # if not enable_sequence_parallelism or (enable_sequence_parallelism and coordinator.is_master()):
- # for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
- # batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list)
-
- # # sync the prompt if using seq parallel
- # if enable_sequence_parallelism:
- # coordinator.block_all()
- # prompt_segment_length = [
- # len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list
- # ]
-
- # # flatten the prompt segment list
- # batched_prompt_segment_list = [
- # prompt_segment
- # for prompt_segment_list in batched_prompt_segment_list
- # for prompt_segment in prompt_segment_list
- # ]
-
- # # create a list of size equal to world size
- # broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size
- # dist.broadcast_object_list(broadcast_obj_list, 0)
-
- # # recover the prompt list
- # batched_prompt_segment_list = []
- # segment_start_idx = 0
- # all_prompts = broadcast_obj_list[0]
- # for num_segment in prompt_segment_length:
- # batched_prompt_segment_list.append(
- # all_prompts[segment_start_idx : segment_start_idx + num_segment]
- # )
- # segment_start_idx += num_segment
-
- # 2. append score
- for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
- batched_prompt_segment_list[idx] = append_score_to_prompts(
- prompt_segment_list,
- aes=aes,
- flow=flow,
- camera_motion=camera_motion,
- )
-
- # 3. clean prompt with T5
- for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
- batched_prompt_segment_list[idx] = [text_preprocessing(prompt) for prompt in prompt_segment_list]
-
- # 4. merge to obtain the final prompt
- batch_prompts = []
- for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list):
- batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list))
-
- # == Iter over loop generation ==
- video_clips = []
- for loop_i in range(loop):
- # == get prompt for loop i ==
- batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)
-
- # == add condition frames for loop ==
- if loop_i > 0:
- refs, ms = append_generated(
- self.vae, video_clips[-1], refs, ms, loop_i, condition_frame_length, condition_frame_edit
- )
-
- # == sampling ==
- input_size = (num_frames, *image_size)
- latent_size = self.vae.get_latent_size(input_size)
- z = torch.randn(
- len(batch_prompts), self.vae.out_channels, *latent_size, device=self._device, dtype=self._dtype
- )
- masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)
- samples = self.scheduler.sample(
- self.transformer,
- self.text_encoder,
- z=z,
- prompts=batch_prompts_loop,
- device=self._device,
- additional_args=model_args,
- progress=verbose,
- mask=masks,
- )
- samples = self.vae.decode(samples.to(self._dtype), num_frames=num_frames)
- video_clips.append(samples)
-
- for i in range(1, loop):
- video_clips[i] = video_clips[i][:, dframe_to_frame(condition_frame_length) :]
- video = torch.cat(video_clips, dim=1)
-
- low, high = -1, 1
- video.clamp_(min=low, max=high)
- video.sub_(low).div_(max(high - low, 1e-5))
- video = video.mul(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 4, 1).to("cpu", torch.uint8)
-
- # Offload all models
- self.maybe_free_model_hooks()
-
- if not return_dict:
- return (video,)
-
- return VideoSysPipelineOutput(video=video)
-
- def save_video(self, video, output_path):
- save_video(video, output_path, fps=24)
diff --git a/videosys/models/open_sora/text_encoder.py b/videosys/models/open_sora/text_encoder.py
deleted file mode 100644
index cc3d5b313d84090f2ad1dc53bf06861fc0818998..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora/text_encoder.py
+++ /dev/null
@@ -1,330 +0,0 @@
-# Adapted from OpenSora
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# OpenSora: https://github.com/hpcaitech/Open-Sora
-# --------------------------------------------------------
-
-import html
-import os
-import re
-
-import ftfy
-import torch
-from transformers import AutoTokenizer, T5EncoderModel
-
-os.environ["TOKENIZERS_PARALLELISM"] = "true"
-
-
-class T5Embedder:
- available_models = ["DeepFloyd/t5-v1_1-xxl"]
-
- def __init__(
- self,
- device,
- from_pretrained=None,
- *,
- cache_dir=None,
- hf_token=None,
- use_text_preprocessing=True,
- t5_model_kwargs=None,
- torch_dtype=None,
- use_offload_folder=None,
- model_max_length=120,
- local_files_only=False,
- ):
- self.device = torch.device(device)
- self.torch_dtype = torch_dtype or torch.bfloat16
- self.cache_dir = cache_dir
-
- if t5_model_kwargs is None:
- t5_model_kwargs = {
- "low_cpu_mem_usage": True,
- "torch_dtype": self.torch_dtype,
- }
-
- if use_offload_folder is not None:
- t5_model_kwargs["offload_folder"] = use_offload_folder
- t5_model_kwargs["device_map"] = {
- "shared": self.device,
- "encoder.embed_tokens": self.device,
- "encoder.block.0": self.device,
- "encoder.block.1": self.device,
- "encoder.block.2": self.device,
- "encoder.block.3": self.device,
- "encoder.block.4": self.device,
- "encoder.block.5": self.device,
- "encoder.block.6": self.device,
- "encoder.block.7": self.device,
- "encoder.block.8": self.device,
- "encoder.block.9": self.device,
- "encoder.block.10": self.device,
- "encoder.block.11": self.device,
- "encoder.block.12": "disk",
- "encoder.block.13": "disk",
- "encoder.block.14": "disk",
- "encoder.block.15": "disk",
- "encoder.block.16": "disk",
- "encoder.block.17": "disk",
- "encoder.block.18": "disk",
- "encoder.block.19": "disk",
- "encoder.block.20": "disk",
- "encoder.block.21": "disk",
- "encoder.block.22": "disk",
- "encoder.block.23": "disk",
- "encoder.final_layer_norm": "disk",
- "encoder.dropout": "disk",
- }
- else:
- t5_model_kwargs["device_map"] = {
- "shared": self.device,
- "encoder": self.device,
- }
-
- self.use_text_preprocessing = use_text_preprocessing
- self.hf_token = hf_token
-
- assert from_pretrained in self.available_models
- self.tokenizer = AutoTokenizer.from_pretrained(
- from_pretrained,
- cache_dir=cache_dir,
- local_files_only=local_files_only,
- )
- self.model = T5EncoderModel.from_pretrained(
- from_pretrained,
- cache_dir=cache_dir,
- local_files_only=local_files_only,
- **t5_model_kwargs,
- ).eval()
- self.model_max_length = model_max_length
-
- def get_text_embeddings(self, texts):
- text_tokens_and_mask = self.tokenizer(
- texts,
- max_length=self.model_max_length,
- padding="max_length",
- truncation=True,
- return_attention_mask=True,
- add_special_tokens=True,
- return_tensors="pt",
- )
-
- input_ids = text_tokens_and_mask["input_ids"].to(self.device)
- attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
- with torch.no_grad():
- text_encoder_embs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- )["last_hidden_state"].detach()
- return text_encoder_embs, attention_mask
-
-
-class T5Encoder:
- def __init__(
- self,
- from_pretrained=None,
- model_max_length=120,
- device="cuda",
- dtype=torch.float,
- cache_dir=None,
- shardformer=False,
- local_files_only=False,
- ):
- assert from_pretrained is not None, "Please specify the path to the T5 model"
-
- self.t5 = T5Embedder(
- device=device,
- torch_dtype=dtype,
- from_pretrained=from_pretrained,
- cache_dir=cache_dir,
- model_max_length=model_max_length,
- local_files_only=local_files_only,
- )
- self.t5.model.to(dtype=dtype)
- self.y_embedder = None
-
- self.model_max_length = model_max_length
- self.output_dim = self.t5.model.config.d_model
- self.dtype = dtype
-
- if shardformer:
- self.shardformer_t5()
-
- def eval(self):
- self.t5.model.eval()
-
- def to(self, device):
- self.t5.model.to(device)
-
- def shardformer_t5(self):
- from colossalai.shardformer import ShardConfig, ShardFormer
-
- from videosys.core.shardformer.t5.policy import T5EncoderPolicy
- from videosys.utils.utils import requires_grad
-
- shard_config = ShardConfig(
- tensor_parallel_process_group=None,
- pipeline_stage_manager=None,
- enable_tensor_parallelism=False,
- enable_fused_normalization=False,
- enable_flash_attention=False,
- enable_jit_fused=True,
- enable_sequence_parallelism=False,
- enable_sequence_overlap=False,
- )
- shard_former = ShardFormer(shard_config=shard_config)
- optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
- self.t5.model = optim_model.to(self.dtype)
-
- # ensure the weights are frozen
- requires_grad(self.t5.model, False)
-
- def encode(self, text):
- caption_embs, emb_masks = self.t5.get_text_embeddings(text)
- caption_embs = caption_embs[:, None]
- return dict(y=caption_embs, mask=emb_masks)
-
- def null(self, n):
- null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
- return null_y
-
-
-def basic_clean(text):
- text = ftfy.fix_text(text)
- text = html.unescape(html.unescape(text))
- return text.strip()
-
-
-BAD_PUNCT_REGEX = re.compile(
- r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
-) # noqa
-
-
-def clean_caption(caption):
- import urllib.parse as ul
-
- from bs4 import BeautifulSoup
-
- caption = str(caption)
- caption = ul.unquote_plus(caption)
- caption = caption.strip().lower()
- caption = re.sub("", "person", caption)
- # urls:
- caption = re.sub(
- r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
- "",
- caption,
- ) # regex for urls
- caption = re.sub(
- r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
- "",
- caption,
- ) # regex for urls
- # html:
- caption = BeautifulSoup(caption, features="html.parser").text
-
- # @
- caption = re.sub(r"@[\w\d]+\b", "", caption)
-
- # 31C0—31EF CJK Strokes
- # 31F0—31FF Katakana Phonetic Extensions
- # 3200—32FF Enclosed CJK Letters and Months
- # 3300—33FF CJK Compatibility
- # 3400—4DBF CJK Unified Ideographs Extension A
- # 4DC0—4DFF Yijing Hexagram Symbols
- # 4E00—9FFF CJK Unified Ideographs
- caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
- caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
- caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
- caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
- caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
- caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
- caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
- #######################################################
-
- # все виды тире / all types of dash --> "-"
- caption = re.sub(
- r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
- "-",
- caption,
- )
-
- # кавычки к одному стандарту
- caption = re.sub(r"[`´«»“”¨]", '"', caption)
- caption = re.sub(r"[‘’]", "'", caption)
-
- # "
- caption = re.sub(r""?", "", caption)
- # &
- caption = re.sub(r"&", "", caption)
-
- # ip adresses:
- caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
-
- # article ids:
- caption = re.sub(r"\d:\d\d\s+$", "", caption)
-
- # \n
- caption = re.sub(r"\\n", " ", caption)
-
- # "#123"
- caption = re.sub(r"#\d{1,3}\b", "", caption)
- # "#12345.."
- caption = re.sub(r"#\d{5,}\b", "", caption)
- # "123456.."
- caption = re.sub(r"\b\d{6,}\b", "", caption)
- # filenames:
- caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
-
- #
- caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
- caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
-
- caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
- caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
-
- # this-is-my-cute-cat / this_is_my_cute_cat
- regex2 = re.compile(r"(?:\-|\_)")
- if len(re.findall(regex2, caption)) > 3:
- caption = re.sub(regex2, " ", caption)
-
- caption = basic_clean(caption)
-
- caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
- caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
- caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
-
- caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
- caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
- caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
- caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
- caption = re.sub(r"\bpage\s+\d+\b", "", caption)
-
- caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
-
- caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
-
- caption = re.sub(r"\b\s+\:\s+", r": ", caption)
- caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
- caption = re.sub(r"\s+", " ", caption)
-
- caption.strip()
-
- caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
- caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
- caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
- caption = re.sub(r"^\.\S+$", "", caption)
-
- return caption.strip()
-
-
-def text_preprocessing(text, use_text_preprocessing: bool = True):
- if use_text_preprocessing:
- # The exact text cleaning as was in the training stage:
- text = clean_caption(text)
- text = clean_caption(text)
- return text
- else:
- return text.lower().strip()
diff --git a/videosys/models/open_sora/utils.py b/videosys/models/open_sora/utils.py
deleted file mode 100644
index 8f32611730933876070a704c923d9823bcaaebb0..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora/utils.py
+++ /dev/null
@@ -1,179 +0,0 @@
-# Adapted from OpenSora
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# OpenSora: https://github.com/hpcaitech/Open-Sora
-# --------------------------------------------------------
-
-import os
-from collections.abc import Iterable
-
-import torch
-import torch.distributed as dist
-from colossalai.checkpoint_io import GeneralCheckpointIO
-from torch.utils.checkpoint import checkpoint, checkpoint_sequential
-from torchvision.datasets.utils import download_url
-
-from videosys.utils.logging import logger
-
-hf_endpoint = os.environ.get("HF_ENDPOINT")
-if hf_endpoint is None:
- hf_endpoint = "https://huggingface.co"
-
-pretrained_models = {
- "DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt",
- "DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt",
- "Latte-XL-2-256x256-ucf101.pt": hf_endpoint + "/maxin-cn/Latte/resolve/main/ucf101.pt",
- "PixArt-XL-2-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-256x256.pth",
- "PixArt-XL-2-SAM-256x256.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-SAM-256x256.pth",
- "PixArt-XL-2-512x512.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-512x512.pth",
- "PixArt-XL-2-1024-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth",
- "OpenSora-v1-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth",
- "OpenSora-v1-HQ-16x256x256.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth",
- "OpenSora-v1-HQ-16x512x512.pth": hf_endpoint + "/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth",
- "PixArt-Sigma-XL-2-256x256.pth": hf_endpoint
- + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
- "PixArt-Sigma-XL-2-512-MS.pth": hf_endpoint
- + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
- "PixArt-Sigma-XL-2-1024-MS.pth": hf_endpoint
- + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
- "PixArt-Sigma-XL-2-2K-MS.pth": hf_endpoint + "/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
-}
-
-
-def load_from_sharded_state_dict(model, ckpt_path, model_name="model", strict=False):
- ckpt_io = GeneralCheckpointIO()
- ckpt_io.load_model(model, os.path.join(ckpt_path, model_name), strict=strict)
-
-
-def reparameter(ckpt, name=None, model=None):
- model_name = name
- name = os.path.basename(name)
- if not dist.is_initialized() or dist.get_rank() == 0:
- logger.info("loading pretrained model: %s", model_name)
- if name in ["DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"]:
- ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
- del ckpt["pos_embed"]
- if name in ["Latte-XL-2-256x256-ucf101.pt"]:
- ckpt = ckpt["ema"]
- ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
- del ckpt["pos_embed"]
- del ckpt["temp_embed"]
- if name in [
- "PixArt-XL-2-256x256.pth",
- "PixArt-XL-2-SAM-256x256.pth",
- "PixArt-XL-2-512x512.pth",
- "PixArt-XL-2-1024-MS.pth",
- "PixArt-Sigma-XL-2-256x256.pth",
- "PixArt-Sigma-XL-2-512-MS.pth",
- "PixArt-Sigma-XL-2-1024-MS.pth",
- "PixArt-Sigma-XL-2-2K-MS.pth",
- ]:
- ckpt = ckpt["state_dict"]
- ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
- if "pos_embed" in ckpt:
- del ckpt["pos_embed"]
-
- if name in [
- "PixArt-1B-2.pth",
- ]:
- ckpt = ckpt["state_dict"]
- if "pos_embed" in ckpt:
- del ckpt["pos_embed"]
-
- # no need pos_embed
- if "pos_embed_temporal" in ckpt:
- del ckpt["pos_embed_temporal"]
- if "pos_embed" in ckpt:
- del ckpt["pos_embed"]
- # different text length
- if "y_embedder.y_embedding" in ckpt:
- if ckpt["y_embedder.y_embedding"].shape[0] < model.y_embedder.y_embedding.shape[0]:
- logger.info(
- "Extend y_embedding from %s to %s",
- ckpt["y_embedder.y_embedding"].shape[0],
- model.y_embedder.y_embedding.shape[0],
- )
- additional_length = model.y_embedder.y_embedding.shape[0] - ckpt["y_embedder.y_embedding"].shape[0]
- new_y_embedding = torch.zeros(additional_length, model.y_embedder.y_embedding.shape[1])
- new_y_embedding[:] = ckpt["y_embedder.y_embedding"][-1]
- ckpt["y_embedder.y_embedding"] = torch.cat([ckpt["y_embedder.y_embedding"], new_y_embedding], dim=0)
- elif ckpt["y_embedder.y_embedding"].shape[0] > model.y_embedder.y_embedding.shape[0]:
- logger.info(
- "Shrink y_embedding from %s to %s",
- ckpt["y_embedder.y_embedding"].shape[0],
- model.y_embedder.y_embedding.shape[0],
- )
- ckpt["y_embedder.y_embedding"] = ckpt["y_embedder.y_embedding"][: model.y_embedder.y_embedding.shape[0]]
- # stdit3 special case
- if type(model).__name__ == "STDiT3" and "PixArt-Sigma" in name:
- ckpt_keys = list(ckpt.keys())
- for key in ckpt_keys:
- if "blocks." in key:
- ckpt[key.replace("blocks.", "spatial_blocks.")] = ckpt[key]
- del ckpt[key]
-
- return ckpt
-
-
-def find_model(model_name, model=None):
- """
- Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
- """
- if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
- model_ckpt = download_model(model_name)
- model_ckpt = reparameter(model_ckpt, model_name, model=model)
- else: # Load a custom DiT checkpoint:
- assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
- model_ckpt = torch.load(model_name, map_location=lambda storage, loc: storage)
- model_ckpt = reparameter(model_ckpt, model_name, model=model)
- return model_ckpt
-
-
-def download_model(model_name=None, local_path=None, url=None):
- """
- Downloads a pre-trained DiT model from the web.
- """
- if model_name is not None:
- assert model_name in pretrained_models
- local_path = f"pretrained_models/{model_name}"
- web_path = pretrained_models[model_name]
- else:
- assert local_path is not None
- assert url is not None
- web_path = url
- if not os.path.isfile(local_path):
- os.makedirs("pretrained_models", exist_ok=True)
- dir_name = os.path.dirname(local_path)
- file_name = os.path.basename(local_path)
- download_url(web_path, dir_name, file_name)
- model = torch.load(local_path, map_location=lambda storage, loc: storage)
- return model
-
-
-def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", strict=False):
- if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
- state_dict = find_model(ckpt_path, model=model)
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
- logger.info("Missing keys: %s", missing_keys)
- logger.info("Unexpected keys: %s", unexpected_keys)
- elif os.path.isdir(ckpt_path):
- load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict)
- logger.info("Model checkpoint loaded from %s", ckpt_path)
- if save_as_pt:
- save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt")
- torch.save(model.state_dict(), save_path)
- logger.info("Model checkpoint saved to %s", save_path)
- else:
- raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
-
-
-def auto_grad_checkpoint(module, *args, **kwargs):
- if getattr(module, "grad_checkpointing", False):
- if not isinstance(module, Iterable):
- return checkpoint(module, *args, use_reentrant=False, **kwargs)
- gc_step = module[0].grad_checkpointing_step
- return checkpoint_sequential(module, gc_step, *args, use_reentrant=False, **kwargs)
- return module(*args, **kwargs)
diff --git a/videosys/models/open_sora_plan/__init__.py b/videosys/models/open_sora_plan/__init__.py
deleted file mode 100644
index c19791336c70c4e58bae7a5bb02d0e74f36daf1f..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .pipeline import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
-
-__all__ = [
- "OpenSoraPlanPipeline",
- "OpenSoraPlanConfig",
- "OpenSoraPlanPABConfig",
-]
diff --git a/videosys/models/open_sora_plan/losses.py b/videosys/models/open_sora_plan/losses.py
deleted file mode 100644
index 106f4eb7ab063d41719721f97657ff1ee9d90103..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/losses.py
+++ /dev/null
@@ -1,677 +0,0 @@
-# Adapted from Open-Sora-Plan
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
-# --------------------------------------------------------
-
-import functools
-import hashlib
-import os
-from collections import namedtuple
-
-import requests
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-from torch import nn
-from torchvision import models
-from tqdm import tqdm
-
-from videosys.models.open_sora_plan.modules.normalize import ActNorm
-
-URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
-
-CKPT_MAP = {"vgg_lpips": "vgg.pth"}
-
-MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
-
-
-def download(url, local_path, chunk_size=1024):
- os.makedirs(os.path.split(local_path)[0], exist_ok=True)
- with requests.get(url, stream=True) as r:
- total_size = int(r.headers.get("content-length", 0))
- with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
- with open(local_path, "wb") as f:
- for data in r.iter_content(chunk_size=chunk_size):
- if data:
- f.write(data)
- pbar.update(chunk_size)
-
-
-def md5_hash(path):
- with open(path, "rb") as f:
- content = f.read()
- return hashlib.md5(content).hexdigest()
-
-
-def get_ckpt_path(name, root, check=False):
- assert name in URL_MAP
- path = os.path.join(root, CKPT_MAP[name])
- if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
- print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
- download(URL_MAP[name], path)
- md5 = md5_hash(path)
- assert md5 == MD5_MAP[name], md5
- return path
-
-
-class LPIPS(nn.Module):
- # Learned perceptual metric
- def __init__(self, use_dropout=True):
- super().__init__()
- self.scaling_layer = ScalingLayer()
- self.chns = [64, 128, 256, 512, 512] # vg16 features
- self.net = vgg16(pretrained=True, requires_grad=False)
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
- self.load_from_pretrained()
- for param in self.parameters():
- param.requires_grad = False
-
- def load_from_pretrained(self, name="vgg_lpips"):
- ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
- self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
- print("loaded pretrained LPIPS loss from {}".format(ckpt))
-
- @classmethod
- def from_pretrained(cls, name="vgg_lpips"):
- if name != "vgg_lpips":
- raise NotImplementedError
- model = cls()
- ckpt = get_ckpt_path(name)
- model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
- return model
-
- def forward(self, input, target):
- in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
- outs0, outs1 = self.net(in0_input), self.net(in1_input)
- feats0, feats1, diffs = {}, {}, {}
- lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
- for kk in range(len(self.chns)):
- feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
- diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
-
- res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
- val = res[0]
- for l in range(1, len(self.chns)):
- val += res[l]
- return val
-
-
-class ScalingLayer(nn.Module):
- def __init__(self):
- super(ScalingLayer, self).__init__()
- self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
- self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
-
- def forward(self, inp):
- return (inp - self.shift) / self.scale
-
-
-class NetLinLayer(nn.Module):
- """A single linear layer which does a 1x1 conv"""
-
- def __init__(self, chn_in, chn_out=1, use_dropout=False):
- super(NetLinLayer, self).__init__()
- layers = (
- [
- nn.Dropout(),
- ]
- if (use_dropout)
- else []
- )
- layers += [
- nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
- ]
- self.model = nn.Sequential(*layers)
-
-
-class vgg16(torch.nn.Module):
- def __init__(self, requires_grad=False, pretrained=True):
- super(vgg16, self).__init__()
- vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
- self.slice1 = torch.nn.Sequential()
- self.slice2 = torch.nn.Sequential()
- self.slice3 = torch.nn.Sequential()
- self.slice4 = torch.nn.Sequential()
- self.slice5 = torch.nn.Sequential()
- self.N_slices = 5
- for x in range(4):
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
- for x in range(4, 9):
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
- for x in range(9, 16):
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
- for x in range(16, 23):
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
- for x in range(23, 30):
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
- if not requires_grad:
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, X):
- h = self.slice1(X)
- h_relu1_2 = h
- h = self.slice2(h)
- h_relu2_2 = h
- h = self.slice3(h)
- h_relu3_3 = h
- h = self.slice4(h)
- h_relu4_3 = h
- h = self.slice5(h)
- h_relu5_3 = h
- vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
- out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
- return out
-
-
-def normalize_tensor(x, eps=1e-10):
- norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
- return x / (norm_factor + eps)
-
-
-def spatial_average(x, keepdim=True):
- return x.mean([2, 3], keepdim=keepdim)
-
-
-def weights_init(m):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- nn.init.normal_(m.weight.data, 0.0, 0.02)
- elif classname.find("BatchNorm") != -1:
- nn.init.normal_(m.weight.data, 1.0, 0.02)
- nn.init.constant_(m.bias.data, 0)
-
-
-def weights_init_conv(m):
- if hasattr(m, "conv"):
- m = m.conv
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- nn.init.normal_(m.weight.data, 0.0, 0.02)
- elif classname.find("BatchNorm") != -1:
- nn.init.normal_(m.weight.data, 1.0, 0.02)
- nn.init.constant_(m.bias.data, 0)
-
-
-class NLayerDiscriminator(nn.Module):
- """Defines a PatchGAN discriminator as in Pix2Pix
- --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
- """
-
- def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
- """Construct a PatchGAN discriminator
- Parameters:
- input_nc (int) -- the number of channels in input images
- ndf (int) -- the number of filters in the last conv layer
- n_layers (int) -- the number of conv layers in the discriminator
- norm_layer -- normalization layer
- """
- super(NLayerDiscriminator, self).__init__()
- if not use_actnorm:
- norm_layer = nn.BatchNorm2d
- else:
- norm_layer = ActNorm
- if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
- use_bias = norm_layer.func != nn.BatchNorm2d
- else:
- use_bias = norm_layer != nn.BatchNorm2d
-
- kw = 4
- padw = 1
- sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
- nf_mult = 1
- nf_mult_prev = 1
- for n in range(1, n_layers): # gradually increase the number of filters
- nf_mult_prev = nf_mult
- nf_mult = min(2**n, 8)
- sequence += [
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
- norm_layer(ndf * nf_mult),
- nn.LeakyReLU(0.2, True),
- ]
-
- nf_mult_prev = nf_mult
- nf_mult = min(2**n_layers, 8)
- sequence += [
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
- norm_layer(ndf * nf_mult),
- nn.LeakyReLU(0.2, True),
- ]
-
- sequence += [
- nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
- ] # output 1 channel prediction map
- self.main = nn.Sequential(*sequence)
-
- def forward(self, input):
- """Standard forward."""
- return self.main(input)
-
-
-class NLayerDiscriminator3D(nn.Module):
- """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
-
- def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False):
- """
- Construct a 3D PatchGAN discriminator
-
- Parameters:
- input_nc (int) -- the number of channels in input volumes
- ndf (int) -- the number of filters in the last conv layer
- n_layers (int) -- the number of conv layers in the discriminator
- use_actnorm (bool) -- flag to use actnorm instead of batchnorm
- """
- super(NLayerDiscriminator3D, self).__init__()
- if not use_actnorm:
- norm_layer = nn.BatchNorm3d
- else:
- raise NotImplementedError("Not implemented.")
- if type(norm_layer) == functools.partial:
- use_bias = norm_layer.func != nn.BatchNorm3d
- else:
- use_bias = norm_layer != nn.BatchNorm3d
-
- kw = 3
- padw = 1
- sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
- nf_mult = 1
- nf_mult_prev = 1
- for n in range(1, n_layers): # gradually increase the number of filters
- nf_mult_prev = nf_mult
- nf_mult = min(2**n, 8)
- sequence += [
- nn.Conv3d(
- ndf * nf_mult_prev,
- ndf * nf_mult,
- kernel_size=(kw, kw, kw),
- stride=(2 if n == 1 else 1, 2, 2),
- padding=padw,
- bias=use_bias,
- ),
- norm_layer(ndf * nf_mult),
- nn.LeakyReLU(0.2, True),
- ]
-
- nf_mult_prev = nf_mult
- nf_mult = min(2**n_layers, 8)
- sequence += [
- nn.Conv3d(
- ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias
- ),
- norm_layer(ndf * nf_mult),
- nn.LeakyReLU(0.2, True),
- ]
-
- sequence += [
- nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
- ] # output 1 channel prediction map
- self.main = nn.Sequential(*sequence)
-
- def forward(self, input):
- """Standard forward."""
- return self.main(input)
-
-
-def hinge_d_loss(logits_real, logits_fake):
- loss_real = torch.mean(F.relu(1.0 - logits_real))
- loss_fake = torch.mean(F.relu(1.0 + logits_fake))
- d_loss = 0.5 * (loss_real + loss_fake)
- return d_loss
-
-
-def vanilla_d_loss(logits_real, logits_fake):
- d_loss = 0.5 * (
- torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
- )
- return d_loss
-
-
-def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
- assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
- loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3])
- loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3])
- loss_real = (weights * loss_real).sum() / weights.sum()
- loss_fake = (weights * loss_fake).sum() / weights.sum()
- d_loss = 0.5 * (loss_real + loss_fake)
- return d_loss
-
-
-def adopt_weight(weight, global_step, threshold=0, value=0.0):
- if global_step < threshold:
- weight = value
- return weight
-
-
-def measure_perplexity(predicted_indices, n_embed):
- # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
- # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
- encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
- avg_probs = encodings.mean(0)
- perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
- cluster_use = torch.sum(avg_probs > 0)
- return perplexity, cluster_use
-
-
-def l1(x, y):
- return torch.abs(x - y)
-
-
-def l2(x, y):
- return torch.pow((x - y), 2)
-
-
-class LPIPSWithDiscriminator(nn.Module):
- def __init__(
- self,
- disc_start,
- logvar_init=0.0,
- kl_weight=1.0,
- pixelloss_weight=1.0,
- perceptual_weight=1.0,
- # --- Discriminator Loss ---
- disc_num_layers=3,
- disc_in_channels=3,
- disc_factor=1.0,
- disc_weight=1.0,
- use_actnorm=False,
- disc_conditional=False,
- disc_loss="hinge",
- ):
- super().__init__()
- assert disc_loss in ["hinge", "vanilla"]
- self.kl_weight = kl_weight
- self.pixel_weight = pixelloss_weight
- self.perceptual_loss = LPIPS().eval()
- self.perceptual_weight = perceptual_weight
- self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
-
- self.discriminator = NLayerDiscriminator(
- input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
- ).apply(weights_init)
- self.discriminator_iter_start = disc_start
- self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
- self.disc_factor = disc_factor
- self.discriminator_weight = disc_weight
- self.disc_conditional = disc_conditional
-
- def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
- if last_layer is not None:
- nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
- else:
- nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
- g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
-
- d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
- d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
- d_weight = d_weight * self.discriminator_weight
- return d_weight
-
- def forward(
- self,
- inputs,
- reconstructions,
- posteriors,
- optimizer_idx,
- global_step,
- split="train",
- weights=None,
- last_layer=None,
- cond=None,
- ):
- inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
- reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous()
- rec_loss = torch.abs(inputs - reconstructions)
- if self.perceptual_weight > 0:
- p_loss = self.perceptual_loss(inputs, reconstructions)
- rec_loss = rec_loss + self.perceptual_weight * p_loss
- nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
- weighted_nll_loss = nll_loss
- if weights is not None:
- weighted_nll_loss = weights * nll_loss
- weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
- nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
- kl_loss = posteriors.kl()
- kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
-
- # GAN Part
- if optimizer_idx == 0:
- # generator update
- if cond is None:
- assert not self.disc_conditional
- logits_fake = self.discriminator(reconstructions.contiguous())
- else:
- assert self.disc_conditional
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
- g_loss = -torch.mean(logits_fake)
-
- if self.disc_factor > 0.0:
- try:
- d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
- except RuntimeError:
- assert not self.training
- d_weight = torch.tensor(0.0)
- else:
- d_weight = torch.tensor(0.0)
-
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
- loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
- log = {
- "{}/total_loss".format(split): loss.clone().detach().mean(),
- "{}/logvar".format(split): self.logvar.detach(),
- "{}/kl_loss".format(split): kl_loss.detach().mean(),
- "{}/nll_loss".format(split): nll_loss.detach().mean(),
- "{}/rec_loss".format(split): rec_loss.detach().mean(),
- "{}/d_weight".format(split): d_weight.detach(),
- "{}/disc_factor".format(split): torch.tensor(disc_factor),
- "{}/g_loss".format(split): g_loss.detach().mean(),
- }
- return loss, log
-
- if optimizer_idx == 1:
- if cond is None:
- logits_real = self.discriminator(inputs.contiguous().detach())
- logits_fake = self.discriminator(reconstructions.contiguous().detach())
- else:
- logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
-
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
- d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
-
- log = {
- "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
- "{}/logits_real".format(split): logits_real.detach().mean(),
- "{}/logits_fake".format(split): logits_fake.detach().mean(),
- }
- return d_loss, log
-
-
-class LPIPSWithDiscriminator3D(nn.Module):
- def __init__(
- self,
- disc_start,
- logvar_init=0.0,
- kl_weight=1.0,
- pixelloss_weight=1.0,
- perceptual_weight=1.0,
- # --- Discriminator Loss ---
- disc_num_layers=3,
- disc_in_channels=3,
- disc_factor=1.0,
- disc_weight=1.0,
- use_actnorm=False,
- disc_conditional=False,
- disc_loss="hinge",
- ):
- super().__init__()
- assert disc_loss in ["hinge", "vanilla"]
- self.kl_weight = kl_weight
- self.pixel_weight = pixelloss_weight
- self.perceptual_loss = LPIPS().eval()
- self.perceptual_weight = perceptual_weight
- self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
-
- self.discriminator = NLayerDiscriminator3D(
- input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
- ).apply(weights_init)
- self.discriminator_iter_start = disc_start
- self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
- self.disc_factor = disc_factor
- self.discriminator_weight = disc_weight
- self.disc_conditional = disc_conditional
-
- def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
- if last_layer is not None:
- nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
- else:
- nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
- g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
-
- d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
- d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
- d_weight = d_weight * self.discriminator_weight
- return d_weight
-
- def forward(
- self,
- inputs,
- reconstructions,
- posteriors,
- optimizer_idx,
- global_step,
- split="train",
- weights=None,
- last_layer=None,
- cond=None,
- ):
- t = inputs.shape[2]
- inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
- reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous()
- rec_loss = torch.abs(inputs - reconstructions)
- if self.perceptual_weight > 0:
- p_loss = self.perceptual_loss(inputs, reconstructions)
- rec_loss = rec_loss + self.perceptual_weight * p_loss
- nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
- weighted_nll_loss = nll_loss
- if weights is not None:
- weighted_nll_loss = weights * nll_loss
- weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
- nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
- kl_loss = posteriors.kl()
- kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
- inputs = rearrange(inputs, "(b t) c h w -> b c t h w", t=t).contiguous()
- reconstructions = rearrange(reconstructions, "(b t) c h w -> b c t h w", t=t).contiguous()
- # GAN Part
- if optimizer_idx == 0:
- # generator update
- if cond is None:
- assert not self.disc_conditional
- logits_fake = self.discriminator(reconstructions)
- else:
- assert self.disc_conditional
- logits_fake = self.discriminator(torch.cat((reconstructions, cond), dim=1))
- g_loss = -torch.mean(logits_fake)
-
- if self.disc_factor > 0.0:
- try:
- d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
- except RuntimeError as e:
- assert not self.training, print(e)
- d_weight = torch.tensor(0.0)
- else:
- d_weight = torch.tensor(0.0)
-
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
- loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
- log = {
- "{}/total_loss".format(split): loss.clone().detach().mean(),
- "{}/logvar".format(split): self.logvar.detach(),
- "{}/kl_loss".format(split): kl_loss.detach().mean(),
- "{}/nll_loss".format(split): nll_loss.detach().mean(),
- "{}/rec_loss".format(split): rec_loss.detach().mean(),
- "{}/d_weight".format(split): d_weight.detach(),
- "{}/disc_factor".format(split): torch.tensor(disc_factor),
- "{}/g_loss".format(split): g_loss.detach().mean(),
- }
- return loss, log
-
- if optimizer_idx == 1:
- if cond is None:
- logits_real = self.discriminator(inputs.contiguous().detach())
- logits_fake = self.discriminator(reconstructions.contiguous().detach())
- else:
- logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
-
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
- d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
-
- log = {
- "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
- "{}/logits_real".format(split): logits_real.detach().mean(),
- "{}/logits_fake".format(split): logits_fake.detach().mean(),
- }
- return d_loss, log
-
-
-class SimpleLPIPS(nn.Module):
- def __init__(
- self,
- logvar_init=0.0,
- kl_weight=1.0,
- pixelloss_weight=1.0,
- perceptual_weight=1.0,
- disc_loss="hinge",
- ):
- super().__init__()
- assert disc_loss in ["hinge", "vanilla"]
- self.kl_weight = kl_weight
- self.pixel_weight = pixelloss_weight
- self.perceptual_loss = LPIPS().eval()
- self.perceptual_weight = perceptual_weight
- self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
-
- def forward(
- self,
- inputs,
- reconstructions,
- posteriors,
- split="train",
- weights=None,
- ):
- inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
- reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous()
- rec_loss = torch.abs(inputs - reconstructions)
- if self.perceptual_weight > 0:
- p_loss = self.perceptual_loss(inputs, reconstructions)
- rec_loss = rec_loss + self.perceptual_weight * p_loss
- nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
- weighted_nll_loss = nll_loss
- if weights is not None:
- weighted_nll_loss = weights * nll_loss
- weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
- nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
- kl_loss = posteriors.kl()
- kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
- loss = weighted_nll_loss + self.kl_weight * kl_loss
- log = {
- "{}/total_loss".format(split): loss.clone().detach().mean(),
- "{}/logvar".format(split): self.logvar.detach(),
- "{}/kl_loss".format(split): kl_loss.detach().mean(),
- "{}/nll_loss".format(split): nll_loss.detach().mean(),
- "{}/rec_loss".format(split): rec_loss.detach().mean(),
- }
- if self.perceptual_weight > 0:
- log.update({"{}/p_loss".format(split): p_loss.detach().mean()})
- return loss, log
diff --git a/videosys/models/open_sora_plan/modules/__init__.py b/videosys/models/open_sora_plan/modules/__init__.py
deleted file mode 100644
index 519c44694083780c478bfd10079f1b2accd80652..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/modules/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from .attention import AttnBlock, AttnBlock3D, AttnBlock3DFix, LinAttnBlock, LinearAttention, TemporalAttnBlock
-from .block import Block
-from .conv import CausalConv3d, Conv2d
-from .normalize import GroupNorm, Normalize
-from .resnet_block import ResnetBlock2D, ResnetBlock3D
-from .updownsample import (
- Downsample,
- SpatialDownsample2x,
- SpatialUpsample2x,
- TimeDownsample2x,
- TimeDownsampleRes2x,
- TimeDownsampleResAdv2x,
- TimeUpsample2x,
- TimeUpsampleRes2x,
- TimeUpsampleResAdv2x,
- Upsample,
-)
diff --git a/videosys/models/open_sora_plan/modules/attention.py b/videosys/models/open_sora_plan/modules/attention.py
deleted file mode 100644
index 97ea9364d493cf52750afe2f399a41072f834fc0..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/modules/attention.py
+++ /dev/null
@@ -1,227 +0,0 @@
-# Adapted from Open-Sora-Plan
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
-# --------------------------------------------------------
-
-import torch
-import torch.nn as nn
-from einops import rearrange
-
-from .block import Block
-from .conv import CausalConv3d
-from .normalize import Normalize
-from .ops import video_to_image
-
-
-class LinearAttention(Block):
- def __init__(self, dim, heads=4, dim_head=32):
- super().__init__()
- self.heads = heads
- hidden_dim = dim_head * heads
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
-
- def forward(self, x):
- b, c, h, w = x.shape
- qkv = self.to_qkv(x)
- q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
- k = k.softmax(dim=-1)
- context = torch.einsum("bhdn,bhen->bhde", k, v)
- out = torch.einsum("bhde,bhdn->bhen", context, q)
- out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
- return self.to_out(out)
-
-
-class LinAttnBlock(LinearAttention):
- """to match AttnBlock usage"""
-
- def __init__(self, in_channels):
- super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
-
-
-class AttnBlock3D(Block):
- """Compatible with old versions, there are issues, use with caution."""
-
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
- self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
- self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
- self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, t, h, w = q.shape
- q = q.reshape(b * t, c, h * w)
- q = q.permute(0, 2, 1) # b,hw,c
- k = k.reshape(b * t, c, h * w) # b,c,hw
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c) ** (-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b * t, c, h * w)
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b, c, t, h, w)
-
- h_ = self.proj_out(h_)
-
- return x + h_
-
-
-class AttnBlock3DFix(nn.Module):
- """
- Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172.
- """
-
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
- self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
- self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
- self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c)
- b, c, t, h, w = q.shape
- q = q.permute(0, 2, 1, 3, 4)
- q = q.reshape(b * t, c, h * w)
- q = q.permute(0, 2, 1)
-
- # k: (b c t h w) -> (b t c h w) -> (b*t c h*w)
- k = k.permute(0, 2, 1, 3, 4)
- k = k.reshape(b * t, c, h * w)
-
- # w: (b*t hw hw)
- w_ = torch.bmm(q, k)
- w_ = w_ * (int(c) ** (-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- # v: (b c t h w) -> (b t c h w) -> (bt c hw)
- # w_: (bt hw hw) -> (bt hw hw)
- v = v.permute(0, 2, 1, 3, 4)
- v = v.reshape(b * t, c, h * w)
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
-
- # h_: (b*t c hw) -> (b t c h w) -> (b c t h w)
- h_ = h_.reshape(b, t, c, h, w)
- h_ = h_.permute(0, 2, 1, 3, 4)
-
- h_ = self.proj_out(h_)
-
- return x + h_
-
-
-class AttnBlock(Block):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
-
- @video_to_image
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h * w)
- q = q.permute(0, 2, 1) # b,hw,c
- k = k.reshape(b, c, h * w) # b,c,hw
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c) ** (-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b, c, h * w)
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b, c, h, w)
-
- h_ = self.proj_out(h_)
-
- return x + h_
-
-
-class TemporalAttnBlock(Block):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.k = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.v = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
- self.proj_out = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, t, h, w = q.shape
- q = rearrange(q, "b c t h w -> (b h w) t c")
- k = rearrange(k, "b c t h w -> (b h w) c t")
- v = rearrange(v, "b c t h w -> (b h w) c t")
- w_ = torch.bmm(q, k)
- w_ = w_ * (int(c) ** (-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- w_ = w_.permute(0, 2, 1)
- h_ = torch.bmm(v, w_)
- h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w)
- h_ = self.proj_out(h_)
-
- return x + h_
-
-
-def make_attn(in_channels, attn_type="vanilla"):
- assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown"
- print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
- print(attn_type)
- if attn_type == "vanilla":
- return AttnBlock(in_channels)
- elif attn_type == "vanilla3D":
- return AttnBlock3D(in_channels)
- elif attn_type == "none":
- return nn.Identity(in_channels)
- else:
- return LinAttnBlock(in_channels)
diff --git a/videosys/models/open_sora_plan/modules/block.py b/videosys/models/open_sora_plan/modules/block.py
deleted file mode 100644
index 423e1b6f62f4121515a09806e00b38ba68f56516..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/modules/block.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# Adapted from Open-Sora-Plan
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
-# --------------------------------------------------------
-
-import torch.nn as nn
-
-
-class Block(nn.Module):
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
diff --git a/videosys/models/open_sora_plan/modules/conv.py b/videosys/models/open_sora_plan/modules/conv.py
deleted file mode 100644
index 787f4c263f1caf25a0f868ddf73e6e7e99a59ee1..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/modules/conv.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Adapted from Open-Sora-Plan
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
-# --------------------------------------------------------
-
-from typing import Tuple, Union
-
-import torch
-import torch.nn as nn
-
-from .ops import cast_tuple, video_to_image
-
-
-class Conv2d(nn.Conv2d):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int]] = 3,
- stride: Union[int, Tuple[int]] = 1,
- padding: Union[str, int, Tuple[int]] = 0,
- dilation: Union[int, Tuple[int]] = 1,
- groups: int = 1,
- bias: bool = True,
- padding_mode: str = "zeros",
- device=None,
- dtype=None,
- ) -> None:
- super().__init__(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- dilation,
- groups,
- bias,
- padding_mode,
- device,
- dtype,
- )
-
- @video_to_image
- def forward(self, x):
- return super().forward(x)
-
-
-class CausalConv3d(nn.Module):
- def __init__(
- self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
- ):
- super().__init__()
- self.kernel_size = cast_tuple(kernel_size, 3)
- self.time_kernel_size = self.kernel_size[0]
- self.chan_in = chan_in
- self.chan_out = chan_out
- stride = kwargs.pop("stride", 1)
- padding = kwargs.pop("padding", 0)
- padding = list(cast_tuple(padding, 3))
- padding[0] = 0
- stride = cast_tuple(stride, 3)
- self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
- self._init_weights(init_method)
-
- def _init_weights(self, init_method):
- torch.tensor(self.kernel_size)
- if init_method == "avg":
- assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample"
- assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
- weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
-
- eyes = torch.concat(
- [
- torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
- torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
- torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
- ],
- dim=-1,
- )
- weight[:, :, :, 0, 0] = eyes
-
- self.conv.weight = nn.Parameter(
- weight,
- requires_grad=True,
- )
- elif init_method == "zero":
- self.conv.weight = nn.Parameter(
- torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
- requires_grad=True,
- )
- if self.conv.bias is not None:
- nn.init.constant_(self.conv.bias, 0)
-
- def forward(self, x):
- # 1 + 16 16 as video, 1 as image
- first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w
- x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
- return self.conv(x)
diff --git a/videosys/models/open_sora_plan/modules/normalize.py b/videosys/models/open_sora_plan/modules/normalize.py
deleted file mode 100644
index 0ee61c7b25501d19ad7ba3091e73df9750f5a68a..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/modules/normalize.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# Adapted from Open-Sora-Plan
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
-# --------------------------------------------------------
-
-import torch
-import torch.nn as nn
-
-from .block import Block
-
-
-class GroupNorm(Block):
- def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
- self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True)
-
- def forward(self, x):
- return self.norm(x)
-
-
-def Normalize(in_channels, num_groups=32):
- return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
-
-
-class ActNorm(nn.Module):
- def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
- assert affine
- super().__init__()
- self.logdet = logdet
- self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
- self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
- self.allow_reverse_init = allow_reverse_init
-
- self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
-
- def initialize(self, input):
- with torch.no_grad():
- flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
- mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
- std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
-
- self.loc.data.copy_(-mean)
- self.scale.data.copy_(1 / (std + 1e-6))
-
- def forward(self, input, reverse=False):
- if reverse:
- return self.reverse(input)
- if len(input.shape) == 2:
- input = input[:, :, None, None]
- squeeze = True
- else:
- squeeze = False
-
- _, _, height, width = input.shape
-
- if self.training and self.initialized.item() == 0:
- self.initialize(input)
- self.initialized.fill_(1)
-
- h = self.scale * (input + self.loc)
-
- if squeeze:
- h = h.squeeze(-1).squeeze(-1)
-
- if self.logdet:
- log_abs = torch.log(torch.abs(self.scale))
- logdet = height * width * torch.sum(log_abs)
- logdet = logdet * torch.ones(input.shape[0]).to(input)
- return h, logdet
-
- return h
-
- def reverse(self, output):
- if self.training and self.initialized.item() == 0:
- if not self.allow_reverse_init:
- raise RuntimeError(
- "Initializing ActNorm in reverse direction is "
- "disabled by default. Use allow_reverse_init=True to enable."
- )
- else:
- self.initialize(output)
- self.initialized.fill_(1)
-
- if len(output.shape) == 2:
- output = output[:, :, None, None]
- squeeze = True
- else:
- squeeze = False
-
- h = output / self.scale - self.loc
-
- if squeeze:
- h = h.squeeze(-1).squeeze(-1)
- return h
diff --git a/videosys/models/open_sora_plan/modules/ops.py b/videosys/models/open_sora_plan/modules/ops.py
deleted file mode 100644
index 8fd636ae92511f108def36706b79f52a45c939fd..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/modules/ops.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Adapted from Open-Sora-Plan
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
-# --------------------------------------------------------
-
-import torch
-from einops import rearrange
-
-
-def video_to_image(func):
- def wrapper(self, x, *args, **kwargs):
- if x.dim() == 5:
- t = x.shape[2]
- x = rearrange(x, "b c t h w -> (b t) c h w")
- x = func(self, x, *args, **kwargs)
- x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
- return x
-
- return wrapper
-
-
-def nonlinearity(x):
- return x * torch.sigmoid(x)
-
-
-def cast_tuple(t, length=1):
- return t if isinstance(t, tuple) else ((t,) * length)
-
-
-def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
- n_dims = len(x.shape)
- if src_dim < 0:
- src_dim = n_dims + src_dim
- if dest_dim < 0:
- dest_dim = n_dims + dest_dim
- assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
- dims = list(range(n_dims))
- del dims[src_dim]
- permutation = []
- ctr = 0
- for i in range(n_dims):
- if i == dest_dim:
- permutation.append(src_dim)
- else:
- permutation.append(dims[ctr])
- ctr += 1
- x = x.permute(permutation)
- if make_contiguous:
- x = x.contiguous()
- return x
diff --git a/videosys/models/open_sora_plan/modules/quant.py b/videosys/models/open_sora_plan/modules/quant.py
deleted file mode 100644
index e7b9dcf26b95ad81aa5474feec8397b3c01916bb..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/modules/quant.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# Adapted from Open-Sora-Plan
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
-# --------------------------------------------------------
-
-import numpy as np
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .ops import shift_dim
-
-
-class Codebook(nn.Module):
- def __init__(self, n_codes, embedding_dim):
- super().__init__()
- self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
- self.register_buffer("N", torch.zeros(n_codes))
- self.register_buffer("z_avg", self.embeddings.data.clone())
-
- self.n_codes = n_codes
- self.embedding_dim = embedding_dim
- self._need_init = True
-
- def _tile(self, x):
- d, ew = x.shape
- if d < self.n_codes:
- n_repeats = (self.n_codes + d - 1) // d
- std = 0.01 / np.sqrt(ew)
- x = x.repeat(n_repeats, 1)
- x = x + torch.randn_like(x) * std
- return x
-
- def _init_embeddings(self, z):
- # z: [b, c, t, h, w]
- self._need_init = False
- flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
- y = self._tile(flat_inputs)
-
- y.shape[0]
- _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
- if dist.is_initialized():
- dist.broadcast(_k_rand, 0)
- self.embeddings.data.copy_(_k_rand)
- self.z_avg.data.copy_(_k_rand)
- self.N.data.copy_(torch.ones(self.n_codes))
-
- def forward(self, z):
- # z: [b, c, t, h, w]
- if self._need_init and self.training:
- self._init_embeddings(z)
- flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
- distances = (
- (flat_inputs**2).sum(dim=1, keepdim=True)
- - 2 * flat_inputs @ self.embeddings.t()
- + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
- )
-
- encoding_indices = torch.argmin(distances, dim=1)
- encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
- encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
-
- embeddings = F.embedding(encoding_indices, self.embeddings)
- embeddings = shift_dim(embeddings, -1, 1)
-
- commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
-
- # EMA codebook update
- if self.training:
- n_total = encode_onehot.sum(dim=0)
- encode_sum = flat_inputs.t() @ encode_onehot
- if dist.is_initialized():
- dist.all_reduce(n_total)
- dist.all_reduce(encode_sum)
-
- self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
- self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
-
- n = self.N.sum()
- weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
- encode_normalized = self.z_avg / weights.unsqueeze(1)
- self.embeddings.data.copy_(encode_normalized)
-
- y = self._tile(flat_inputs)
- _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
- if dist.is_initialized():
- dist.broadcast(_k_rand, 0)
-
- usage = (self.N.view(self.n_codes, 1) >= 1).float()
- self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
-
- embeddings_st = (embeddings - z).detach() + z
-
- avg_probs = torch.mean(encode_onehot, dim=0)
- perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
-
- return dict(
- embeddings=embeddings_st,
- encodings=encoding_indices,
- commitment_loss=commitment_loss,
- perplexity=perplexity,
- )
-
- def dictionary_lookup(self, encodings):
- embeddings = F.embedding(encodings, self.embeddings)
- return embeddings
diff --git a/videosys/models/open_sora_plan/modules/resnet_block.py b/videosys/models/open_sora_plan/modules/resnet_block.py
deleted file mode 100644
index 987c690e525f28114d8321d0cb7c043a4b2a7e8b..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/modules/resnet_block.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# Adapted from Open-Sora-Plan
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
-# --------------------------------------------------------
-
-import torch
-
-from .block import Block
-from .conv import CausalConv3d
-from .normalize import Normalize
-from .ops import nonlinearity, video_to_image
-
-
-class ResnetBlock2D(Block):
- def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = in_channels if out_channels is None else out_channels
- self.use_conv_shortcut = conv_shortcut
-
- self.norm1 = Normalize(in_channels)
- self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
- self.norm2 = Normalize(out_channels)
- self.dropout = torch.nn.Dropout(dropout)
- self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
- else:
- self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
-
- @video_to_image
- def forward(self, x):
- h = x
- h = self.norm1(h)
- h = nonlinearity(h)
- h = self.conv1(h)
- h = self.norm2(h)
- h = nonlinearity(h)
- h = self.dropout(h)
- h = self.conv2(h)
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
- else:
- x = self.nin_shortcut(x)
- x = x + h
- return x
-
-
-class ResnetBlock3D(Block):
- def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = in_channels if out_channels is None else out_channels
- self.use_conv_shortcut = conv_shortcut
-
- self.norm1 = Normalize(in_channels)
- self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
- self.norm2 = Normalize(out_channels)
- self.dropout = torch.nn.Dropout(dropout)
- self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
- else:
- self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
-
- def forward(self, x):
- h = x
- h = self.norm1(h)
- h = nonlinearity(h)
- h = self.conv1(h)
- h = self.norm2(h)
- h = nonlinearity(h)
- h = self.dropout(h)
- h = self.conv2(h)
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
- else:
- x = self.nin_shortcut(x)
- return x + h
diff --git a/videosys/models/open_sora_plan/modules/updownsample.py b/videosys/models/open_sora_plan/modules/updownsample.py
deleted file mode 100644
index db27de1d95206d80336472a1acc4a99165ebbb98..0000000000000000000000000000000000000000
--- a/videosys/models/open_sora_plan/modules/updownsample.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# Adapted from Open-Sora-Plan
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
-# --------------------------------------------------------
-
-from typing import Tuple, Union
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-
-from .attention import TemporalAttnBlock
-from .block import Block
-from .conv import CausalConv3d
-from .normalize import Normalize
-from .ops import cast_tuple, video_to_image
-from .resnet_block import ResnetBlock3D
-
-
-class Upsample(Block):
- def __init__(self, in_channels, out_channels):
- super().__init__()
- self.with_conv = True
- if self.with_conv:
- self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
-
- @video_to_image
- def forward(self, x):
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
- if self.with_conv:
- x = self.conv(x)
- return x
-
-
-class Downsample(Block):
- def __init__(self, in_channels, out_channels):
- super().__init__()
- self.with_conv = True
- if self.with_conv:
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
-
- @video_to_image
- def forward(self, x):
- if self.with_conv:
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- else:
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
- return x
-
-
-class SpatialDownsample2x(Block):
- def __init__(
- self,
- chan_in,
- chan_out,
- kernel_size: Union[int, Tuple[int]] = (3, 3),
- stride: Union[int, Tuple[int]] = (2, 2),
- ):
- super().__init__()
- kernel_size = cast_tuple(kernel_size, 2)
- stride = cast_tuple(stride, 2)
- self.chan_in = chan_in
- self.chan_out = chan_out
- self.kernel_size = kernel_size
- self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=0)
-
- def forward(self, x):
- pad = (0, 1, 0, 1, 0, 0)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- return x
-
-
-class SpatialUpsample2x(Block):
- def __init__(
- self,
- chan_in,
- chan_out,
- kernel_size: Union[int, Tuple[int]] = (3, 3),
- stride: Union[int, Tuple[int]] = (1, 1),
- ):
- super().__init__()
- self.chan_in = chan_in
- self.chan_out = chan_out
- self.kernel_size = kernel_size
- self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1)
-
- def forward(self, x):
- t = x.shape[2]
- x = rearrange(x, "b c t h w -> b (c t) h w")
- x = F.interpolate(x, scale_factor=(2, 2), mode="nearest")
- x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
- x = self.conv(x)
- return x
-
-
-class TimeDownsample2x(Block):
- def __init__(self, chan_in, chan_out, kernel_size: int = 3):
- super().__init__()
- self.kernel_size = kernel_size
- self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
-
- def forward(self, x):
- first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size - 1, 1, 1))
- x = torch.concatenate((first_frame_pad, x), dim=2)
- return self.conv(x)
-
-
-class TimeUpsample2x(Block):
- def __init__(self, chan_in, chan_out):
- super().__init__()
-
- def forward(self, x):
- if x.size(2) > 1:
- x, x_ = x[:, :, :1], x[:, :, 1:]
- x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
- x = torch.concat([x, x_], dim=2)
- return x
-
-
-class TimeDownsampleRes2x(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size: int = 3,
- mix_factor: float = 2.0,
- ):
- super().__init__()
- self.kernel_size = cast_tuple(kernel_size, 3)
- self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
- self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
- self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
-
- def forward(self, x):
- alpha = torch.sigmoid(self.mix_factor)
- first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
- x = torch.concatenate((first_frame_pad, x), dim=2)
- return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x)
-
-
-class TimeUpsampleRes2x(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size: int = 3,
- mix_factor: float = 2.0,
- ):
- super().__init__()
- self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
- self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
-
- def forward(self, x):
- alpha = torch.sigmoid(self.mix_factor)
- if x.size(2) > 1:
- x, x_ = x[:, :, :1], x[:, :, 1:]
- x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
- x = torch.concat([x, x_], dim=2)
- return alpha * x + (1 - alpha) * self.conv(x)
-
-
-class TimeDownsampleResAdv2x(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size: int = 3,
- mix_factor: float = 1.5,
- ):
- super().__init__()
- self.kernel_size = cast_tuple(kernel_size, 3)
- self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
- self.attn = TemporalAttnBlock(in_channels)
- self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
- self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
- self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
-
- def forward(self, x):
- first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
- x = torch.concatenate((first_frame_pad, x), dim=2)
- alpha = torch.sigmoid(self.mix_factor)
- return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x))))
-
-
-class TimeUpsampleResAdv2x(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size: int = 3,
- mix_factor: float = 1.5,
- ):
- super().__init__()
- self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
- self.attn = TemporalAttnBlock(in_channels)
- self.norm = Normalize(in_channels=in_channels)
- self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
- self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
-
- def forward(self, x):
- if x.size(2) > 1:
- x, x_ = x[:, :, :1], x[:, :, 1:]
- x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
- x = torch.concat([x, x_], dim=2)
- alpha = torch.sigmoid(self.mix_factor)
- return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x)))
diff --git a/tests/__init__.py b/videosys/models/transformers/__init__.py
similarity index 100%
rename from tests/__init__.py
rename to videosys/models/transformers/__init__.py
diff --git a/videosys/models/transformers/cogvideox_transformer_3d.py b/videosys/models/transformers/cogvideox_transformer_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8af5482f55fd6497423b0c3a0566af36a798dc8
--- /dev/null
+++ b/videosys/models/transformers/cogvideox_transformer_3d.py
@@ -0,0 +1,534 @@
+# Adapted from CogVideo
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# CogVideo: https://github.com/THUDM/CogVideo
+# diffusers: https://github.com/huggingface/diffusers
+# --------------------------------------------------------
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import is_torch_version
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from torch import nn
+
+from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
+from videosys.models.modules.embeddings import apply_rotary_emb
+
+from ..modules.embeddings import CogVideoXPatchEmbed
+from ..modules.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+
+
+class CogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class FusedCogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ block_idx: int = 0,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ # pab
+ self.attn_count = 0
+ self.last_attn = None
+ self.block_idx = block_idx
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ timestep=None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ if enable_pab():
+ broadcast_attn, self.attn_count = if_broadcast_spatial(int(timestep[0]), self.attn_count, self.block_idx)
+ if enable_pab() and broadcast_attn:
+ attn_hidden_states, attn_encoder_hidden_states = self.last_attn
+ else:
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+ if enable_pab():
+ self.last_attn = (attn_hidden_states, attn_encoder_hidden_states)
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 30,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ post_patch_height = sample_height // patch_size
+ post_patch_width = sample_width // patch_size
+ post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
+ self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. 3D positional embeddings
+ spatial_pos_embedding = get_3d_sincos_pos_embed(
+ inner_dim,
+ (post_patch_width, post_patch_height),
+ post_time_compression_frames,
+ spatial_interpolation_scale,
+ temporal_interpolation_scale,
+ )
+ spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
+ pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
+ pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
+ self.register_buffer("pos_embedding", pos_embedding, persistent=False)
+
+ # 3. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 4. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 5. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ return_dict: bool = True,
+ ):
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+
+ # 3. Position embedding
+ text_seq_length = encoder_hidden_states.shape[1]
+ if not self.config.use_rotary_positional_embeddings:
+ seq_length = height * width * num_frames // (self.config.patch_size**2)
+
+ pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
+ hidden_states = hidden_states + pos_embeds
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 4. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ timestep=timesteps if enable_pab() else None,
+ )
+
+ if not self.config.use_rotary_positional_embeddings:
+ # CogVideoX-2B
+ hidden_states = self.norm_final(hidden_states)
+ else:
+ # CogVideoX-5B
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 5. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 6. Unpatchify
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/videosys/models/latte/latte_t2v.py b/videosys/models/transformers/latte_transformer_3d.py
similarity index 100%
rename from videosys/models/latte/latte_t2v.py
rename to videosys/models/transformers/latte_transformer_3d.py
diff --git a/videosys/models/open_sora_plan/latte.py b/videosys/models/transformers/open_sora_plan_transformer_3d.py
similarity index 100%
rename from videosys/models/open_sora_plan/latte.py
rename to videosys/models/transformers/open_sora_plan_transformer_3d.py
diff --git a/videosys/models/open_sora/stdit3.py b/videosys/models/transformers/open_sora_transformer_3d.py
similarity index 87%
rename from videosys/models/open_sora/stdit3.py
rename to videosys/models/transformers/open_sora_transformer_3d.py
index cfea9dd5501c4dd341c277eb3023ca3854242b05..8e595741f5e09dff21f3bca369f327ccbaafed60 100644
--- a/videosys/models/open_sora/stdit3.py
+++ b/videosys/models/transformers/open_sora_transformer_3d.py
@@ -9,6 +9,7 @@
import os
+from collections.abc import Iterable
from functools import partial
import numpy as np
@@ -17,6 +18,7 @@ import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
+from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from transformers import PretrainedConfig, PreTrainedModel
from videosys.core.comm import (
@@ -43,22 +45,68 @@ from videosys.core.parallel_mgr import (
get_data_parallel_group,
get_sequence_parallel_group,
)
-from videosys.utils.utils import batch_func
-
-from .modules import (
- Attention,
- CaptionEmbedder,
- MultiHeadCrossAttention,
- PatchEmbed3D,
- PositionEmbedding2D,
+from videosys.models.modules.activations import approx_gelu
+from videosys.models.modules.attentions import Attention, MultiHeadCrossAttention
+from videosys.models.modules.embeddings import (
+ OpenSoraCaptionEmbedder,
+ OpenSoraPatchEmbed3D,
+ OpenSoraPositionEmbedding2D,
SizeEmbedder,
- T2IFinalLayer,
TimestepEmbedder,
- approx_gelu,
- get_layernorm,
- t2i_modulate,
)
-from .utils import auto_grad_checkpoint, load_checkpoint
+from videosys.utils.utils import batch_func
+
+
+def t2i_modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+class T2IFinalLayer(nn.Module):
+ """
+ The final layer of PixArt.
+ """
+
+ def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
+ self.out_channels = out_channels
+ self.d_t = d_t
+ self.d_s = d_s
+
+ def t_mask_select(self, x_mask, x, masked_x, T, S):
+ # x: [B, (T, S), C]
+ # mased_x: [B, (T, S), C]
+ # x_mask: [B, T]
+ x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
+ masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
+ x = torch.where(x_mask[:, :, None, None], x, masked_x)
+ x = rearrange(x, "B T S C -> B (T S) C")
+ return x
+
+ def forward(self, x, t, x_mask=None, t0=None, T=None, S=None):
+ if T is None:
+ T = self.d_t
+ if S is None:
+ S = self.d_s
+ shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
+ x = t2i_modulate(self.norm_final(x), shift, scale)
+ if x_mask is not None:
+ shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
+ x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
+ x = self.t_mask_select(x_mask, x, x_zero, T, S)
+ x = self.linear(x)
+ return x
+
+
+def auto_grad_checkpoint(module, *args, **kwargs):
+ if getattr(module, "grad_checkpointing", False):
+ if not isinstance(module, Iterable):
+ return checkpoint(module, *args, use_reentrant=False, **kwargs)
+ gc_step = module[0].grad_checkpointing_step
+ return checkpoint_sequential(module, gc_step, *args, use_reentrant=False, **kwargs)
+ return module(*args, **kwargs)
class STDiT3Block(nn.Module):
@@ -82,7 +130,7 @@ class STDiT3Block(nn.Module):
attn_cls = Attention
mha_cls = MultiHeadCrossAttention
- self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False)
+ self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
self.attn = attn_cls(
hidden_size,
num_heads=num_heads,
@@ -92,7 +140,7 @@ class STDiT3Block(nn.Module):
enable_flash_attn=enable_flash_attn,
)
self.cross_attn = mha_cls(hidden_size, num_heads)
- self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False)
+ self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
@@ -332,21 +380,21 @@ class STDiT3(PreTrainedModel):
# input size related
self.patch_size = config.patch_size
self.input_sq_size = config.input_sq_size
- self.pos_embed = PositionEmbedding2D(config.hidden_size)
+ self.pos_embed = OpenSoraPositionEmbedding2D(config.hidden_size)
from rotary_embedding_torch import RotaryEmbedding
self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads)
# embedding
- self.x_embedder = PatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size)
+ self.x_embedder = OpenSoraPatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size)
self.t_embedder = TimestepEmbedder(config.hidden_size)
self.fps_embedder = SizeEmbedder(self.hidden_size)
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True),
)
- self.y_embedder = CaptionEmbedder(
+ self.y_embedder = OpenSoraCaptionEmbedder(
in_channels=config.caption_channels,
hidden_size=config.hidden_size,
uncond_prob=config.class_dropout_prob,
@@ -598,6 +646,4 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs):
else:
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
model = STDiT3(config)
- if from_pretrained is not None:
- load_checkpoint(model, from_pretrained)
return model
diff --git a/videosys/modules/attn.py b/videosys/modules/attn.py
deleted file mode 100644
index 424c2b78a340696c26a530941ea013caa352889a..0000000000000000000000000000000000000000
--- a/videosys/modules/attn.py
+++ /dev/null
@@ -1,217 +0,0 @@
-from dataclasses import dataclass
-from typing import Iterable, List, Optional, Sequence, Tuple
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint
-
-from videosys.modules.layers import LlamaRMSNorm
-
-
-class Attention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- attn_drop: float = 0.0,
- proj_drop: float = 0.0,
- norm_layer: nn.Module = LlamaRMSNorm,
- enable_flashattn: bool = False,
- rope=None,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, "dim should be divisible by num_heads"
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim**-0.5
- self.enable_flashattn = enable_flashattn
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- self.rope = False
- if rope is not None:
- self.rope = True
- self.rotary_emb = rope
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- B, N, C = x.shape
-
- qkv = self.qkv(x)
- qkv = qkv.view(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4)
- q, k, v = qkv.unbind(0)
- if self.rope:
- q = self.rotary_emb(q)
- k = self.rotary_emb(k)
- q, k = self.q_norm(q), self.k_norm(k)
-
- if self.enable_flashattn:
- from flash_attn import flash_attn_func
-
- x = flash_attn_func(
- q,
- k,
- v,
- dropout_p=self.attn_drop.p if self.training else 0.0,
- softmax_scale=self.scale,
- )
- else:
- q, k, v = map(lambda t: t.permute(0, 2, 1, 3), (q, k, v))
- x = F.scaled_dot_product_attention(
- q, k, v, scale=self.scale, dropout_p=self.attn_drop.p if self.training else 0.0
- )
-
- x_output_shape = (B, N, C)
- if not self.enable_flashattn:
- x = x.transpose(1, 2)
- x = x.reshape(x_output_shape)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-class MultiHeadCrossAttention(nn.Module):
- def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, enable_flashattn=False):
- super(MultiHeadCrossAttention, self).__init__()
- assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
-
- self.d_model = d_model
- self.num_heads = num_heads
- self.head_dim = d_model // num_heads
- self.enable_flashattn = enable_flashattn
-
- self.q_linear = nn.Linear(d_model, d_model)
- self.kv_linear = nn.Linear(d_model, d_model * 2)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(d_model, d_model)
- self.proj_drop = nn.Dropout(proj_drop)
- self.last_out = None
- self.count = 0
-
- def forward(self, x, cond, mask=None, timestep=None):
- # query/value: img tokens; key: condition; mask: if padding tokens
- B, N, C = x.shape
-
- q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
- kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
- k, v = kv.unbind(2)
- x = self.flash_attn_impl(q, k, v, mask, B, N, C)
-
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def flash_attn_impl(self, q, k, v, mask, B, N, C):
- from flash_attn import flash_attn_varlen_func
-
- q_seqinfo = _SeqLenInfo.from_seqlens([N] * B)
- k_seqinfo = _SeqLenInfo.from_seqlens(mask)
-
- x = flash_attn_varlen_func(
- q.view(-1, self.num_heads, self.head_dim),
- k.view(-1, self.num_heads, self.head_dim),
- v.view(-1, self.num_heads, self.head_dim),
- cu_seqlens_q=q_seqinfo.seqstart.cuda(),
- cu_seqlens_k=k_seqinfo.seqstart.cuda(),
- max_seqlen_q=q_seqinfo.max_seqlen,
- max_seqlen_k=k_seqinfo.max_seqlen,
- dropout_p=self.attn_drop.p if self.training else 0.0,
- )
- x = x.view(B, N, C)
- return x
-
- def torch_impl(self, q, k, v, mask, B, N, C):
- q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
- k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
- v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- attn_mask = torch.zeros(B, N, k.shape[2], dtype=torch.float32, device=q.device)
- for i, m in enumerate(mask):
- attn_mask[i, :, m:] = -1e8
-
- scale = 1 / q.shape[-1] ** 0.5
- q = q * scale
- attn = q @ k.transpose(-2, -1)
- attn = attn.to(torch.float32)
- if mask is not None:
- attn = attn + attn_mask.unsqueeze(1)
- attn = attn.softmax(-1)
- attn = attn.to(v.dtype)
- out = attn @ v
-
- x = out.transpose(1, 2).contiguous().view(B, N, C)
- return x
-
-
-@dataclass
-class _SeqLenInfo:
- """
- copied from xformers
-
- (Internal) Represents the division of a dimension into blocks.
- For example, to represents a dimension of length 7 divided into
- three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
- The members will be:
- max_seqlen: 3
- min_seqlen: 2
- seqstart_py: [0, 2, 5, 7]
- seqstart: torch.IntTensor([0, 2, 5, 7])
- """
-
- seqstart: torch.Tensor
- max_seqlen: int
- min_seqlen: int
- seqstart_py: List[int]
-
- def to(self, device: torch.device) -> None:
- self.seqstart = self.seqstart.to(device, non_blocking=True)
-
- def intervals(self) -> Iterable[Tuple[int, int]]:
- yield from zip(self.seqstart_py, self.seqstart_py[1:])
-
- @classmethod
- def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
- """
- Input tensors are assumed to be in shape [B, M, *]
- """
- assert not isinstance(seqlens, torch.Tensor)
- seqstart_py = [0]
- max_seqlen = -1
- min_seqlen = -1
- for seqlen in seqlens:
- min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
- max_seqlen = max(max_seqlen, seqlen)
- seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
- seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
- return cls(
- max_seqlen=max_seqlen,
- min_seqlen=min_seqlen,
- seqstart=seqstart,
- seqstart_py=seqstart_py,
- )
-
- def split(self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None) -> List[torch.Tensor]:
- if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
- raise ValueError(
- f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
- f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
- f" seqstart: {self.seqstart_py}"
- )
- if batch_sizes is None:
- batch_sizes = [1] * (len(self.seqstart_py) - 1)
- split_chunks = []
- it = 0
- for batch_size in batch_sizes:
- split_chunks.append(self.seqstart_py[it + batch_size] - self.seqstart_py[it])
- it += batch_size
- return [
- tensor.reshape([bs, -1, *tensor.shape[2:]]) for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
- ]
diff --git a/videosys/modules/embed.py b/videosys/modules/embed.py
deleted file mode 100644
index 2a166fadaa695c314cd5279754b2b1389136d547..0000000000000000000000000000000000000000
--- a/videosys/modules/embed.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Modified from Meta DiT
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# DiT: https://github.com/facebookresearch/DiT/tree/main
-# GLIDE: https://github.com/openai/glide-text2im
-# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
-# --------------------------------------------------------
-
-
-import math
-
-import numpy as np
-import torch
-from torch import nn
-
-
-class TimestepEmbedder(nn.Module):
- """
- Embeds scalar timesteps into vector representations.
- """
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10000):
- """
- Create sinusoidal timestep embeddings.
- :param t: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an (N, D) Tensor of positional embeddings.
- """
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
- half = dim // 2
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
- device=t.device
- )
- args = t[:, None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-
-class LabelEmbedder(nn.Module):
- """
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
- """
-
- def __init__(self, num_classes, hidden_size, dropout_prob):
- super().__init__()
- use_cfg_embedding = dropout_prob > 0
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
- self.num_classes = num_classes
- self.dropout_prob = dropout_prob
-
- def token_drop(self, labels, force_drop_ids=None):
- """
- Drops labels to enable classifier-free guidance.
- """
- if force_drop_ids is None:
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
- else:
- drop_ids = force_drop_ids == 1
- labels = torch.where(drop_ids, self.num_classes, labels)
- return labels
-
- def forward(self, labels, train, force_drop_ids=None):
- use_dropout = self.dropout_prob > 0
- if (train and use_dropout) or (force_drop_ids is not None):
- labels = self.token_drop(labels, force_drop_ids)
- embeddings = self.embedding_table(labels)
- return embeddings
-
-
-#################################################################################
-# Sine/Cosine Positional Embedding Functions #
-#################################################################################
-# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
-
-
-def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
- """
- grid_size: int of the grid height and width
- return:
- pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
- """
- grid_h = np.arange(grid_size, dtype=np.float32)
- grid_w = np.arange(grid_size, dtype=np.float32)
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
- grid = np.stack(grid, axis=0)
-
- grid = grid.reshape([2, 1, grid_size, grid_size])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token and extra_tokens > 0:
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
- return pos_embed
-
-
-def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
- assert embed_dim % 2 == 0
-
- # use half of dimensions to encode grid_h
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
-
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
- return emb
-
-
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
- """
- embed_dim: output dimension for each position
- pos: a list of positions to be encoded: size (M,)
- out: (M, D)
- """
- assert embed_dim % 2 == 0
- omega = np.arange(embed_dim // 2, dtype=np.float64)
- omega /= embed_dim / 2.0
- omega = 1.0 / 10000**omega # (D/2,)
-
- pos = pos.reshape(-1) # (M,)
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
-
- emb_sin = np.sin(out) # (M, D/2)
- emb_cos = np.cos(out) # (M, D/2)
-
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
- return emb
diff --git a/videosys/modules/layers.py b/videosys/modules/layers.py
deleted file mode 100644
index b717ba12d1685e1ed9e843bb6907db044a229824..0000000000000000000000000000000000000000
--- a/videosys/modules/layers.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# Modified from Meta DiT
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# DiT: https://github.com/facebookresearch/DiT/tree/main
-# GLIDE: https://github.com/openai/glide-text2im
-# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
-# --------------------------------------------------------
-
-
-import torch
-import torch.nn as nn
-import torch.utils.checkpoint
-
-
-def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool):
- if use_kernel:
- try:
- from apex.normalization import FusedLayerNorm
-
- return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps)
- except ImportError:
- raise RuntimeError("FusedLayerNorm not available. Please install apex.")
- else:
- return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
-
-
-def modulate(norm_func, x, shift, scale, use_kernel=False):
- # Suppose x is (N, T, D), shift is (N, D), scale is (N, D)
- dtype = x.dtype
- x = norm_func(x.to(torch.float32)).to(dtype)
- if use_kernel:
- try:
- from videosys.kernels.fused_modulate import fused_modulate
-
- x = fused_modulate(x, scale, shift)
- except ImportError:
- raise RuntimeError("FusedModulate kernel not available. Please install triton.")
- else:
- x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1)
- x = x.to(dtype)
-
- return x
-
-
-class FinalLayer(nn.Module):
- """
- The final layer of DiT.
- """
-
- def __init__(self, hidden_size, patch_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
- x = modulate(self.norm_final, x, shift, scale)
- x = self.linear(x)
- return x
-
-
-class LlamaRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
diff --git a/videosys/modules/__init__.py b/videosys/pipelines/__init__.py
similarity index 100%
rename from videosys/modules/__init__.py
rename to videosys/pipelines/__init__.py
diff --git a/videosys/pipelines/cogvideox/__init__.py b/videosys/pipelines/cogvideox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc64485945b9dd5e02f14229131f3bf58d8e03e3
--- /dev/null
+++ b/videosys/pipelines/cogvideox/__init__.py
@@ -0,0 +1,3 @@
+from .pipeline_cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
+
+__all__ = ["CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"]
diff --git a/videosys/models/cogvideo/pipeline.py b/videosys/pipelines/cogvideox/pipeline_cogvideox.py
similarity index 78%
rename from videosys/models/cogvideo/pipeline.py
rename to videosys/pipelines/cogvideox/pipeline_cogvideox.py
index 5334d9289a8df333dce27608a659b137d2b798eb..4c70fe8d4ab3ead4d591d5fb1765a0400881b9f4 100644
--- a/videosys/models/cogvideo/pipeline.py
+++ b/videosys/pipelines/cogvideox/pipeline_cogvideox.py
@@ -14,100 +14,52 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
-from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from transformers import T5EncoderModel, T5Tokenizer
+from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
+from videosys.models.autoencoders.autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
+from videosys.models.modules.embeddings import get_3d_rotary_pos_embed
+from videosys.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel
+from videosys.schedulers.scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
+from videosys.schedulers.scheduling_dpm_cogvideox import CogVideoXDPMScheduler
+from videosys.utils.logging import logger
from videosys.utils.utils import save_video
-from .autoencoder_kl import AutoencoderKLCogVideoX
-from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
-from .retrieve_timesteps import retrieve_timesteps
-from .scheduling import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-from videosys.core.pab_mgr import (
- PABConfig,
- get_diffusion_skip,
- get_diffusion_skip_timestep,
- set_pab_manager,
- skip_diffusion_timestep,
- update_steps,
-)
-
-
-
-class CogVideoPABConfig(PABConfig):
+class CogVideoXPABConfig(PABConfig):
def __init__(
self,
- steps: int = 150,
+ steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
- spatial_gap: int = 2,
- temporal_broadcast: bool = True,
- temporal_threshold: list = [100, 850],
- temporal_gap: int = 4,
- cross_broadcast: bool = True,
- cross_threshold: list = [100, 850],
- cross_gap: int = 6,
- diffusion_skip: bool = False,
- diffusion_timestep_respacing: list = None,
- diffusion_skip_timestep: list = None,
- mlp_skip: bool = True,
- mlp_spatial_skip_config: dict = {
- 738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
- 714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
- },
- mlp_temporal_skip_config: dict = {
- 738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
- 714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
- },
- full_broadcast: bool = True,
- full_threshold: list = [100, 850],
- full_gap: int = 3,
+ spatial_range: int = 2,
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
- spatial_gap=spatial_gap,
- temporal_broadcast=temporal_broadcast,
- temporal_threshold=temporal_threshold,
- temporal_gap=temporal_gap,
- cross_broadcast=cross_broadcast,
- cross_threshold=cross_threshold,
- cross_gap=cross_gap,
- diffusion_skip=diffusion_skip,
- diffusion_timestep_respacing=diffusion_timestep_respacing,
- diffusion_skip_timestep=diffusion_skip_timestep,
- mlp_skip=mlp_skip,
- mlp_spatial_skip_config=mlp_spatial_skip_config,
- mlp_temporal_skip_config=mlp_temporal_skip_config,
- full_broadcast=full_broadcast,
- full_threshold=full_threshold,
- full_gap=full_gap,
+ spatial_range=spatial_range,
)
-
-class CogVideoConfig:
+class CogVideoXConfig:
def __init__(
self,
- world_size: int = 1,
model_path: str = "THUDM/CogVideoX-2b",
+ world_size: int = 1,
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
enable_pab: bool = False,
- pab_config = CogVideoPABConfig()
+ pab_config=CogVideoXPABConfig(),
):
# ======= engine ========
self.world_size = world_size
# ======= pipeline ========
- self.pipeline_cls = CogVideoPipeline
+ self.pipeline_cls = CogVideoXPipeline
# ======= model ========
self.model_path = model_path
@@ -117,7 +69,7 @@ class CogVideoConfig:
self.pab_config = pab_config
-class CogVideoPipeline(VideoSysPipeline):
+class CogVideoXPipeline(VideoSysPipeline):
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
@@ -126,7 +78,7 @@ class CogVideoPipeline(VideoSysPipeline):
def __init__(
self,
- config: CogVideoConfig,
+ config: CogVideoXConfig,
tokenizer: Optional[T5Tokenizer] = None,
text_encoder: Optional[T5EncoderModel] = None,
vae: Optional[AutoencoderKLCogVideoX] = None,
@@ -165,10 +117,10 @@ class CogVideoPipeline(VideoSysPipeline):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
+ # pab
if config.enable_pab:
set_pab_manager(config.pab_config)
-
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
@@ -185,7 +137,7 @@ class CogVideoPipeline(VideoSysPipeline):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
- device = device or self._device
+ device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -257,7 +209,7 @@ class CogVideoPipeline(VideoSysPipeline):
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
- device = device or self._device
+ device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
@@ -323,36 +275,13 @@ class CogVideoPipeline(VideoSysPipeline):
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
- torch.cuda.empty_cache()
return latents
-
- def decode_latents(self, latents: torch.Tensor, num_seconds: int):
- print("hhhhhhhh")
+
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents
- frames = []
- num_frames = latents.size(2)
- segment_size = num_frames // num_frames # 每段处理的帧数
-
- for i in range(num_frames): # 显存问题,逐帧处理
- start_frame = i * segment_size
- end_frame = start_frame + segment_size if i < num_frames-1 else num_frames
-
- current_latents = latents[:, :, start_frame:end_frame, :, :]
- try:
- current_frames = self.vae.decode(current_latents).sample
- frames.append(current_frames)
- except RuntimeError as e:
- logger.error(f"CUDA out of memory error: {str(e)}")
- raise e
-
- # 清理缓存
- torch.cuda.empty_cache()
-
- self.vae.clear_fake_context_parallel_cache()
-
- frames = torch.cat(frames, dim=2)
+ frames = self.vae.decode(latents).sample
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
@@ -425,6 +354,46 @@ class CogVideoPipeline(VideoSysPipeline):
f" {negative_prompt_embeds.shape}."
)
+ def fuse_qkv_projections(self) -> None:
+ r"""Enables fused QKV projections."""
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ def unfuse_qkv_projections(self) -> None:
+ r"""Disable QKV projection fusion if enabled."""
+ if not self.fusing_transformer:
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
+
+ def _prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ use_real=True,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
+
@property
def guidance_scale(self):
return self._guidance_scale
@@ -444,7 +413,7 @@ class CogVideoPipeline(VideoSysPipeline):
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
- num_frames: int = 48,
+ num_frames: int = 49,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
@@ -538,10 +507,12 @@ class CogVideoPipeline(VideoSysPipeline):
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
- fps = 8
- assert (
- num_frames <= 48 and num_frames % fps == 0 and fps == 8
- ), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
+
+ if num_frames > 49:
+ raise ValueError(
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
+ )
+ update_steps(num_inference_steps)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -598,7 +569,6 @@ class CogVideoPipeline(VideoSysPipeline):
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
- num_frames += 1
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
@@ -611,10 +581,17 @@ class CogVideoPipeline(VideoSysPipeline):
latents,
)
- # 6. Prepare extra step kwargs.
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 7. Denoising loop
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -635,6 +612,7 @@ class CogVideoPipeline(VideoSysPipeline):
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
@@ -678,11 +656,14 @@ class CogVideoPipeline(VideoSysPipeline):
progress_bar.update()
if not output_type == "latent":
- video = self.decode_latents(latents, num_frames // fps)
+ video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
+ # Offload all models
+ self.maybe_free_model_hooks()
+
if not return_dict:
return (video,)
@@ -690,3 +671,82 @@ class CogVideoPipeline(VideoSysPipeline):
def save_video(self, video, output_path):
save_video(video, output_path, fps=8)
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
diff --git a/videosys/pipelines/latte/__init__.py b/videosys/pipelines/latte/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..31884084dbc53ab3a55c28d35307eb57b9b30750
--- /dev/null
+++ b/videosys/pipelines/latte/__init__.py
@@ -0,0 +1,3 @@
+from .pipeline_latte import LatteConfig, LattePABConfig, LattePipeline
+
+__all__ = ["LatteConfig", "LattePipeline", "LattePABConfig"]
diff --git a/videosys/models/latte/pipeline.py b/videosys/pipelines/latte/pipeline_latte.py
similarity index 94%
rename from videosys/models/latte/pipeline.py
rename to videosys/pipelines/latte/pipeline_latte.py
index ad742a89db778539d53c9f20fa73f6c715e0c9e3..a70e8aeaad3d61bbdb6695a6d54f9f114d449cfe 100644
--- a/videosys/models/latte/pipeline.py
+++ b/videosys/pipelines/latte/pipeline_latte.py
@@ -25,20 +25,12 @@ from diffusers.schedulers import DDIMScheduler
from diffusers.utils.torch_utils import randn_tensor
from transformers import T5EncoderModel, T5Tokenizer
-from videosys.core.pab_mgr import (
- PABConfig,
- get_diffusion_skip,
- get_diffusion_skip_timestep,
- set_pab_manager,
- skip_diffusion_timestep,
- update_steps,
-)
+from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
+from videosys.models.transformers.latte_transformer_3d import LatteT2V
from videosys.utils.logging import logger
from videosys.utils.utils import save_video
-from .latte_t2v import LatteT2V
-
class LattePABConfig(PABConfig):
def __init__(
@@ -46,25 +38,22 @@ class LattePABConfig(PABConfig):
steps: int = 50,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 800],
- spatial_gap: int = 2,
+ spatial_range: int = 2,
temporal_broadcast: bool = True,
temporal_threshold: list = [100, 800],
- temporal_gap: int = 3,
+ temporal_range: int = 3,
cross_broadcast: bool = True,
cross_threshold: list = [100, 800],
- cross_gap: int = 6,
- diffusion_skip: bool = False,
- diffusion_timestep_respacing: list = None,
- diffusion_skip_timestep: list = None,
- mlp_skip: bool = True,
- mlp_spatial_skip_config: dict = {
+ cross_range: int = 6,
+ mlp_broadcast: bool = True,
+ mlp_spatial_broadcast_config: dict = {
720: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
640: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
560: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
480: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
400: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
},
- mlp_temporal_skip_config: dict = {
+ mlp_temporal_broadcast_config: dict = {
720: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
640: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
560: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
@@ -76,27 +65,24 @@ class LattePABConfig(PABConfig):
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
- spatial_gap=spatial_gap,
+ spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
- temporal_gap=temporal_gap,
+ temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=cross_threshold,
- cross_gap=cross_gap,
- diffusion_skip=diffusion_skip,
- diffusion_timestep_respacing=diffusion_timestep_respacing,
- diffusion_skip_timestep=diffusion_skip_timestep,
- mlp_skip=mlp_skip,
- mlp_spatial_skip_config=mlp_spatial_skip_config,
- mlp_temporal_skip_config=mlp_temporal_skip_config,
+ cross_range=cross_range,
+ mlp_broadcast=mlp_broadcast,
+ mlp_spatial_broadcast_config=mlp_spatial_broadcast_config,
+ mlp_temporal_broadcast_config=mlp_temporal_broadcast_config,
)
class LatteConfig:
def __init__(
self,
- world_size: int = 1,
model_path: str = "maxin-cn/Latte-1",
+ world_size: int = 1,
enable_vae_temporal_decoder: bool = True,
# ======= scheduler ========
beta_start: float = 0.0001,
@@ -738,33 +724,7 @@ class LattePipeline(VideoSysPipeline):
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
- # timesteps = self.scheduler.timesteps # NOTE change timestep_respacing here
-
- if get_diffusion_skip() and get_diffusion_skip_timestep() is not None:
- # TODO add assertion for timestep_respacing
- # timestep_respacing = get_diffusion_skip_timestep()
- # timesteps = space_timesteps(1000, timestep_respacing)
-
- diffusion_skip_timestep = get_diffusion_skip_timestep()
- timesteps = skip_diffusion_timestep(self.scheduler.timesteps, diffusion_skip_timestep)
-
- self.scheduler.set_timesteps(num_inference_steps, device=device)
- orignal_timesteps = self.scheduler.timesteps
-
- if verbose and dist.get_rank() == 0:
- print("============================")
- print("skip diffusion steps!!!")
- print("============================")
- print(f"orignal sample timesteps: {orignal_timesteps}")
- print(f"orignal diffusion steps: {len(orignal_timesteps)}")
- print("============================")
- print(f"skip diffusion steps: {get_diffusion_skip_timestep()}")
- print(f"sample timesteps: {timesteps}")
- print(f"num_inference_steps: {len(timesteps)}")
- print("============================")
- else:
- self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps = self.scheduler.timesteps
+ timesteps = self.scheduler.timesteps
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
diff --git a/videosys/pipelines/open_sora/__init__.py b/videosys/pipelines/open_sora/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..46536c737c316a966a6c0603e5b75fa564b14ce2
--- /dev/null
+++ b/videosys/pipelines/open_sora/__init__.py
@@ -0,0 +1,3 @@
+from .pipeline_open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
+
+__all__ = ["OpenSoraConfig", "OpenSoraPipeline", "OpenSoraPABConfig"]
diff --git a/videosys/models/open_sora/datasets.py b/videosys/pipelines/open_sora/data_process.py
similarity index 95%
rename from videosys/models/open_sora/datasets.py
rename to videosys/pipelines/open_sora/data_process.py
index a75a711cfb55ff7b3802b7a596101712dc266de5..cd14da9e4b8dea29d064d70790cc8046afd1fa88 100644
--- a/videosys/models/open_sora/datasets.py
+++ b/videosys/pipelines/open_sora/data_process.py
@@ -786,3 +786,22 @@ def read_image_from_path(path, transform=None, transform_name="center", num_fram
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
+
+
+def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype):
+ if info_type is None:
+ return dict()
+ elif info_type == "PixArtMS":
+ hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1)
+ ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1)
+ return dict(ar=ar, hw=hw)
+ elif info_type in ["STDiT2", "OpenSora"]:
+ fps = fps if num_frames > 1 else IMG_FPS
+ fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size)
+ height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size)
+ width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size)
+ num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size)
+ ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size)
+ return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps)
+ else:
+ raise NotImplementedError
diff --git a/videosys/pipelines/open_sora/pipeline_open_sora.py b/videosys/pipelines/open_sora/pipeline_open_sora.py
new file mode 100644
index 0000000000000000000000000000000000000000..7931a95cae35788151944ec92350635e9bc7264d
--- /dev/null
+++ b/videosys/pipelines/open_sora/pipeline_open_sora.py
@@ -0,0 +1,898 @@
+import html
+import json
+import os
+import re
+from typing import Optional, Tuple, Union
+
+import ftfy
+import torch
+from diffusers.models import AutoencoderKL
+from transformers import AutoTokenizer, T5EncoderModel
+
+from videosys.core.pab_mgr import PABConfig, set_pab_manager
+from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
+from videosys.models.autoencoders.autoencoder_kl_open_sora import OpenSoraVAE_V1_2
+from videosys.models.transformers.open_sora_transformer_3d import STDiT3_XL_2
+from videosys.schedulers.scheduling_rflow_open_sora import RFLOW
+from videosys.utils.utils import save_video
+
+from .data_process import get_image_size, get_num_frames, prepare_multi_resolution_info, read_from_path
+
+os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+
+BAD_PUNCT_REGEX = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+) # noqa
+
+
+class OpenSoraPABConfig(PABConfig):
+ def __init__(
+ self,
+ steps: int = 50,
+ spatial_broadcast: bool = True,
+ spatial_threshold: list = [450, 930],
+ spatial_range: int = 2,
+ temporal_broadcast: bool = True,
+ temporal_threshold: list = [450, 930],
+ temporal_range: int = 4,
+ cross_broadcast: bool = True,
+ cross_threshold: list = [450, 930],
+ cross_range: int = 6,
+ mlp_broadcast: bool = True,
+ mlp_spatial_broadcast_config: dict = {
+ 676: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 788: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 864: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ },
+ mlp_temporal_broadcast_config: dict = {
+ 676: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 788: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ 864: {"block": [0, 1, 2, 3, 4], "skip_count": 2},
+ },
+ ):
+ super().__init__(
+ steps=steps,
+ spatial_broadcast=spatial_broadcast,
+ spatial_threshold=spatial_threshold,
+ spatial_range=spatial_range,
+ temporal_broadcast=temporal_broadcast,
+ temporal_threshold=temporal_threshold,
+ temporal_range=temporal_range,
+ cross_broadcast=cross_broadcast,
+ cross_threshold=cross_threshold,
+ cross_range=cross_range,
+ mlp_broadcast=mlp_broadcast,
+ mlp_spatial_broadcast_config=mlp_spatial_broadcast_config,
+ mlp_temporal_broadcast_config=mlp_temporal_broadcast_config,
+ )
+
+
+class OpenSoraConfig:
+ def __init__(
+ self,
+ model_path: str = "hpcai-tech/OpenSora-STDiT-v3",
+ world_size: int = 1,
+ vae: str = "hpcai-tech/OpenSora-VAE-v1.2",
+ text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
+ # ======= scheduler =======
+ num_sampling_steps: int = 30,
+ cfg_scale: float = 7.0,
+ # ======= vae ========
+ tiling_size: int = 4,
+ # ======= pab ========
+ enable_pab: bool = False,
+ pab_config: PABConfig = OpenSoraPABConfig(),
+ ):
+ # ======= engine ========
+ self.world_size = world_size
+
+ # ======= pipeline ========
+ self.pipeline_cls = OpenSoraPipeline
+ self.transformer = model_path
+ self.vae = vae
+ self.text_encoder = text_encoder
+
+ # ======= scheduler ========
+ self.num_sampling_steps = num_sampling_steps
+ self.cfg_scale = cfg_scale
+
+ # ======= vae ========
+ self.tiling_size = tiling_size
+
+ # ======= pab ========
+ self.enable_pab = enable_pab
+ self.pab_config = pab_config
+
+
+class OpenSoraPipeline(VideoSysPipeline):
+ r"""
+ Pipeline for text-to-image generation using PixArt-Alpha.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`Transformer2DModel`]):
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ config: OpenSoraConfig,
+ text_encoder: Optional[T5EncoderModel] = None,
+ tokenizer: Optional[AutoTokenizer] = None,
+ vae: Optional[AutoencoderKL] = None,
+ transformer: Optional[STDiT3_XL_2] = None,
+ scheduler: Optional[RFLOW] = None,
+ device: torch.device = torch.device("cuda"),
+ dtype: torch.dtype = torch.bfloat16,
+ ):
+ super().__init__()
+ self._config = config
+ self._device = device
+ self._dtype = dtype
+
+ # initialize the model if not provided
+ if text_encoder is None:
+ text_encoder = T5EncoderModel.from_pretrained(config.text_encoder).to(dtype)
+ if tokenizer is None:
+ tokenizer = AutoTokenizer.from_pretrained(config.text_encoder)
+ if vae is None:
+ vae = OpenSoraVAE_V1_2(
+ from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
+ micro_frame_size=17,
+ micro_batch_size=config.tiling_size,
+ ).to(dtype)
+ if transformer is None:
+ transformer = STDiT3_XL_2(
+ from_pretrained="hpcai-tech/OpenSora-STDiT-v3",
+ qk_norm=True,
+ enable_flash_attn=True,
+ enable_layernorm_kernel=True,
+ in_channels=vae.out_channels,
+ caption_channels=text_encoder.config.d_model,
+ model_max_length=300,
+ ).to(device, dtype)
+ if scheduler is None:
+ scheduler = RFLOW(
+ use_timestep_transform=True, num_sampling_steps=config.num_sampling_steps, cfg_scale=config.cfg_scale
+ )
+
+ # pab
+ if config.enable_pab:
+ set_pab_manager(config.pab_config)
+
+ # set eval and device
+ self.set_eval_and_device(device, text_encoder, vae, transformer)
+
+ self.register_modules(
+ text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, tokenizer=tokenizer
+ )
+
+ def get_text_embeddings(self, texts):
+ text_tokens_and_mask = self.tokenizer(
+ texts,
+ max_length=300,
+ padding="max_length",
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+
+ input_ids = text_tokens_and_mask["input_ids"].to(self.device)
+ attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
+ with torch.no_grad():
+ text_encoder_embs = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )["last_hidden_state"].detach()
+ return text_encoder_embs, attention_mask
+
+ def encode_prompt(self, text):
+ caption_embs, emb_masks = self.get_text_embeddings(text)
+ caption_embs = caption_embs[:, None]
+ return dict(y=caption_embs, mask=emb_masks)
+
+ def null_embed(self, n):
+ null_y = self.transformer.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
+ return null_y
+
+ @staticmethod
+ def _basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+ def _clean_caption(self, caption):
+ import urllib.parse as ul
+
+ from bs4 import BeautifulSoup
+
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = self._basic_clean(caption)
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def text_preprocessing(self, text, use_text_preprocessing: bool = True):
+ if use_text_preprocessing:
+ # The exact text cleaning as was in the training stage:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ return text
+ else:
+ return text.lower().strip()
+
+ @torch.no_grad()
+ def generate(
+ self,
+ prompt: str,
+ resolution="480p",
+ aspect_ratio="9:16",
+ num_frames: int = 51,
+ loop: int = 1,
+ llm_refine: bool = False,
+ negative_prompt: str = "",
+ ms: Optional[str] = "",
+ refs: Optional[str] = "",
+ aes: float = 6.5,
+ flow: Optional[float] = None,
+ camera_motion: Optional[float] = None,
+ condition_frame_length: int = 5,
+ align: int = 5,
+ condition_frame_edit: float = 0.0,
+ return_dict: bool = True,
+ verbose: bool = True,
+ ) -> Union[VideoSysPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ resolution (`str`, *optional*, defaults to `"480p"`):
+ The resolution of the generated video.
+ aspect_ratio (`str`, *optional*, defaults to `"9:16"`):
+ The aspect ratio of the generated video.
+ num_frames (`int`, *optional*, defaults to 51):
+ The number of frames to generate.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # == basic ==
+ fps = 24
+ image_size = get_image_size(resolution, aspect_ratio)
+ num_frames = get_num_frames(num_frames)
+
+ # == prepare batch prompts ==
+ batch_prompts = [prompt]
+ ms = [ms]
+ refs = [refs]
+
+ # == get json from prompts ==
+ batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)
+
+ # == get reference for condition ==
+ refs = collect_references_batch(refs, self.vae, image_size)
+
+ # == multi-resolution info ==
+ model_args = prepare_multi_resolution_info(
+ "OpenSora", len(batch_prompts), image_size, num_frames, fps, self._device, self._dtype
+ )
+
+ # == process prompts step by step ==
+ # 0. split prompt
+ # each element in the list is [prompt_segment_list, loop_idx_list]
+ batched_prompt_segment_list = []
+ batched_loop_idx_list = []
+ for prompt in batch_prompts:
+ prompt_segment_list, loop_idx_list = split_prompt(prompt)
+ batched_prompt_segment_list.append(prompt_segment_list)
+ batched_loop_idx_list.append(loop_idx_list)
+
+ # 1. refine prompt by openai
+ # if llm_refine:
+ # only call openai API when
+ # 1. seq parallel is not enabled
+ # 2. seq parallel is enabled and the process is rank 0
+ # if not enable_sequence_parallelism or (enable_sequence_parallelism and coordinator.is_master()):
+ # for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
+ # batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list)
+
+ # # sync the prompt if using seq parallel
+ # if enable_sequence_parallelism:
+ # coordinator.block_all()
+ # prompt_segment_length = [
+ # len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list
+ # ]
+
+ # # flatten the prompt segment list
+ # batched_prompt_segment_list = [
+ # prompt_segment
+ # for prompt_segment_list in batched_prompt_segment_list
+ # for prompt_segment in prompt_segment_list
+ # ]
+
+ # # create a list of size equal to world size
+ # broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size
+ # dist.broadcast_object_list(broadcast_obj_list, 0)
+
+ # # recover the prompt list
+ # batched_prompt_segment_list = []
+ # segment_start_idx = 0
+ # all_prompts = broadcast_obj_list[0]
+ # for num_segment in prompt_segment_length:
+ # batched_prompt_segment_list.append(
+ # all_prompts[segment_start_idx : segment_start_idx + num_segment]
+ # )
+ # segment_start_idx += num_segment
+
+ # 2. append score
+ for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
+ batched_prompt_segment_list[idx] = append_score_to_prompts(
+ prompt_segment_list,
+ aes=aes,
+ flow=flow,
+ camera_motion=camera_motion,
+ )
+
+ # 3. clean prompt with T5
+ for idx, prompt_segment_list in enumerate(batched_prompt_segment_list):
+ batched_prompt_segment_list[idx] = [self.text_preprocessing(prompt) for prompt in prompt_segment_list]
+
+ # 4. merge to obtain the final prompt
+ batch_prompts = []
+ for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list):
+ batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list))
+
+ # == Iter over loop generation ==
+ video_clips = []
+ for loop_i in range(loop):
+ # == get prompt for loop i ==
+ batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)
+
+ # == add condition frames for loop ==
+ if loop_i > 0:
+ refs, ms = append_generated(
+ self.vae, video_clips[-1], refs, ms, loop_i, condition_frame_length, condition_frame_edit
+ )
+
+ # == sampling ==
+ input_size = (num_frames, *image_size)
+ latent_size = self.vae.get_latent_size(input_size)
+ z = torch.randn(
+ len(batch_prompts), self.vae.out_channels, *latent_size, device=self._device, dtype=self._dtype
+ )
+ model_args.update(self.encode_prompt(batch_prompts_loop))
+ y_null = self.null_embed(len(batch_prompts_loop))
+
+ masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)
+ samples = self.scheduler.sample(
+ self.transformer,
+ z=z,
+ model_args=model_args,
+ y_null=y_null,
+ device=self._device,
+ progress=verbose,
+ mask=masks,
+ )
+ samples = self.vae.decode(samples.to(self._dtype), num_frames=num_frames)
+ video_clips.append(samples)
+
+ for i in range(1, loop):
+ video_clips[i] = video_clips[i][:, dframe_to_frame(condition_frame_length) :]
+ video = torch.cat(video_clips, dim=1)
+
+ low, high = -1, 1
+ video.clamp_(min=low, max=high)
+ video.sub_(low).div_(max(high - low, 1e-5))
+ video = video.mul(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 4, 1).to("cpu", torch.uint8)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return VideoSysPipelineOutput(video=video)
+
+ def save_video(self, video, output_path):
+ save_video(video, output_path, fps=24)
+
+
+def load_prompts(prompt_path, start_idx=None, end_idx=None):
+ with open(prompt_path, "r") as f:
+ prompts = [line.strip() for line in f.readlines()]
+ prompts = prompts[start_idx:end_idx]
+ return prompts
+
+
+def get_save_path_name(
+ save_dir,
+ sample_name=None, # prefix
+ sample_idx=None, # sample index
+ prompt=None, # used prompt
+ prompt_as_path=False, # use prompt as path
+ num_sample=1, # number of samples to generate for one prompt
+ k=None, # kth sample
+):
+ if sample_name is None:
+ sample_name = "" if prompt_as_path else "sample"
+ sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx:04d}"
+ save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix[:50]}")
+ if num_sample != 1:
+ save_path = f"{save_path}-{k}"
+ return save_path
+
+
+def get_eval_save_path_name(
+ save_dir,
+ id, # add id parameter
+ sample_name=None, # prefix
+ sample_idx=None, # sample index
+ prompt=None, # used prompt
+ prompt_as_path=False, # use prompt as path
+ num_sample=1, # number of samples to generate for one prompt
+ k=None, # kth sample
+):
+ if sample_name is None:
+ sample_name = "" if prompt_as_path else "sample"
+ save_path = os.path.join(save_dir, f"{id}")
+ if num_sample != 1:
+ save_path = f"{save_path}-{k}"
+ return save_path
+
+
+def append_score_to_prompts(prompts, aes=None, flow=None, camera_motion=None):
+ new_prompts = []
+ for prompt in prompts:
+ new_prompt = prompt
+ if aes is not None and "aesthetic score:" not in prompt:
+ new_prompt = f"{new_prompt} aesthetic score: {aes:.1f}."
+ if flow is not None and "motion score:" not in prompt:
+ new_prompt = f"{new_prompt} motion score: {flow:.1f}."
+ if camera_motion is not None and "camera motion:" not in prompt:
+ new_prompt = f"{new_prompt} camera motion: {camera_motion}."
+ new_prompts.append(new_prompt)
+ return new_prompts
+
+
+def extract_json_from_prompts(prompts, reference, mask_strategy):
+ ret_prompts = []
+ for i, prompt in enumerate(prompts):
+ parts = re.split(r"(?=[{])", prompt)
+ assert len(parts) <= 2, f"Invalid prompt: {prompt}"
+ ret_prompts.append(parts[0])
+ if len(parts) > 1:
+ additional_info = json.loads(parts[1])
+ for key in additional_info:
+ assert key in ["reference_path", "mask_strategy"], f"Invalid key: {key}"
+ if key == "reference_path":
+ reference[i] = additional_info[key]
+ elif key == "mask_strategy":
+ mask_strategy[i] = additional_info[key]
+ return ret_prompts, reference, mask_strategy
+
+
+def collect_references_batch(reference_paths, vae, image_size):
+ refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
+ for reference_path in reference_paths:
+ if reference_path == "":
+ refs_x.append([])
+ continue
+ ref_path = reference_path.split(";")
+ ref = []
+ for r_path in ref_path:
+ r = read_from_path(r_path, image_size, transform_name="resize_crop")
+ r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
+ r_x = r_x.squeeze(0)
+ ref.append(r_x)
+ refs_x.append(ref)
+ return refs_x
+
+
+def extract_prompts_loop(prompts, num_loop):
+ ret_prompts = []
+ for prompt in prompts:
+ if prompt.startswith("|0|"):
+ prompt_list = prompt.split("|")[1:]
+ text_list = []
+ for i in range(0, len(prompt_list), 2):
+ start_loop = int(prompt_list[i])
+ text = prompt_list[i + 1]
+ end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + 1
+ text_list.extend([text] * (end_loop - start_loop))
+ prompt = text_list[num_loop]
+ ret_prompts.append(prompt)
+ return ret_prompts
+
+
+def split_prompt(prompt_text):
+ if prompt_text.startswith("|0|"):
+ # this is for prompts which look like
+ # |0| a beautiful day |1| a sunny day |2| a rainy day
+ # we want to parse it into a list of prompts with the loop index
+ prompt_list = prompt_text.split("|")[1:]
+ text_list = []
+ loop_idx = []
+ for i in range(0, len(prompt_list), 2):
+ start_loop = int(prompt_list[i])
+ text = prompt_list[i + 1].strip()
+ text_list.append(text)
+ loop_idx.append(start_loop)
+ return text_list, loop_idx
+ else:
+ return [prompt_text], None
+
+
+def merge_prompt(text_list, loop_idx_list=None):
+ if loop_idx_list is None:
+ return text_list[0]
+ else:
+ prompt = ""
+ for i, text in enumerate(text_list):
+ prompt += f"|{loop_idx_list[i]}|{text}"
+ return prompt
+
+
+MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"]
+
+
+def parse_mask_strategy(mask_strategy):
+ mask_batch = []
+ if mask_strategy == "" or mask_strategy is None:
+ return mask_batch
+
+ mask_strategy = mask_strategy.split(";")
+ for mask in mask_strategy:
+ mask_group = mask.split(",")
+ num_group = len(mask_group)
+ assert num_group >= 1 and num_group <= 6, f"Invalid mask strategy: {mask}"
+ mask_group.extend(MASK_DEFAULT[num_group:])
+ for i in range(5):
+ mask_group[i] = int(mask_group[i])
+ mask_group[5] = float(mask_group[5])
+ mask_batch.append(mask_group)
+ return mask_batch
+
+
+def find_nearest_point(value, point, max_value):
+ t = value // point
+ if value % point > point / 2 and t < max_value // point - 1:
+ t += 1
+ return t * point
+
+
+def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
+ masks = []
+ no_mask = True
+ for i, mask_strategy in enumerate(mask_strategys):
+ no_mask = False
+ mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
+ mask_strategy = parse_mask_strategy(mask_strategy)
+ for mst in mask_strategy:
+ loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst
+ if loop_id != loop_i:
+ continue
+ ref = refs_x[i][m_id]
+
+ if m_ref_start < 0:
+ # ref: [C, T, H, W]
+ m_ref_start = ref.shape[1] + m_ref_start
+ if m_target_start < 0:
+ # z: [B, C, T, H, W]
+ m_target_start = z.shape[2] + m_target_start
+ if align is not None:
+ m_ref_start = find_nearest_point(m_ref_start, align, ref.shape[1])
+ m_target_start = find_nearest_point(m_target_start, align, z.shape[2])
+ m_length = min(m_length, z.shape[2] - m_target_start, ref.shape[1] - m_ref_start)
+ z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
+ mask[m_target_start : m_target_start + m_length] = edit_ratio
+ masks.append(mask)
+ if no_mask:
+ return None
+ masks = torch.stack(masks)
+ return masks
+
+
+def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i, condition_frame_length, condition_frame_edit):
+ ref_x = vae.encode(generated_video)
+ for j, refs in enumerate(refs_x):
+ if refs is None:
+ refs_x[j] = [ref_x[j]]
+ else:
+ refs.append(ref_x[j])
+ if mask_strategy[j] is None or mask_strategy[j] == "":
+ mask_strategy[j] = ""
+ else:
+ mask_strategy[j] += ";"
+ mask_strategy[
+ j
+ ] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}"
+ return refs_x, mask_strategy
+
+
+def dframe_to_frame(num):
+ assert num % 5 == 0, f"Invalid num: {num}"
+ return num // 5 * 17
+
+
+OPENAI_CLIENT = None
+REFINE_PROMPTS = None
+REFINE_PROMPTS_PATH = "assets/texts/t2v_pllava.txt"
+REFINE_PROMPTS_TEMPLATE = """
+You need to refine user's input prompt. The user's input prompt is used for video generation task. You need to refine the user's prompt to make it more suitable for the task. Here are some examples of refined prompts:
+{}
+
+The refined prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The refined prompt should be in English.
+"""
+RANDOM_PROMPTS = None
+RANDOM_PROMPTS_TEMPLATE = """
+You need to generate one input prompt for video generation task. The prompt should be suitable for the task. Here are some examples of refined prompts:
+{}
+
+The prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The prompt should be in English.
+"""
+
+
+def get_openai_response(sys_prompt, usr_prompt, model="gpt-4o"):
+ global OPENAI_CLIENT
+ if OPENAI_CLIENT is None:
+ from openai import OpenAI
+
+ OPENAI_CLIENT = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
+
+ completion = OPENAI_CLIENT.chat.completions.create(
+ model=model,
+ messages=[
+ {
+ "role": "system",
+ "content": sys_prompt,
+ }, # <-- This is the system message that provides context to the model
+ {
+ "role": "user",
+ "content": usr_prompt,
+ }, # <-- This is the user message for which the model will generate a response
+ ],
+ )
+
+ return completion.choices[0].message.content
+
+
+def get_random_prompt_by_openai():
+ global RANDOM_PROMPTS
+ if RANDOM_PROMPTS is None:
+ examples = load_prompts(REFINE_PROMPTS_PATH)
+ RANDOM_PROMPTS = RANDOM_PROMPTS_TEMPLATE.format("\n".join(examples))
+
+ response = get_openai_response(RANDOM_PROMPTS, "Generate one example.")
+ return response
+
+
+def refine_prompt_by_openai(prompt):
+ global REFINE_PROMPTS
+ if REFINE_PROMPTS is None:
+ examples = load_prompts(REFINE_PROMPTS_PATH)
+ REFINE_PROMPTS = REFINE_PROMPTS_TEMPLATE.format("\n".join(examples))
+
+ response = get_openai_response(REFINE_PROMPTS, prompt)
+ return response
+
+
+def has_openai_key():
+ return "OPENAI_API_KEY" in os.environ
+
+
+def refine_prompts_by_openai(prompts):
+ new_prompts = []
+ for prompt in prompts:
+ try:
+ if prompt.strip() == "":
+ new_prompt = get_random_prompt_by_openai()
+ print(f"[Info] Empty prompt detected, generate random prompt: {new_prompt}")
+ else:
+ new_prompt = refine_prompt_by_openai(prompt)
+ print(f"[Info] Refine prompt: {prompt} -> {new_prompt}")
+ new_prompts.append(new_prompt)
+ except Exception as e:
+ print(f"[Warning] Failed to refine prompt: {prompt} due to {e}")
+ new_prompts.append(prompt)
+ return new_prompts
+
+
+def add_watermark(
+ input_video_path, watermark_image_path="./assets/images/watermark/watermark.png", output_video_path=None
+):
+ # execute this command in terminal with subprocess
+ # return if the process is successful
+ if output_video_path is None:
+ output_video_path = input_video_path.replace(".mp4", "_watermark.mp4")
+ cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path} -filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}'
+ exit_code = os.system(cmd)
+ is_success = exit_code == 0
+ return is_success
diff --git a/videosys/pipelines/open_sora_plan/__init__.py b/videosys/pipelines/open_sora_plan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1ddb8e63b7e16899daa5c2f906d8fb52ea0c5f
--- /dev/null
+++ b/videosys/pipelines/open_sora_plan/__init__.py
@@ -0,0 +1,3 @@
+from .pipeline_open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
+
+__all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanPABConfig"]
diff --git a/videosys/models/open_sora_plan/pipeline.py b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py
similarity index 96%
rename from videosys/models/open_sora_plan/pipeline.py
rename to videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py
index 004486d671306fe769713195ab438c15c62d72c0..530e966245d597bcad2b48f88ee0f067c38dd7f2 100644
--- a/videosys/models/open_sora_plan/pipeline.py
+++ b/videosys/pipelines/open_sora_plan/pipeline_open_sora_plan.py
@@ -24,20 +24,13 @@ from diffusers.schedulers import PNDMScheduler
from diffusers.utils.torch_utils import randn_tensor
from transformers import T5EncoderModel, T5Tokenizer
-from videosys.core.pab_mgr import (
- PABConfig,
- get_diffusion_skip,
- get_diffusion_skip_timestep,
- set_pab_manager,
- skip_diffusion_timestep,
- update_steps,
-)
+from videosys.core.pab_mgr import PABConfig, set_pab_manager, update_steps
from videosys.core.pipeline import VideoSysPipeline, VideoSysPipelineOutput
from videosys.utils.logging import logger
from videosys.utils.utils import save_video
-from .ae import ae_stride_config, getae_wrapper
-from .latte import LatteT2V
+from ...models.autoencoders.autoencoder_kl_open_sora_plan import ae_stride_config, getae_wrapper
+from ...models.transformers.open_sora_plan_transformer_3d import LatteT2V
EXAMPLE_DOC_STRING = """
Examples:
@@ -62,18 +55,15 @@ class OpenSoraPlanPABConfig(PABConfig):
steps: int = 150,
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
- spatial_gap: int = 2,
+ spatial_range: int = 2,
temporal_broadcast: bool = True,
temporal_threshold: list = [100, 850],
- temporal_gap: int = 4,
+ temporal_range: int = 4,
cross_broadcast: bool = True,
cross_threshold: list = [100, 850],
- cross_gap: int = 6,
- diffusion_skip: bool = False,
- diffusion_timestep_respacing: list = None,
- diffusion_skip_timestep: list = None,
- mlp_skip: bool = True,
- mlp_spatial_skip_config: dict = {
+ cross_range: int = 6,
+ mlp_broadcast: bool = True,
+ mlp_spatial_broadcast_config: dict = {
738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
690: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
@@ -89,7 +79,7 @@ class OpenSoraPlanPABConfig(PABConfig):
450: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
426: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
},
- mlp_temporal_skip_config: dict = {
+ mlp_temporal_broadcast_config: dict = {
738: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
714: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
690: {"block": [0, 1, 2, 3, 4, 5, 6], "skip_count": 2},
@@ -110,27 +100,24 @@ class OpenSoraPlanPABConfig(PABConfig):
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
- spatial_gap=spatial_gap,
+ spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
- temporal_gap=temporal_gap,
+ temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=cross_threshold,
- cross_gap=cross_gap,
- diffusion_skip=diffusion_skip,
- diffusion_timestep_respacing=diffusion_timestep_respacing,
- diffusion_skip_timestep=diffusion_skip_timestep,
- mlp_skip=mlp_skip,
- mlp_spatial_skip_config=mlp_spatial_skip_config,
- mlp_temporal_skip_config=mlp_temporal_skip_config,
+ cross_range=cross_range,
+ mlp_broadcast=mlp_broadcast,
+ mlp_spatial_broadcast_config=mlp_spatial_broadcast_config,
+ mlp_temporal_broadcast_config=mlp_temporal_broadcast_config,
)
class OpenSoraPlanConfig:
def __init__(
self,
- world_size: int = 1,
model_path: str = "LanguageBind/Open-Sora-Plan-v1.1.0",
+ world_size: int = 1,
num_frames: int = 65,
ae: str = "CausalVAEModel_4x8x8",
text_encoder: str = "DeepFloyd/t5-v1_1-xxl",
@@ -799,18 +786,6 @@ class OpenSoraPlanPipeline(VideoSysPipeline):
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
-
- if get_diffusion_skip() and get_diffusion_skip_timestep() is not None:
- diffusion_skip_timestep = get_diffusion_skip_timestep()
-
- # warmup_timesteps = timesteps[:num_warmup_steps]
- # after_warmup_timesteps = skip_diffusion_timestep(timesteps[num_warmup_steps:], diffusion_skip_timestep)
- # timesteps = torch.cat((warmup_timesteps, after_warmup_timesteps))
-
- timesteps = skip_diffusion_timestep(timesteps, diffusion_skip_timestep)
-
- self.scheduler.set_timesteps(num_inference_steps, device=device)
-
progress_wrap = tqdm.tqdm if verbose and dist.get_rank() == 0 else (lambda x: x)
for i, t in progress_wrap(list(enumerate(timesteps))):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
diff --git a/videosys/schedulers/__init__.py b/videosys/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/videosys/schedulers/scheduling_ddim_cogvideox.py b/videosys/schedulers/scheduling_ddim_cogvideox.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed1c99dd4b5d5fe6143da0155acb1979704bae55
--- /dev/null
+++ b/videosys/schedulers/scheduling_ddim_cogvideox.py
@@ -0,0 +1,443 @@
+# Adapted from CogVideo
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# CogVideo: https://github.com/THUDM/CogVideo
+# diffusers: https://github.com/huggingface/diffusers
+# --------------------------------------------------------
+
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.Tensor
+ pred_original_sample: Optional[torch.Tensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+def rescale_zero_terminal_snr(alphas_cumprod):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.Tensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.Tensor`: rescaled betas with zero terminal SNR
+ """
+
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+
+ return alphas_bar
+
+
+class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085,
+ beta_end: float = 0.0120,
+ beta_schedule: str = "scaled_linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ snr_shift_scale: float = 3.0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # Modify: SNR shift following SD3
+ self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int,
+ sample: torch.Tensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # To make style tests pass, commented out `pred_epsilon` as it is an unused variable
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ # pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5
+ b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t
+
+ prev_sample = a_t * sample + b_t * pred_original_sample
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
+ # for the subsequent add_noise calls
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/videosys/models/cogvideo/scheduling.py b/videosys/schedulers/scheduling_dpm_cogvideox.py
similarity index 57%
rename from videosys/models/cogvideo/scheduling.py
rename to videosys/schedulers/scheduling_dpm_cogvideox.py
index 06a4e0f01f250060a5791a494a50c7186908b55b..3209dbe77fe84750577f4f0ebdd930576fe0d7e2 100644
--- a/videosys/models/cogvideo/scheduling.py
+++ b/videosys/schedulers/scheduling_dpm_cogvideox.py
@@ -6,9 +6,6 @@
# References:
# CogVideo: https://github.com/THUDM/CogVideo
# diffusers: https://github.com/huggingface/diffusers
-
-# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
-# and https://github.com/hojonathanho/diffusion
# --------------------------------------------------------
@@ -21,6 +18,7 @@ import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from diffusers.utils import BaseOutput
+from diffusers.utils.torch_utils import randn_tensor
@dataclass
@@ -118,334 +116,6 @@ def rescale_zero_terminal_snr(alphas_cumprod):
return alphas_bar
-class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
- """
- `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
- non-Markovian guidance.
-
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
- methods the library implements for all schedulers such as loading and saving.
-
- Args:
- num_train_timesteps (`int`, defaults to 1000):
- The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.0001):
- The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.02):
- The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
- trained_betas (`np.ndarray`, *optional*):
- Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
- clip_sample (`bool`, defaults to `True`):
- Clip the predicted sample for numerical stability.
- clip_sample_range (`float`, defaults to 1.0):
- The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
- set_alpha_to_one (`bool`, defaults to `True`):
- Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
- there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
- otherwise it uses the alpha value at step 0.
- steps_offset (`int`, defaults to 0):
- An offset added to the inference steps, as required by some model families.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
- thresholding (`bool`, defaults to `False`):
- Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
- as Stable Diffusion.
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
- The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
- sample_max_value (`float`, defaults to 1.0):
- The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
- timestep_spacing (`str`, defaults to `"leading"`):
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
- rescale_betas_zero_snr (`bool`, defaults to `False`):
- Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
- dark samples instead of limiting it to samples with medium brightness. Loosely related to
- [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
- """
-
- _compatibles = [e.name for e in KarrasDiffusionSchedulers]
- order = 1
-
- @register_to_config
- def __init__(
- self,
- num_train_timesteps: int = 1000,
- beta_start: float = 0.00085,
- beta_end: float = 0.0120,
- beta_schedule: str = "scaled_linear",
- trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
- clip_sample: bool = True,
- set_alpha_to_one: bool = True,
- steps_offset: int = 0,
- prediction_type: str = "epsilon",
- clip_sample_range: float = 1.0,
- sample_max_value: float = 1.0,
- timestep_spacing: str = "leading",
- rescale_betas_zero_snr: bool = False,
- snr_shift_scale: float = 3.0,
- ):
- if trained_betas is not None:
- self.betas = torch.tensor(trained_betas, dtype=torch.float32)
- elif beta_schedule == "linear":
- self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
- elif beta_schedule == "scaled_linear":
- # this schedule is very specific to the latent diffusion model.
- self.betas = (
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
- )
- elif beta_schedule == "squaredcos_cap_v2":
- # Glide cosine schedule
- self.betas = betas_for_alpha_bar(num_train_timesteps)
- else:
- raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
-
- self.alphas = 1.0 - self.betas
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
-
- # Modify: SNR shift following SD3
- self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
-
- # Rescale for zero SNR
- if rescale_betas_zero_snr:
- self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
-
- # At every step in ddim, we are looking into the previous alphas_cumprod
- # For the final step, there is no previous alphas_cumprod because we are already at 0
- # `set_alpha_to_one` decides whether we set this parameter simply to one or
- # whether we use the final alpha of the "non-previous" one.
- self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
-
- # standard deviation of the initial noise distribution
- self.init_noise_sigma = 1.0
-
- # setable values
- self.num_inference_steps = None
- self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
-
- def _get_variance(self, timestep, prev_timestep):
- alpha_prod_t = self.alphas_cumprod[timestep]
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
- beta_prod_t = 1 - alpha_prod_t
- beta_prod_t_prev = 1 - alpha_prod_t_prev
-
- variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
-
- return variance
-
- def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
- """
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
- current timestep.
-
- Args:
- sample (`torch.Tensor`):
- The input sample.
- timestep (`int`, *optional*):
- The current timestep in the diffusion chain.
-
- Returns:
- `torch.Tensor`:
- A scaled input sample.
- """
- return sample
-
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
- """
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
-
- Args:
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model.
- """
-
- if num_inference_steps > self.config.num_train_timesteps:
- raise ValueError(
- f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
- f" maximal {self.config.num_train_timesteps} timesteps."
- )
-
- self.num_inference_steps = num_inference_steps
-
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
- if self.config.timestep_spacing == "linspace":
- timesteps = (
- np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
- .round()[::-1]
- .copy()
- .astype(np.int64)
- )
- elif self.config.timestep_spacing == "leading":
- step_ratio = self.config.num_train_timesteps // self.num_inference_steps
- # creates integer timesteps by multiplying by ratio
- # casting to int to avoid issues when num_inference_step is power of 3
- timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
- timesteps += self.config.steps_offset
- elif self.config.timestep_spacing == "trailing":
- step_ratio = self.config.num_train_timesteps / self.num_inference_steps
- # creates integer timesteps by multiplying by ratio
- # casting to int to avoid issues when num_inference_step is power of 3
- timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
- timesteps -= 1
- else:
- raise ValueError(
- f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
- )
-
- self.timesteps = torch.from_numpy(timesteps).to(device)
-
- def step(
- self,
- model_output: torch.Tensor,
- timestep: int,
- sample: torch.Tensor,
- eta: float = 0.0,
- use_clipped_model_output: bool = False,
- generator=None,
- variance_noise: Optional[torch.Tensor] = None,
- return_dict: bool = True,
- ) -> Union[DDIMSchedulerOutput, Tuple]:
- """
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
- process from the learned model outputs (most often the predicted noise).
-
- Args:
- model_output (`torch.Tensor`):
- The direct output from learned diffusion model.
- timestep (`float`):
- The current discrete timestep in the diffusion chain.
- sample (`torch.Tensor`):
- A current instance of a sample created by the diffusion process.
- eta (`float`):
- The weight of noise for added noise in diffusion step.
- use_clipped_model_output (`bool`, defaults to `False`):
- If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
- because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
- clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
- `use_clipped_model_output` has no effect.
- generator (`torch.Generator`, *optional*):
- A random number generator.
- variance_noise (`torch.Tensor`):
- Alternative to generating noise with `generator` by directly providing the noise for the variance
- itself. Useful for methods such as [`CycleDiffusion`].
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
-
- Returns:
- [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
- If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
- tuple is returned where the first element is the sample tensor.
-
- """
- if self.num_inference_steps is None:
- raise ValueError(
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
- )
-
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
- # Ideally, read DDIM paper in-detail understanding
-
- # Notation ( ->
- # - pred_noise_t -> e_theta(x_t, t)
- # - pred_original_sample -> f_theta(x_t, t) or x_0
- # - std_dev_t -> sigma_t
- # - eta -> η
- # - pred_sample_direction -> "direction pointing to x_t"
- # - pred_prev_sample -> "x_t-1"
-
- # 1. get previous step value (=t-1)
- prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
-
- # 2. compute alphas, betas
- alpha_prod_t = self.alphas_cumprod[timestep]
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
-
- beta_prod_t = 1 - alpha_prod_t
-
- # 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
- # To make style tests pass, commented out `pred_epsilon` as it is an unused variable
- if self.config.prediction_type == "epsilon":
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
- # pred_epsilon = model_output
- elif self.config.prediction_type == "sample":
- pred_original_sample = model_output
- # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
- elif self.config.prediction_type == "v_prediction":
- pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
- # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
- else:
- raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
- " `v_prediction`"
- )
-
- a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5
- b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t
-
- prev_sample = a_t * sample + b_t * pred_original_sample
-
- if not return_dict:
- return (prev_sample,)
-
- return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
-
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
- def add_noise(
- self,
- original_samples: torch.Tensor,
- noise: torch.Tensor,
- timesteps: torch.IntTensor,
- ) -> torch.Tensor:
- # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
- # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
- # for the subsequent add_noise calls
- self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
- alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
- timesteps = timesteps.to(original_samples.device)
-
- sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
- while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
-
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
- while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
-
- noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
- return noisy_samples
-
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
- def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
- # Make sure alphas_cumprod and timestep have same device and dtype as sample
- self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
- alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
- timesteps = timesteps.to(sample.device)
-
- sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
- while len(sqrt_alpha_prod.shape) < len(sample.shape):
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
-
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
- while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
-
- velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
- return velocity
-
- def __len__(self):
- return self.config.num_train_timesteps
-
-
class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
"""
`DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
diff --git a/videosys/models/open_sora/rflow.py b/videosys/schedulers/scheduling_rflow_open_sora.py
similarity index 85%
rename from videosys/models/open_sora/rflow.py
rename to videosys/schedulers/scheduling_rflow_open_sora.py
index d9b8f5bfac237dcadded659d7c3ba6bcc2515e77..5ba6c6cadcf735e44a52c8a5430915195a5d4788 100644
--- a/videosys/models/open_sora/rflow.py
+++ b/videosys/schedulers/scheduling_rflow_open_sora.py
@@ -13,8 +13,20 @@ from einops import rearrange
from torch.distributions import LogisticNormal
from tqdm import tqdm
-from videosys.core.pab_mgr import get_diffusion_skip, get_diffusion_skip_timestep, skip_diffusion_timestep
-from videosys.diffusion.gaussian_diffusion import _extract_into_tensor
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res + torch.zeros(broadcast_shape, device=timesteps.device)
def mean_flat(tensor: torch.Tensor, mask=None):
@@ -176,11 +188,10 @@ class RFLOW:
def sample(
self,
model,
- text_encoder,
z,
- prompts,
+ model_args,
+ y_null,
device,
- additional_args=None,
mask=None,
guidance_scale=None,
progress=True,
@@ -190,13 +201,8 @@ class RFLOW:
if guidance_scale is None:
guidance_scale = self.cfg_scale
- n = len(prompts)
# text encoding
- model_args = text_encoder.encode(prompts)
- y_null = text_encoder.null(n)
model_args["y"] = torch.cat([model_args["y"], y_null], 0)
- if additional_args is not None:
- model_args.update(additional_args)
# prepare timesteps
timesteps = [(1.0 - i / self.num_sampling_steps) * self.num_timesteps for i in range(self.num_sampling_steps)]
@@ -204,24 +210,7 @@ class RFLOW:
timesteps = [int(round(t)) for t in timesteps]
timesteps = [torch.tensor([t] * z.shape[0], device=device) for t in timesteps]
if self.use_timestep_transform:
- timesteps = [timestep_transform(t, additional_args, num_timesteps=self.num_timesteps) for t in timesteps]
-
- if get_diffusion_skip() and get_diffusion_skip_timestep() is not None:
- orignal_timesteps = timesteps
- diffusion_skip_timestep = get_diffusion_skip_timestep()
- timesteps = skip_diffusion_timestep(timesteps, diffusion_skip_timestep)
-
- if verbose and dist.get_rank() == 0:
- print("============================")
- print("skip diffusion steps!!!")
- print("============================")
- print(f"orignal sample timesteps: {orignal_timesteps}")
- print(f"orignal diffusion steps: {len(orignal_timesteps)}")
- print("============================")
- print(f"skip diffusion steps: {get_diffusion_skip_timestep()}")
- print(f"sample timesteps: {timesteps}")
- print(f"num_inference_steps: {len(timesteps)}")
- print("============================")
+ timesteps = [timestep_transform(t, model_args, num_timesteps=self.num_timesteps) for t in timesteps]
if mask is not None:
noise_added = torch.zeros_like(mask, dtype=torch.bool)
diff --git a/videosys/utils/ckpt_utils.py b/videosys/utils/ckpt_utils.py
deleted file mode 100644
index e8b33c7909dc9eb83a6013f5042c07d5ef3dc71b..0000000000000000000000000000000000000000
--- a/videosys/utils/ckpt_utils.py
+++ /dev/null
@@ -1,115 +0,0 @@
-import functools
-import json
-import operator
-import os
-from typing import Tuple
-
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from colossalai.booster import Booster
-from colossalai.cluster import DistCoordinator
-from torch.optim import Optimizer
-from torch.optim.lr_scheduler import _LRScheduler
-
-from videosys.core.comm import model_sharding
-
-
-def load_json(file_path: str):
- with open(file_path, "r") as f:
- return json.load(f)
-
-
-def save_json(data, file_path: str):
- with open(file_path, "w") as f:
- json.dump(data, f, indent=4)
-
-
-def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor:
- return tensor[: functools.reduce(operator.mul, original_shape)]
-
-
-def model_gathering(model: torch.nn.Module, model_shape_dict: dict):
- global_rank = dist.get_rank()
- global_size = dist.get_world_size()
- for name, param in model.named_parameters():
- all_params = [torch.empty_like(param.data) for _ in range(global_size)]
- dist.all_gather(all_params, param.data, group=dist.group.WORLD)
- if global_rank == 0:
- all_params = torch.cat(all_params)
- param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
- dist.barrier()
-
-
-def record_model_param_shape(model: torch.nn.Module) -> dict:
- param_shape = {}
- for name, param in model.named_parameters():
- param_shape[name] = param.shape
- return param_shape
-
-
-def save(
- booster: Booster,
- model: nn.Module,
- ema: nn.Module,
- optimizer: Optimizer,
- lr_scheduler: _LRScheduler,
- epoch: int,
- step: int,
- global_step: int,
- batch_size: int,
- coordinator: DistCoordinator,
- save_dir: str,
- shape_dict: dict,
- shard_ema: bool = False,
-):
- torch.cuda.empty_cache()
- global_rank = dist.get_rank()
- save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}")
- os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
- booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
-
- # Gather the sharded ema model before saving
- if shard_ema:
- model_gathering(ema, shape_dict)
-
- # ema is not boosted, so we don't need to use booster.save_model
- if global_rank == 0:
- torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
- # Shard ema model when using zero2 plugin
- if shard_ema:
- model_sharding(ema)
- if optimizer is not None:
- booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
- if lr_scheduler is not None:
- booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
- running_states = {
- "epoch": epoch,
- "step": step,
- "global_step": global_step,
- "sample_start_index": step * batch_size,
- }
- if coordinator.is_master():
- save_json(running_states, os.path.join(save_dir, "running_states.json"))
- dist.barrier()
-
-
-def load(
- booster: Booster,
- model: nn.Module,
- ema: nn.Module,
- optimizer: Optimizer,
- lr_scheduler: _LRScheduler,
- load_dir: str,
-) -> Tuple[int, int, int]:
- booster.load_model(model, os.path.join(load_dir, "model"))
- # ema is not boosted, so we don't use booster.load_model
- ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
- if optimizer is not None:
- booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
- if lr_scheduler is not None:
- booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
- running_states = load_json(os.path.join(load_dir, "running_states.json"))
- dist.barrier()
- torch.cuda.empty_cache()
- return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
diff --git a/videosys/utils/debug_utils.py b/videosys/utils/debug_utils.py
deleted file mode 100644
index 7c6d8f45786f647aa87c064aedc01831ff8d498d..0000000000000000000000000000000000000000
--- a/videosys/utils/debug_utils.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import torch.distributed as dist
-
-
-# Print debug information on selected rank
-def print_rank(var_name, var_value, rank=0):
- if dist.get_rank() == rank:
- print(f"[Rank {rank}] {var_name}: {var_value}")
diff --git a/videosys/utils/download.py b/videosys/utils/download.py
deleted file mode 100644
index 75cf8dd8bf91e66d568ca2055a8138d6fb8977bb..0000000000000000000000000000000000000000
--- a/videosys/utils/download.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""
-Functions for downloading pre-trained DiT models
-"""
-import json
-import os
-
-import torch
-from torchvision.datasets.utils import download_url
-
-pretrained_models = {"DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"}
-
-
-def find_model(model_name):
- """
- Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
- """
- if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
- return download_model(model_name)
- else: # Load a custom DiT checkpoint:
- if not os.path.isfile(model_name):
- # if the model_name is a directory, then we assume we should load it in the Hugging Face manner
- # i.e. the model weights are sharded into multiple files and there is an index.json file
- # walk through the files in the directory and find the index.json file
- index_file = [os.path.join(model_name, f) for f in os.listdir(model_name) if "index.json" in f]
- assert len(index_file) == 1, f"Could not find index.json in {model_name}"
-
- # process index json
- with open(index_file[0], "r") as f:
- index_data = json.load(f)
-
- bin_to_weight_mapping = dict()
- for k, v in index_data["weight_map"].items():
- if v in bin_to_weight_mapping:
- bin_to_weight_mapping[v].append(k)
- else:
- bin_to_weight_mapping[v] = [k]
-
- # make state dict
- state_dict = dict()
- for bin_name, weight_list in bin_to_weight_mapping.items():
- bin_path = os.path.join(model_name, bin_name)
- bin_state_dict = torch.load(bin_path, map_location=lambda storage, loc: storage)
- for weight in weight_list:
- state_dict[weight] = bin_state_dict[weight]
- return state_dict
- else:
- # if it is a file, we just load it directly in the typical PyTorch manner
- assert os.path.exists(model_name), f"Could not find DiT checkpoint at {model_name}"
- checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
- if "ema" in checkpoint: # supports checkpoints from train.py
- checkpoint = checkpoint["ema"]
- return checkpoint
-
-
-def download_model(model_name):
- """
- Downloads a pre-trained DiT model from the web.
- """
- assert model_name in pretrained_models
- local_path = f"pretrained_models/{model_name}"
- if not os.path.isfile(local_path):
- os.makedirs("pretrained_models", exist_ok=True)
- web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
- download_url(web_path, "pretrained_models")
- model = torch.load(local_path, map_location=lambda storage, loc: storage)
- return model
-
-
-if __name__ == "__main__":
- # Download all DiT checkpoints
- for model in pretrained_models:
- download_model(model)
- print("Done.")
diff --git a/videosys/utils/train_utils.py b/videosys/utils/train_utils.py
deleted file mode 100644
index 8906872698c59eb4cf76e00692ea2d6418224190..0000000000000000000000000000000000000000
--- a/videosys/utils/train_utils.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from collections import OrderedDict
-
-import torch
-import torch.distributed as dist
-from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer
-
-
-def get_model_numel(model: torch.nn.Module) -> int:
- return sum(p.numel() for p in model.parameters())
-
-
-def format_numel_str(numel: int) -> str:
- B = 1024**3
- M = 1024**2
- K = 1024
- if numel >= B:
- return f"{numel / B:.2f} B"
- elif numel >= M:
- return f"{numel / M:.2f} M"
- elif numel >= K:
- return f"{numel / K:.2f} K"
- else:
- return f"{numel}"
-
-
-def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
- dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
- tensor.div_(dist.get_world_size())
- return tensor
-
-
-@torch.no_grad()
-def update_ema(
- ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True
-) -> None:
- """
- Step the EMA model towards the current model.
- """
- ema_params = OrderedDict(ema_model.named_parameters())
- model_params = OrderedDict(model.named_parameters())
-
- for name, param in model_params.items():
- if name == "pos_embed":
- continue
- if param.requires_grad == False:
- continue
- if not sharded:
- param_data = param.data
- ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
- else:
- if param.data.dtype != torch.float32 and isinstance(optimizer, LowLevelZeroOptimizer):
- param_id = id(param)
- master_param = optimizer._param_store.working_to_master_param[param_id]
- param_data = master_param.data
- else:
- param_data = param.data
- ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
-
-
-def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
- """
- Set requires_grad flag for all parameters in a model.
- """
- for p in model.parameters():
- p.requires_grad = flag