Spaces:
Running
on
L4
Running
on
L4
Xu Xuenan
commited on
Commit
•
a121edc
1
Parent(s):
644bfda
Initial commit
Browse files- .gitignore +165 -0
- app.py +258 -0
- configs/mm_story_agent.yaml +75 -0
- mm_story_agent/__init__.py +105 -0
- mm_story_agent/modality_agents/image_agent.py +663 -0
- mm_story_agent/modality_agents/llm.py +73 -0
- mm_story_agent/modality_agents/music_agent.py +78 -0
- mm_story_agent/modality_agents/sound_agent.py +106 -0
- mm_story_agent/modality_agents/speech_agent.py +90 -0
- mm_story_agent/modality_agents/story_agent.py +114 -0
- mm_story_agent/prompts_en.py +277 -0
- mm_story_agent/video_compose_agent.py +412 -0
- nls-1.0.0-py3-none-any.whl +0 -0
- policy.xml +99 -0
- requirements.txt +13 -0
.gitignore
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
164 |
+
|
165 |
+
generated_stories/
|
app.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import argparse
|
3 |
+
import shutil
|
4 |
+
import time
|
5 |
+
import uuid
|
6 |
+
import subprocess
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import yaml
|
10 |
+
import torch.multiprocessing as mp
|
11 |
+
|
12 |
+
mp.set_start_method('spawn', force=True)
|
13 |
+
|
14 |
+
from mm_story_agent import MMStoryAgent
|
15 |
+
|
16 |
+
try:
|
17 |
+
result = subprocess.run(["convert", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
18 |
+
imagemagick_installed = True
|
19 |
+
except FileNotFoundError:
|
20 |
+
imagemagick_installed = False
|
21 |
+
|
22 |
+
if not imagemagick_installed:
|
23 |
+
import os
|
24 |
+
os.system("apt update -y")
|
25 |
+
os.system("apt install -y imagemagick")
|
26 |
+
os.system("cp policy.xml /etc/ImageMagick-6/")
|
27 |
+
|
28 |
+
|
29 |
+
with open("configs/mm_story_agent.yaml", "r") as reader:
|
30 |
+
config = yaml.load(reader, Loader=yaml.FullLoader)
|
31 |
+
|
32 |
+
|
33 |
+
default_story_setting = config["story_setting"]
|
34 |
+
default_story_gen_config = config["story_gen_config"]
|
35 |
+
default_slideshow_effect = config["slideshow_effect"]
|
36 |
+
default_image_config = config["image_generation"]
|
37 |
+
default_sound_config = config["sound_generation"]
|
38 |
+
default_music_config = config["music_generation"]
|
39 |
+
|
40 |
+
|
41 |
+
def set_generating_progress_text(text):
|
42 |
+
return gr.update(visible=True, value=f"<h3>{text} ...</h3>")
|
43 |
+
|
44 |
+
def set_text_invisible():
|
45 |
+
return gr.update(visible=False)
|
46 |
+
|
47 |
+
def deep_update(original, updates):
|
48 |
+
for key, value in updates.items():
|
49 |
+
if isinstance(value, dict):
|
50 |
+
original[key] = deep_update(original.get(key, {}), value)
|
51 |
+
else:
|
52 |
+
original[key] = value
|
53 |
+
return original
|
54 |
+
|
55 |
+
def update_page(direction, page, story_data):
|
56 |
+
|
57 |
+
orig_page = page
|
58 |
+
if direction == 'next' and page < len(story_data) - 1:
|
59 |
+
page = orig_page + 1
|
60 |
+
elif direction == 'prev' and page > 0:
|
61 |
+
page = orig_page - 1
|
62 |
+
|
63 |
+
return page, story_data[page], story_data
|
64 |
+
|
65 |
+
def write_story_fn(story_topic, main_role, scene,
|
66 |
+
num_outline, temperature,
|
67 |
+
current_page,
|
68 |
+
progress=gr.Progress(track_tqdm=True)):
|
69 |
+
config["story_dir"] = f"generated_stories/{time.strftime('%Y%m%d-%H%M%S') + '-' + str(uuid.uuid1().hex)}"
|
70 |
+
deep_update(config, {
|
71 |
+
"story_setting": {
|
72 |
+
"story_topic": story_topic,
|
73 |
+
"main_role": main_role,
|
74 |
+
"scene": scene,
|
75 |
+
},
|
76 |
+
"story_gen_config": {
|
77 |
+
"num_outline": num_outline,
|
78 |
+
"temperature": temperature
|
79 |
+
},
|
80 |
+
})
|
81 |
+
story_gen_agent = MMStoryAgent()
|
82 |
+
pages = story_gen_agent.write_story(config)
|
83 |
+
# story_data, story_accordion, story_content
|
84 |
+
return pages, gr.update(visible=True), pages[current_page], gr.update()
|
85 |
+
|
86 |
+
def modality_assets_generation_fn(
|
87 |
+
height, width, image_seed, sound_guidance_scale, sound_seed,
|
88 |
+
n_candidate_per_text, music_duration,
|
89 |
+
story_data):
|
90 |
+
deep_update(config, {
|
91 |
+
"image_generation": {
|
92 |
+
"obj_cfg": {
|
93 |
+
"height": height,
|
94 |
+
"width": width,
|
95 |
+
},
|
96 |
+
"call_cfg": {
|
97 |
+
"seed": image_seed
|
98 |
+
}
|
99 |
+
},
|
100 |
+
"sound_generation": {
|
101 |
+
"call_cfg": {
|
102 |
+
"guidance_scale": sound_guidance_scale,
|
103 |
+
"seed": sound_seed,
|
104 |
+
"n_candidate_per_text": n_candidate_per_text
|
105 |
+
}
|
106 |
+
},
|
107 |
+
"music_generation": {
|
108 |
+
"call_cfg": {
|
109 |
+
"duration": music_duration
|
110 |
+
}
|
111 |
+
}
|
112 |
+
})
|
113 |
+
story_gen_agent = MMStoryAgent()
|
114 |
+
images = story_gen_agent.generate_modality_assets(config, story_data)
|
115 |
+
# image gallery
|
116 |
+
return gr.update(visible=True, value=images, columns=[len(images)], rows=[1], height="auto")
|
117 |
+
|
118 |
+
def compose_storytelling_video_fn(
|
119 |
+
fade_duration, slide_duration, zoom_speed, move_ratio,
|
120 |
+
sound_volume, music_volume, bg_speech_ratio, fps,
|
121 |
+
story_data,
|
122 |
+
progress=gr.Progress(track_tqdm=True)):
|
123 |
+
deep_update(config, {
|
124 |
+
"slideshow_effect": {
|
125 |
+
"fade_duration": fade_duration,
|
126 |
+
"slide_duration": slide_duration,
|
127 |
+
"zoom_speed": zoom_speed,
|
128 |
+
"move_ratio": move_ratio,
|
129 |
+
"sound_volume": sound_volume,
|
130 |
+
"music_volume": music_volume,
|
131 |
+
"bg_speech_ratio": bg_speech_ratio,
|
132 |
+
"fps": fps
|
133 |
+
},
|
134 |
+
})
|
135 |
+
story_gen_agent = MMStoryAgent()
|
136 |
+
story_gen_agent.compose_storytelling_video(config, story_data)
|
137 |
+
|
138 |
+
# video_output
|
139 |
+
return Path(config["story_dir"]) / "output.mp4"
|
140 |
+
|
141 |
+
|
142 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
143 |
+
|
144 |
+
gr.HTML("""
|
145 |
+
<h1 style="text-align: center;">MM-StoryAgent</h1>
|
146 |
+
<p style="font-size: 16px;">This is a demo for generating attractive storytelling videos based on the given story setting.</p>
|
147 |
+
""")
|
148 |
+
|
149 |
+
with gr.Row():
|
150 |
+
with gr.Column():
|
151 |
+
story_topic = gr.Textbox(label="Story Topic", value=default_story_setting["story_topic"])
|
152 |
+
main_role = gr.Textbox(label="Main Role", value=default_story_setting["main_role"])
|
153 |
+
scene = gr.Textbox(label="Scene", value=default_story_setting["scene"])
|
154 |
+
chapter_num = gr.Number(label="Chapter Number", value=default_story_gen_config["num_outline"])
|
155 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature", value=default_story_gen_config["temperature"])
|
156 |
+
|
157 |
+
with gr.Accordion("Detailed Image Configuration (Optional)", open=False):
|
158 |
+
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['height'])
|
159 |
+
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['width'])
|
160 |
+
image_seed = gr.Number(label="Image Seed", value=default_image_config["call_cfg"]['seed'])
|
161 |
+
|
162 |
+
with gr.Accordion("Detailed Sound Configuration (Optional)", open=False):
|
163 |
+
sound_guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=7.0, step=0.5, value=default_sound_config["call_cfg"]['guidance_scale'])
|
164 |
+
sound_seed = gr.Number(label="Sound Seed", value=default_sound_config["call_cfg"]['seed'])
|
165 |
+
n_candidate_per_text = gr.Slider(label="Number of Candidates per Text", minimum=0, maximum=5, step=1, value=default_sound_config["call_cfg"]['n_candidate_per_text'])
|
166 |
+
|
167 |
+
with gr.Accordion("Detailed Music Configuration (Optional)", open=False):
|
168 |
+
music_duration = gr.Number(label="Music Duration", min_width=30.0, maximum=120.0, value=default_music_config["call_cfg"]["duration"])
|
169 |
+
|
170 |
+
with gr.Accordion("Detailed Slideshow Effect (Optional)", open=False):
|
171 |
+
fade_duration = gr.Slider(label="Fade Duration", minimum=0.1, maximum=1.5, step=0.1, value=default_slideshow_effect['fade_duration'])
|
172 |
+
slide_duration = gr.Slider(label="Slide Duration", minimum=0.1, maximum=1.0, step=0.1, value=default_slideshow_effect['slide_duration'])
|
173 |
+
zoom_speed = gr.Slider(label="Zoom Speed", minimum=0.1, maximum=2.0, step=0.1, value=default_slideshow_effect['zoom_speed'])
|
174 |
+
move_ratio = gr.Slider(label="Move Ratio", minimum=0.8, maximum=1.0, step=0.05, value=default_slideshow_effect['move_ratio'])
|
175 |
+
sound_volume = gr.Slider(label="Sound Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['sound_volume'])
|
176 |
+
music_volume = gr.Slider(label="Music Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['music_volume'])
|
177 |
+
bg_speech_ratio = gr.Slider(label="Background / Speech Ratio", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['bg_speech_ratio'])
|
178 |
+
fps = gr.Slider(label="FPS", minimum=1, maximum=30, step=1, value=default_slideshow_effect['fps'])
|
179 |
+
|
180 |
+
|
181 |
+
with gr.Column():
|
182 |
+
story_data = gr.State([])
|
183 |
+
|
184 |
+
story_generation_information = gr.Markdown(
|
185 |
+
label="Story Generation Status",
|
186 |
+
value="<h3>Generating Story Script ......</h3>",
|
187 |
+
visible=False)
|
188 |
+
with gr.Accordion(label="Story Content", open=False, visible=False) as story_accordion:
|
189 |
+
with gr.Row():
|
190 |
+
prev_button = gr.Button("Previous Page",)
|
191 |
+
next_button = gr.Button("Next Page",)
|
192 |
+
story_content = gr.Textbox(label="Page Content")
|
193 |
+
video_generation_information = gr.Markdown(label="Generation Status", value="<h3>Generating Video ......</h3>", visible=False)
|
194 |
+
image_gallery = gr.Gallery(label="Images", show_label=False, visible=False)
|
195 |
+
video_generation_btn = gr.Button("Generate Video")
|
196 |
+
video_output = gr.Video(label="Generated Story", interactive=False)
|
197 |
+
|
198 |
+
current_page = gr.State(0)
|
199 |
+
|
200 |
+
prev_button.click(
|
201 |
+
fn=update_page,
|
202 |
+
inputs=[gr.State("prev"), current_page, story_data],
|
203 |
+
outputs=[current_page, story_content]
|
204 |
+
)
|
205 |
+
next_button.click(
|
206 |
+
fn=update_page,
|
207 |
+
inputs=[gr.State("next"), current_page, story_data],
|
208 |
+
outputs=[current_page, story_content,])
|
209 |
+
|
210 |
+
# (possibly) update role description and scripts
|
211 |
+
|
212 |
+
video_generation_btn.click(
|
213 |
+
fn=set_generating_progress_text,
|
214 |
+
inputs=[gr.State("Generating Story")],
|
215 |
+
outputs=video_generation_information
|
216 |
+
).then(
|
217 |
+
fn=write_story_fn,
|
218 |
+
inputs=[story_topic, main_role, scene,
|
219 |
+
chapter_num, temperature,
|
220 |
+
current_page],
|
221 |
+
outputs=[story_data, story_accordion, story_content, video_output]
|
222 |
+
).then(
|
223 |
+
fn=set_generating_progress_text,
|
224 |
+
inputs=[gr.State("Generating Modality Assets")],
|
225 |
+
outputs=video_generation_information
|
226 |
+
).then(
|
227 |
+
fn=modality_assets_generation_fn,
|
228 |
+
inputs=[height, width, image_seed, sound_guidance_scale, sound_seed,
|
229 |
+
n_candidate_per_text, music_duration,
|
230 |
+
story_data],
|
231 |
+
outputs=[image_gallery]
|
232 |
+
).then(
|
233 |
+
fn=set_generating_progress_text,
|
234 |
+
inputs=[gr.State("Composing Video")],
|
235 |
+
outputs=video_generation_information
|
236 |
+
).then(
|
237 |
+
fn=compose_storytelling_video_fn,
|
238 |
+
inputs=[fade_duration, slide_duration, zoom_speed, move_ratio,
|
239 |
+
sound_volume, music_volume, bg_speech_ratio, fps,
|
240 |
+
story_data],
|
241 |
+
outputs=[video_output]
|
242 |
+
).then(
|
243 |
+
fn=lambda : gr.update(visible=False),
|
244 |
+
inputs=[],
|
245 |
+
outputs=[image_gallery]
|
246 |
+
).then(
|
247 |
+
fn=set_generating_progress_text,
|
248 |
+
inputs=[gr.State("Generation Finished")],
|
249 |
+
outputs=video_generation_information
|
250 |
+
)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
parser = argparse.ArgumentParser()
|
255 |
+
parser.add_argument("--share", default=False, action="store_true")
|
256 |
+
|
257 |
+
args = parser.parse_args()
|
258 |
+
demo.launch(share=args.share)
|
configs/mm_story_agent.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
story_dir: generated_stories/20240808_1130
|
2 |
+
audio_sample_rate: &audio_sample_rate 16000
|
3 |
+
audio_codec: mp3 # [mp3, aac, ...]
|
4 |
+
|
5 |
+
|
6 |
+
story_setting:
|
7 |
+
story_topic: "Time Management: A child learning how to manage their time effectively."
|
8 |
+
main_role: "(no main role specified)"
|
9 |
+
scene: "(no scene specified)"
|
10 |
+
|
11 |
+
story_gen_config:
|
12 |
+
max_conv_turns: 3
|
13 |
+
num_outline: 4
|
14 |
+
temperature: 0.5
|
15 |
+
|
16 |
+
caption_config:
|
17 |
+
font: resources/font/msyh.ttf
|
18 |
+
# bg_color: LightGrey
|
19 |
+
fontsize: 32
|
20 |
+
color: white
|
21 |
+
# stroke_color: white
|
22 |
+
# stroke_width: 0.5
|
23 |
+
max_single_caption_length: 50
|
24 |
+
|
25 |
+
sound_generation:
|
26 |
+
call_cfg:
|
27 |
+
guidance_scale: 3.5
|
28 |
+
seed: 0
|
29 |
+
ddim_steps: 200
|
30 |
+
n_candidate_per_text: 3
|
31 |
+
revise_cfg:
|
32 |
+
num_turns: 3
|
33 |
+
sample_rate: *audio_sample_rate
|
34 |
+
|
35 |
+
|
36 |
+
speech_generation:
|
37 |
+
call_cfg:
|
38 |
+
voice: longyuan
|
39 |
+
sample_rate: *audio_sample_rate
|
40 |
+
|
41 |
+
|
42 |
+
image_generation:
|
43 |
+
revise_cfg:
|
44 |
+
num_turns: 3
|
45 |
+
obj_cfg:
|
46 |
+
model_name: stabilityai/stable-diffusion-xl-base-1.0
|
47 |
+
id_length: 2
|
48 |
+
height: 512
|
49 |
+
width: 1024
|
50 |
+
call_cfg:
|
51 |
+
seed: 112536
|
52 |
+
guidance_scale: 10.0
|
53 |
+
style_name: "Storybook" # ['(No style)', 'Japanese Anime', 'Digital/Oil Painting', 'Pixar/Disney Character',
|
54 |
+
# 'Photographic', 'Comic book', 'Line art', 'Black and White Film Noir', 'Isometric Rooms']
|
55 |
+
|
56 |
+
music_generation:
|
57 |
+
revise_cfg:
|
58 |
+
num_turns: 3
|
59 |
+
call_cfg:
|
60 |
+
duration: 60.0
|
61 |
+
|
62 |
+
slideshow_effect:
|
63 |
+
fade_duration: 0.8
|
64 |
+
slide_duration: 0.4
|
65 |
+
zoom_speed: 0.5
|
66 |
+
move_ratio: 0.9
|
67 |
+
|
68 |
+
sound_volume: 0.6
|
69 |
+
music_volume: 0.5
|
70 |
+
bg_speech_ratio: 0.6
|
71 |
+
|
72 |
+
fps: 8
|
73 |
+
|
74 |
+
|
75 |
+
|
mm_story_agent/__init__.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch.multiprocessing as mp
|
6 |
+
|
7 |
+
from mm_story_agent.modality_agents.story_agent import QAOutlineStoryWriter
|
8 |
+
from mm_story_agent.modality_agents.speech_agent import CosyVoiceAgent
|
9 |
+
from mm_story_agent.modality_agents.sound_agent import AudioLDM2Agent
|
10 |
+
from mm_story_agent.modality_agents.music_agent import MusicGenAgent
|
11 |
+
from mm_story_agent.modality_agents.image_agent import StoryDiffusionAgent
|
12 |
+
from mm_story_agent.video_compose_agent import VideoComposeAgent
|
13 |
+
|
14 |
+
|
15 |
+
class MMStoryAgent:
|
16 |
+
|
17 |
+
def __init__(self) -> None:
|
18 |
+
self.modalities = ["image", "sound", "speech", "music"]
|
19 |
+
self.modality_agent_class = {
|
20 |
+
"image": StoryDiffusionAgent,
|
21 |
+
"sound": AudioLDM2Agent,
|
22 |
+
"speech": CosyVoiceAgent,
|
23 |
+
"music": MusicGenAgent
|
24 |
+
}
|
25 |
+
self.agents = {}
|
26 |
+
|
27 |
+
def call_modality_agent(self, agent, pages, save_path, return_dict):
|
28 |
+
result = agent.call(pages, save_path)
|
29 |
+
modality = result["modality"]
|
30 |
+
return_dict[modality] = result
|
31 |
+
|
32 |
+
def write_story(self, config):
|
33 |
+
story_writer = QAOutlineStoryWriter(config["story_gen_config"])
|
34 |
+
pages = story_writer.call(config["story_setting"])
|
35 |
+
# pages = [
|
36 |
+
# "In the heart of a dense forest, Flicker the Fox, nestled in his cozy den, stumbled upon an ancient computer hidden beneath a pile of soft moss and forgotten treasures. Surrounded by maps of unexplored territories and codes scribbled on parchment, Flicker's eyes widened with intrigue as he traced his paw over the mysterious machine.",
|
37 |
+
# "Flicker's den was a testament to his adventurous spirit, a haven filled with artifacts from his previous quests. The discovery of the computer, however, sparked a new kind of excitement within him, a curiosity that went beyond the physical boundaries of his forest home.",
|
38 |
+
# "With a determined gleam in his eye, Flicker trotted out of his den in search of his parents. He had questions about this relic that couldn't wait, eager to understand the secrets it held and how it functioned in a world so different from his own.",
|
39 |
+
# "Excited by his parents' encouragement, Flicker eagerly started his journey into the world of typing. His paws clumsily hit the wrong keys at first, resulting in a string of random letters and numbers on the screen. But with every mistake, Flicker's determination grew stronger.",
|
40 |
+
# "Days turned into weeks, and Flicker's persistence paid off. His paws now moved gracefully across the keyboard, his eyes focused on the screen as he typed out simple messages and commands. The once foreign device was becoming a familiar tool, and Flicker felt a sense of accomplishment wash over him.",
|
41 |
+
# "One evening, as the moon illuminated the forest, a wise old owl named Ollie perched on a branch outside Flicker's den. With a hoot and a smile, Ollie shared the magic of keyboard shortcuts, turning Flicker's typing sessions into thrilling adventures. Each shortcut was like a secret code, and Flicker couldn't wait to master them all.",
|
42 |
+
# "Eager to explore beyond the basics, Flicker's curiosity led him to the vast digital world of the internet. With guidance from his parents and Ollie, he learned how to navigate safely, discovering interactive games and educational videos that opened his eyes to the wonders beyond his forest.",
|
43 |
+
# "Each day, Flicker would sit before the screen, his paws dancing over the keys as he clicked through virtual tours of distant lands, watched videos of creatures he'd never seen, and played games that taught him about science and history. The computer became a window to a world far larger than he could have imagined.",
|
44 |
+
# ]
|
45 |
+
return pages
|
46 |
+
|
47 |
+
def generate_modality_assets(self, config, pages):
|
48 |
+
script_data = {"pages": [{"story": page} for page in pages]}
|
49 |
+
story_dir = Path(config["story_dir"])
|
50 |
+
|
51 |
+
for sub_dir in self.modalities:
|
52 |
+
(story_dir / sub_dir).mkdir(exist_ok=True, parents=True)
|
53 |
+
|
54 |
+
agents = {}
|
55 |
+
for modality in self.modalities:
|
56 |
+
agents[modality] = self.modality_agent_class[modality](config[modality + "_generation"])
|
57 |
+
|
58 |
+
processes = []
|
59 |
+
return_dict = mp.Manager().dict()
|
60 |
+
|
61 |
+
for modality in self.modalities:
|
62 |
+
p = mp.Process(target=self.call_modality_agent, args=(agents[modality], pages, story_dir / modality, return_dict))
|
63 |
+
processes.append(p)
|
64 |
+
p.start()
|
65 |
+
|
66 |
+
for p in processes:
|
67 |
+
p.join()
|
68 |
+
|
69 |
+
for modality, result in return_dict.items():
|
70 |
+
try:
|
71 |
+
if result["modality"] == "image":
|
72 |
+
images = result["generation_results"]
|
73 |
+
for idx in range(len(pages)):
|
74 |
+
script_data["pages"][idx]["image_prompt"] = result["prompts"][idx]
|
75 |
+
elif result["modality"] == "sound":
|
76 |
+
for idx in range(len(pages)):
|
77 |
+
script_data["pages"][idx]["sound_prompt"] = result["prompts"][idx]
|
78 |
+
elif result["modality"] == "music":
|
79 |
+
script_data["music_prompt"] = result["prompt"]
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Error occurred during generation: {e}")
|
82 |
+
|
83 |
+
with open(story_dir / "script_data.json", "w") as writer:
|
84 |
+
json.dump(script_data, writer, ensure_ascii=False, indent=4)
|
85 |
+
|
86 |
+
return images
|
87 |
+
|
88 |
+
def compose_storytelling_video(self, config, pages):
|
89 |
+
video_compose_agent = VideoComposeAgent()
|
90 |
+
video_compose_agent.call(pages, config)
|
91 |
+
|
92 |
+
def call(self, config):
|
93 |
+
pages = self.write_story(config)
|
94 |
+
images = self.generate_modality_assets(config, pages)
|
95 |
+
self.compose_storytelling_video(config, pages)
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
|
100 |
+
from arg_parser import parse_yaml_and_cmd
|
101 |
+
|
102 |
+
config = parse_yaml_and_cmd()
|
103 |
+
mm_story_agent = MMStoryAgent()
|
104 |
+
|
105 |
+
mm_story_agent.call(config)
|
mm_story_agent/modality_agents/image_agent.py
ADDED
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
|
10 |
+
|
11 |
+
from mm_story_agent.modality_agents.llm import QwenAgent
|
12 |
+
from mm_story_agent.prompts_en import role_extract_system, role_review_system, \
|
13 |
+
story_to_image_reviser_system, story_to_image_review_system
|
14 |
+
|
15 |
+
|
16 |
+
def setup_seed(seed):
|
17 |
+
torch.manual_seed(seed)
|
18 |
+
torch.cuda.manual_seed_all(seed)
|
19 |
+
np.random.seed(seed)
|
20 |
+
random.seed(seed)
|
21 |
+
torch.backends.cudnn.deterministic = True
|
22 |
+
|
23 |
+
|
24 |
+
class AttnProcessor(torch.nn.Module):
|
25 |
+
r"""
|
26 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
27 |
+
"""
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
hidden_size=None,
|
31 |
+
cross_attention_dim=None,
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
35 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
36 |
+
|
37 |
+
def __call__(
|
38 |
+
self,
|
39 |
+
attn,
|
40 |
+
hidden_states,
|
41 |
+
encoder_hidden_states=None,
|
42 |
+
attention_mask=None,
|
43 |
+
temb=None,
|
44 |
+
):
|
45 |
+
residual = hidden_states
|
46 |
+
|
47 |
+
if attn.spatial_norm is not None:
|
48 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
49 |
+
|
50 |
+
input_ndim = hidden_states.ndim
|
51 |
+
|
52 |
+
if input_ndim == 4:
|
53 |
+
batch_size, channel, height, width = hidden_states.shape
|
54 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
55 |
+
|
56 |
+
batch_size, sequence_length, _ = (
|
57 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
58 |
+
)
|
59 |
+
|
60 |
+
if attention_mask is not None:
|
61 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
62 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
63 |
+
# (batch, heads, source_length, target_length)
|
64 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
65 |
+
|
66 |
+
if attn.group_norm is not None:
|
67 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
68 |
+
|
69 |
+
query = attn.to_q(hidden_states)
|
70 |
+
|
71 |
+
if encoder_hidden_states is None:
|
72 |
+
encoder_hidden_states = hidden_states
|
73 |
+
elif attn.norm_cross:
|
74 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
75 |
+
|
76 |
+
key = attn.to_k(encoder_hidden_states)
|
77 |
+
value = attn.to_v(encoder_hidden_states)
|
78 |
+
|
79 |
+
inner_dim = key.shape[-1]
|
80 |
+
head_dim = inner_dim // attn.heads
|
81 |
+
|
82 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
83 |
+
|
84 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
85 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
86 |
+
|
87 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
88 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
89 |
+
hidden_states = F.scaled_dot_product_attention(
|
90 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
91 |
+
)
|
92 |
+
|
93 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
94 |
+
hidden_states = hidden_states.to(query.dtype)
|
95 |
+
|
96 |
+
# linear proj
|
97 |
+
hidden_states = attn.to_out[0](hidden_states)
|
98 |
+
# dropout
|
99 |
+
hidden_states = attn.to_out[1](hidden_states)
|
100 |
+
|
101 |
+
if input_ndim == 4:
|
102 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
103 |
+
|
104 |
+
if attn.residual_connection:
|
105 |
+
hidden_states = hidden_states + residual
|
106 |
+
|
107 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
108 |
+
|
109 |
+
return hidden_states
|
110 |
+
|
111 |
+
|
112 |
+
def cal_attn_mask_xl(total_length,
|
113 |
+
id_length,
|
114 |
+
sa32,
|
115 |
+
sa64,
|
116 |
+
height,
|
117 |
+
width,
|
118 |
+
device="cuda",
|
119 |
+
dtype=torch.float16):
|
120 |
+
nums_1024 = (height // 32) * (width // 32)
|
121 |
+
nums_4096 = (height // 16) * (width // 16)
|
122 |
+
bool_matrix1024 = torch.rand((1, total_length * nums_1024),device = device,dtype = dtype) < sa32
|
123 |
+
bool_matrix4096 = torch.rand((1, total_length * nums_4096),device = device,dtype = dtype) < sa64
|
124 |
+
bool_matrix1024 = bool_matrix1024.repeat(total_length,1)
|
125 |
+
bool_matrix4096 = bool_matrix4096.repeat(total_length,1)
|
126 |
+
for i in range(total_length):
|
127 |
+
bool_matrix1024[i:i+1,id_length*nums_1024:] = False
|
128 |
+
bool_matrix4096[i:i+1,id_length*nums_4096:] = False
|
129 |
+
bool_matrix1024[i:i+1,i*nums_1024:(i+1)*nums_1024] = True
|
130 |
+
bool_matrix4096[i:i+1,i*nums_4096:(i+1)*nums_4096] = True
|
131 |
+
mask1024 = bool_matrix1024.unsqueeze(1).repeat(1,nums_1024,1).reshape(-1,total_length * nums_1024)
|
132 |
+
mask4096 = bool_matrix4096.unsqueeze(1).repeat(1,nums_4096,1).reshape(-1,total_length * nums_4096)
|
133 |
+
return mask1024, mask4096
|
134 |
+
|
135 |
+
|
136 |
+
class SpatialAttnProcessor2_0(torch.nn.Module):
|
137 |
+
r"""
|
138 |
+
Attention processor for IP-Adapater for PyTorch 2.0.
|
139 |
+
Args:
|
140 |
+
hidden_size (`int`):
|
141 |
+
The hidden size of the attention layer.
|
142 |
+
cross_attention_dim (`int`):
|
143 |
+
The number of channels in the `encoder_hidden_states`.
|
144 |
+
text_context_len (`int`, defaults to 77):
|
145 |
+
The context length of the text features.
|
146 |
+
scale (`float`, defaults to 1.0):
|
147 |
+
the weight scale of image prompt.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self,
|
151 |
+
global_attn_args,
|
152 |
+
hidden_size=None,
|
153 |
+
cross_attention_dim=None,
|
154 |
+
id_length=4,
|
155 |
+
device="cuda",
|
156 |
+
dtype=torch.float16,
|
157 |
+
height=1280,
|
158 |
+
width=720,
|
159 |
+
sa32=0.5,
|
160 |
+
sa64=0.5,
|
161 |
+
):
|
162 |
+
super().__init__()
|
163 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
164 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
165 |
+
self.device = device
|
166 |
+
self.dtype = dtype
|
167 |
+
self.hidden_size = hidden_size
|
168 |
+
self.cross_attention_dim = cross_attention_dim
|
169 |
+
self.total_length = id_length + 1
|
170 |
+
self.id_length = id_length
|
171 |
+
self.id_bank = {}
|
172 |
+
self.height = height
|
173 |
+
self.width = width
|
174 |
+
self.sa32 = sa32
|
175 |
+
self.sa64 = sa64
|
176 |
+
self.write = True
|
177 |
+
|
178 |
+
self.global_attn_args = global_attn_args
|
179 |
+
|
180 |
+
|
181 |
+
def __call__(
|
182 |
+
self,
|
183 |
+
attn,
|
184 |
+
hidden_states,
|
185 |
+
encoder_hidden_states=None,
|
186 |
+
attention_mask=None,
|
187 |
+
temb=None
|
188 |
+
):
|
189 |
+
total_count = self.global_attn_args["total_count"]
|
190 |
+
attn_count = self.global_attn_args["attn_count"]
|
191 |
+
cur_step = self.global_attn_args["cur_step"]
|
192 |
+
mask1024 = self.global_attn_args["mask1024"]
|
193 |
+
mask4096 = self.global_attn_args["mask4096"]
|
194 |
+
|
195 |
+
if self.write:
|
196 |
+
self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
|
197 |
+
else:
|
198 |
+
encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),
|
199 |
+
hidden_states[:1],
|
200 |
+
self.id_bank[cur_step][1].to(self.device), hidden_states[1:]))
|
201 |
+
# skip in early step
|
202 |
+
if cur_step < 5:
|
203 |
+
hidden_states = self.__call2__(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
|
204 |
+
else: # 256 1024 4096
|
205 |
+
random_number = random.random()
|
206 |
+
if cur_step < 20:
|
207 |
+
rand_num = 0.3
|
208 |
+
else:
|
209 |
+
rand_num = 0.1
|
210 |
+
if random_number > rand_num:
|
211 |
+
if not self.write:
|
212 |
+
if hidden_states.shape[1] == (self.height // 32) * (self.width // 32):
|
213 |
+
attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:]
|
214 |
+
else:
|
215 |
+
attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:]
|
216 |
+
else:
|
217 |
+
if hidden_states.shape[1] == (self.height // 32) * (self.width // 32):
|
218 |
+
attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length,
|
219 |
+
:mask1024.shape[0] // self.total_length * self.id_length]
|
220 |
+
else:
|
221 |
+
attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length,
|
222 |
+
:mask4096.shape[0] // self.total_length * self.id_length]
|
223 |
+
hidden_states = self.__call1__(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
|
224 |
+
else:
|
225 |
+
hidden_states = self.__call2__(attn, hidden_states, None, attention_mask, temb)
|
226 |
+
attn_count += 1
|
227 |
+
if attn_count == total_count:
|
228 |
+
attn_count = 0
|
229 |
+
cur_step += 1
|
230 |
+
mask1024, mask4096 = cal_attn_mask_xl(self.total_length,
|
231 |
+
self.id_length,
|
232 |
+
self.sa32,
|
233 |
+
self.sa64,
|
234 |
+
self.height,
|
235 |
+
self.width,
|
236 |
+
device=self.device,
|
237 |
+
dtype=self.dtype)
|
238 |
+
self.global_attn_args["mask1024"] = mask1024
|
239 |
+
self.global_attn_args["mask4096"] = mask4096
|
240 |
+
|
241 |
+
self.global_attn_args["attn_count"] = attn_count
|
242 |
+
self.global_attn_args["cur_step"] = cur_step
|
243 |
+
|
244 |
+
return hidden_states
|
245 |
+
|
246 |
+
def __call1__(
|
247 |
+
self,
|
248 |
+
attn,
|
249 |
+
hidden_states,
|
250 |
+
encoder_hidden_states=None,
|
251 |
+
attention_mask=None,
|
252 |
+
temb=None,
|
253 |
+
):
|
254 |
+
residual = hidden_states
|
255 |
+
if attn.spatial_norm is not None:
|
256 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
257 |
+
input_ndim = hidden_states.ndim
|
258 |
+
|
259 |
+
if input_ndim == 4:
|
260 |
+
total_batch_size, channel, height, width = hidden_states.shape
|
261 |
+
hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2)
|
262 |
+
total_batch_size, nums_token, channel = hidden_states.shape
|
263 |
+
img_nums = total_batch_size // 2
|
264 |
+
hidden_states = hidden_states.view(-1, img_nums, nums_token, channel).reshape(-1, img_nums * nums_token, channel)
|
265 |
+
|
266 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
267 |
+
|
268 |
+
if attn.group_norm is not None:
|
269 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
270 |
+
|
271 |
+
query = attn.to_q(hidden_states)
|
272 |
+
|
273 |
+
if encoder_hidden_states is None:
|
274 |
+
encoder_hidden_states = hidden_states # B, N, C
|
275 |
+
else:
|
276 |
+
encoder_hidden_states = encoder_hidden_states.view(-1, self.id_length + 1, nums_token, channel).reshape(
|
277 |
+
-1, (self.id_length + 1) * nums_token, channel)
|
278 |
+
|
279 |
+
key = attn.to_k(encoder_hidden_states)
|
280 |
+
value = attn.to_v(encoder_hidden_states)
|
281 |
+
|
282 |
+
|
283 |
+
inner_dim = key.shape[-1]
|
284 |
+
head_dim = inner_dim // attn.heads
|
285 |
+
|
286 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
287 |
+
|
288 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
289 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
290 |
+
hidden_states = F.scaled_dot_product_attention(
|
291 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
292 |
+
)
|
293 |
+
|
294 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim)
|
295 |
+
hidden_states = hidden_states.to(query.dtype)
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
# linear proj
|
300 |
+
hidden_states = attn.to_out[0](hidden_states)
|
301 |
+
# dropout
|
302 |
+
hidden_states = attn.to_out[1](hidden_states)
|
303 |
+
|
304 |
+
|
305 |
+
if input_ndim == 4:
|
306 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width)
|
307 |
+
if attn.residual_connection:
|
308 |
+
hidden_states = hidden_states + residual
|
309 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
310 |
+
# print(hidden_states.shape)
|
311 |
+
return hidden_states
|
312 |
+
|
313 |
+
def __call2__(
|
314 |
+
self,
|
315 |
+
attn,
|
316 |
+
hidden_states,
|
317 |
+
encoder_hidden_states=None,
|
318 |
+
attention_mask=None,
|
319 |
+
temb=None):
|
320 |
+
residual = hidden_states
|
321 |
+
|
322 |
+
if attn.spatial_norm is not None:
|
323 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
324 |
+
|
325 |
+
input_ndim = hidden_states.ndim
|
326 |
+
|
327 |
+
if input_ndim == 4:
|
328 |
+
batch_size, channel, height, width = hidden_states.shape
|
329 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
330 |
+
|
331 |
+
batch_size, sequence_length, channel = (
|
332 |
+
hidden_states.shape
|
333 |
+
)
|
334 |
+
# print(hidden_states.shape)
|
335 |
+
if attention_mask is not None:
|
336 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
337 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
338 |
+
# (batch, heads, source_length, target_length)
|
339 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
340 |
+
|
341 |
+
if attn.group_norm is not None:
|
342 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
343 |
+
|
344 |
+
query = attn.to_q(hidden_states)
|
345 |
+
|
346 |
+
if encoder_hidden_states is None:
|
347 |
+
encoder_hidden_states = hidden_states # B, N, C
|
348 |
+
else:
|
349 |
+
encoder_hidden_states = encoder_hidden_states.view(-1, self.id_length + 1, sequence_length, channel).reshape(
|
350 |
+
-1, (self.id_length + 1) * sequence_length, channel)
|
351 |
+
|
352 |
+
key = attn.to_k(encoder_hidden_states)
|
353 |
+
value = attn.to_v(encoder_hidden_states)
|
354 |
+
|
355 |
+
inner_dim = key.shape[-1]
|
356 |
+
head_dim = inner_dim // attn.heads
|
357 |
+
|
358 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
359 |
+
|
360 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
361 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
362 |
+
|
363 |
+
hidden_states = F.scaled_dot_product_attention(
|
364 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
365 |
+
)
|
366 |
+
|
367 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
368 |
+
hidden_states = hidden_states.to(query.dtype)
|
369 |
+
|
370 |
+
# linear proj
|
371 |
+
hidden_states = attn.to_out[0](hidden_states)
|
372 |
+
# dropout
|
373 |
+
hidden_states = attn.to_out[1](hidden_states)
|
374 |
+
|
375 |
+
if input_ndim == 4:
|
376 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
377 |
+
|
378 |
+
if attn.residual_connection:
|
379 |
+
hidden_states = hidden_states + residual
|
380 |
+
|
381 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
382 |
+
|
383 |
+
return hidden_states
|
384 |
+
|
385 |
+
|
386 |
+
class StoryDiffusionSynthesizer:
|
387 |
+
|
388 |
+
def __init__(self,
|
389 |
+
num_pages: int,
|
390 |
+
height: int,
|
391 |
+
width: int,
|
392 |
+
model_name: str = "stabilityai/stable-diffusion-xl-base-1.0",
|
393 |
+
model_path: str = None,
|
394 |
+
id_length: int = 4,
|
395 |
+
num_steps: int = 50):
|
396 |
+
self.attn_args = {
|
397 |
+
"attn_count": 0,
|
398 |
+
"cur_step": 0,
|
399 |
+
"total_count": 0,
|
400 |
+
}
|
401 |
+
self.sa32 = 0.5
|
402 |
+
self.sa64 = 0.5
|
403 |
+
self.id_length = id_length
|
404 |
+
self.total_length = num_pages
|
405 |
+
self.height = height
|
406 |
+
self.width = width
|
407 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
408 |
+
self.dtype = torch.float16
|
409 |
+
self.num_steps = num_steps
|
410 |
+
self.styles = {
|
411 |
+
'(No style)': (
|
412 |
+
'{prompt}',
|
413 |
+
''),
|
414 |
+
'Japanese Anime': (
|
415 |
+
'anime artwork illustrating {prompt}. created by japanese anime studio. highly emotional. best quality, high resolution, (Anime Style, Manga Style:1.3), Low detail, sketch, concept art, line art, webtoon, manhua, hand drawn, defined lines, simple shades, minimalistic, High contrast, Linear compositions, Scalable artwork, Digital art, High Contrast Shadows',
|
416 |
+
'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'),
|
417 |
+
'Digital/Oil Painting': (
|
418 |
+
'{prompt} . (Extremely Detailed Oil Painting:1.2), glow effects, godrays, Hand drawn, render, 8k, octane render, cinema 4d, blender, dark, atmospheric 4k ultra detailed, cinematic sensual, Sharp focus, humorous illustration, big depth of field',
|
419 |
+
'anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'),
|
420 |
+
'Pixar/Disney Character': (
|
421 |
+
'Create a Disney Pixar 3D style illustration on {prompt} . The scene is vibrant, motivational, filled with vivid colors and a sense of wonder.',
|
422 |
+
'lowres, bad anatomy, bad hands, text, bad eyes, bad arms, bad legs, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, blurry, grayscale, noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo'),
|
423 |
+
'Photographic': (
|
424 |
+
'cinematic photo {prompt} . Hyperrealistic, Hyperdetailed, detailed skin, matte skin, soft lighting, realistic, best quality, ultra realistic, 8k, golden ratio, Intricate, High Detail, film photography, soft focus',
|
425 |
+
'drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'),
|
426 |
+
'Comic book': (
|
427 |
+
'comic {prompt} . graphic illustration, comic art, graphic novel art, vibrant, highly detailed',
|
428 |
+
'photograph, deformed, glitch, noisy, realistic, stock photo, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'),
|
429 |
+
'Line art': (
|
430 |
+
'line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics',
|
431 |
+
'anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'),
|
432 |
+
'Black and White Film Noir': (
|
433 |
+
'{prompt} . (b&w, Monochromatic, Film Photography:1.3), film noir, analog style, soft lighting, subsurface scattering, realistic, heavy shadow, masterpiece, best quality, ultra realistic, 8k',
|
434 |
+
'anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'),
|
435 |
+
'Isometric Rooms': (
|
436 |
+
'Tiny cute isometric {prompt} . in a cutaway box, soft smooth lighting, soft colors, 100mm lens, 3d blender render',
|
437 |
+
'anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry'),
|
438 |
+
'Storybook': (
|
439 |
+
"Cartoon style, cute illustration of {prompt}.",
|
440 |
+
'realism, photo, realistic, lowres, bad hands, bad eyes, bad arms, bad legs, error, missing fingers, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, grayscale, noisy, sloppy, messy, grainy, ultra textured'
|
441 |
+
)
|
442 |
+
}
|
443 |
+
|
444 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
445 |
+
model_path if model_path is not None else model_name,
|
446 |
+
torch_dtype=torch.float16,
|
447 |
+
use_safetensors=True
|
448 |
+
)
|
449 |
+
|
450 |
+
pipe = pipe.to(self.device)
|
451 |
+
|
452 |
+
# pipe.id_encoder.to(self.device)
|
453 |
+
|
454 |
+
pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
|
455 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
456 |
+
pipe.scheduler.set_timesteps(num_steps)
|
457 |
+
unet = pipe.unet
|
458 |
+
|
459 |
+
attn_procs = {}
|
460 |
+
### Insert PairedAttention
|
461 |
+
for name in unet.attn_processors.keys():
|
462 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
463 |
+
if name.startswith("mid_block"):
|
464 |
+
hidden_size = unet.config.block_out_channels[-1]
|
465 |
+
elif name.startswith("up_blocks"):
|
466 |
+
block_id = int(name[len("up_blocks.")])
|
467 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
468 |
+
elif name.startswith("down_blocks"):
|
469 |
+
block_id = int(name[len("down_blocks.")])
|
470 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
471 |
+
if cross_attention_dim is None and (name.startswith("up_blocks") ) :
|
472 |
+
attn_procs[name] = SpatialAttnProcessor2_0(
|
473 |
+
id_length=self.id_length,
|
474 |
+
device=self.device,
|
475 |
+
height=self.height,
|
476 |
+
width=self.width,
|
477 |
+
sa32=self.sa32,
|
478 |
+
sa64=self.sa64,
|
479 |
+
global_attn_args=self.attn_args
|
480 |
+
)
|
481 |
+
self.attn_args["total_count"] += 1
|
482 |
+
else:
|
483 |
+
attn_procs[name] = AttnProcessor()
|
484 |
+
print("successsfully load consistent self-attention")
|
485 |
+
print(f"number of the processor : {self.attn_args['total_count']}")
|
486 |
+
# unet.set_attn_processor(copy.deepcopy(attn_procs))
|
487 |
+
unet.set_attn_processor(attn_procs)
|
488 |
+
mask1024, mask4096 = cal_attn_mask_xl(
|
489 |
+
self.total_length,
|
490 |
+
self.id_length,
|
491 |
+
self.sa32,
|
492 |
+
self.sa64,
|
493 |
+
self.height,
|
494 |
+
self.width,
|
495 |
+
device=self.device,
|
496 |
+
dtype=torch.float16,
|
497 |
+
)
|
498 |
+
|
499 |
+
self.attn_args.update({
|
500 |
+
"mask1024": mask1024,
|
501 |
+
"mask4096": mask4096
|
502 |
+
})
|
503 |
+
|
504 |
+
self.pipe = pipe
|
505 |
+
self.negative_prompt = "naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation," \
|
506 |
+
"extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating" \
|
507 |
+
"limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation"
|
508 |
+
|
509 |
+
def set_attn_write(self,
|
510 |
+
value: bool):
|
511 |
+
unet = self.pipe.unet
|
512 |
+
for name, processor in unet.attn_processors.items():
|
513 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
514 |
+
if cross_attention_dim is None:
|
515 |
+
if name.startswith("up_blocks") :
|
516 |
+
assert isinstance(processor, SpatialAttnProcessor2_0)
|
517 |
+
processor.write = value
|
518 |
+
|
519 |
+
def apply_style(self, style_name: str, positives: list, negative: str = ""):
|
520 |
+
p, n = self.styles.get(style_name, self.styles["(No style)"])
|
521 |
+
return [p.replace("{prompt}", positive) for positive in positives], n + ' ' + negative
|
522 |
+
|
523 |
+
def apply_style_positive(self, style_name: str, positive: str):
|
524 |
+
p, n = self.styles.get(style_name, self.styles["(No style)"])
|
525 |
+
return p.replace("{prompt}", positive)
|
526 |
+
|
527 |
+
def call(self,
|
528 |
+
prompts: List[str],
|
529 |
+
input_id_images = None,
|
530 |
+
start_merge_step = None,
|
531 |
+
style_name: str = "Pixar/Disney Character",
|
532 |
+
guidance_scale: float = 5.0,
|
533 |
+
seed: int = 2047):
|
534 |
+
assert len(prompts) == self.total_length, "The number of prompts should be equal to the number of pages."
|
535 |
+
setup_seed(seed)
|
536 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
537 |
+
torch.cuda.empty_cache()
|
538 |
+
|
539 |
+
id_prompts = prompts[:self.id_length]
|
540 |
+
real_prompts = prompts[self.id_length:]
|
541 |
+
self.set_attn_write(True)
|
542 |
+
self.attn_args.update({
|
543 |
+
"cur_step": 0,
|
544 |
+
"attn_count": 0
|
545 |
+
})
|
546 |
+
id_prompts, negative_prompt = self.apply_style(style_name, id_prompts, self.negative_prompt)
|
547 |
+
id_images = self.pipe(
|
548 |
+
id_prompts,
|
549 |
+
input_id_images=input_id_images,
|
550 |
+
start_merge_step=start_merge_step,
|
551 |
+
num_inference_steps=self.num_steps,
|
552 |
+
guidance_scale=guidance_scale,
|
553 |
+
height=self.height,
|
554 |
+
width=self.width,
|
555 |
+
negative_prompt=negative_prompt,
|
556 |
+
generator=generator).images
|
557 |
+
|
558 |
+
self.set_attn_write(False)
|
559 |
+
real_images = []
|
560 |
+
for real_prompt in real_prompts:
|
561 |
+
self.attn_args["cur_step"] = 0
|
562 |
+
real_prompt = self.apply_style_positive(style_name, real_prompt)
|
563 |
+
real_images.append(self.pipe(
|
564 |
+
real_prompt,
|
565 |
+
num_inference_steps=self.num_steps,
|
566 |
+
guidance_scale=guidance_scale,
|
567 |
+
height=self.height,
|
568 |
+
width=self.width,
|
569 |
+
negative_prompt=negative_prompt,
|
570 |
+
generator=generator).images[0]
|
571 |
+
)
|
572 |
+
|
573 |
+
images = id_images + real_images
|
574 |
+
return images
|
575 |
+
|
576 |
+
|
577 |
+
class StoryDiffusionAgent:
|
578 |
+
|
579 |
+
def __init__(self, config, llm_type="qwen2") -> None:
|
580 |
+
self.config = config
|
581 |
+
if llm_type == "qwen2":
|
582 |
+
self.LLM = QwenAgent
|
583 |
+
|
584 |
+
def call(self, pages: List, save_path: str):
|
585 |
+
role_dict = self.extract_role_from_story(pages, **self.config["revise_cfg"])
|
586 |
+
image_prompts = self.generate_image_prompt_from_story(pages, **self.config["revise_cfg"])
|
587 |
+
image_prompts_with_role_desc = []
|
588 |
+
for image_prompt in image_prompts:
|
589 |
+
for role, role_desc in role_dict.items():
|
590 |
+
if role in image_prompt:
|
591 |
+
image_prompt = image_prompt.replace(role, role_desc)
|
592 |
+
image_prompts_with_role_desc.append(image_prompt)
|
593 |
+
generation_agent = StoryDiffusionSynthesizer(
|
594 |
+
num_pages=len(pages),
|
595 |
+
**self.config["obj_cfg"]
|
596 |
+
)
|
597 |
+
images = generation_agent.call(
|
598 |
+
image_prompts_with_role_desc,
|
599 |
+
**self.config["call_cfg"]
|
600 |
+
)
|
601 |
+
for idx, image in enumerate(images):
|
602 |
+
image.save(save_path / f"p{idx + 1}.png")
|
603 |
+
return {
|
604 |
+
"prompts": image_prompts_with_role_desc,
|
605 |
+
"modality": "image",
|
606 |
+
"generation_results": images,
|
607 |
+
}
|
608 |
+
|
609 |
+
def extract_role_from_story(
|
610 |
+
self,
|
611 |
+
pages: List,
|
612 |
+
num_turns: int = 3
|
613 |
+
):
|
614 |
+
role_extractor = self.LLM(role_extract_system, track_history=False)
|
615 |
+
role_reviewer = self.LLM(role_review_system, track_history=False)
|
616 |
+
roles = {}
|
617 |
+
review = ""
|
618 |
+
for turn in range(num_turns):
|
619 |
+
roles, success = role_extractor.run(json.dumps({
|
620 |
+
"story_content": pages,
|
621 |
+
"previous_result": roles,
|
622 |
+
"improvement_suggestions": review,
|
623 |
+
}, ensure_ascii=False
|
624 |
+
))
|
625 |
+
roles = json.loads(roles.strip("```json").strip("```"))
|
626 |
+
review, success = role_reviewer.run(json.dumps({
|
627 |
+
"story_content": pages,
|
628 |
+
"role_descriptions": roles
|
629 |
+
}, ensure_ascii=False))
|
630 |
+
if review == "Check passed.":
|
631 |
+
break
|
632 |
+
return roles
|
633 |
+
|
634 |
+
def generate_image_prompt_from_story(
|
635 |
+
self,
|
636 |
+
pages: List,
|
637 |
+
num_turns: int = 3
|
638 |
+
):
|
639 |
+
image_prompt_rewriter = self.LLM(story_to_image_reviser_system, track_history=False)
|
640 |
+
image_prompt_reviewer = self.LLM(story_to_image_review_system, track_history=False)
|
641 |
+
image_prompts = []
|
642 |
+
|
643 |
+
for page in pages:
|
644 |
+
review = ""
|
645 |
+
image_prompt = ""
|
646 |
+
for turn in range(num_turns):
|
647 |
+
image_prompt, success = image_prompt_rewriter.run(json.dumps({
|
648 |
+
"all_pages": pages,
|
649 |
+
"current_page": page,
|
650 |
+
"previous_result": image_prompt,
|
651 |
+
"improvement_suggestions": review,
|
652 |
+
}, ensure_ascii=False))
|
653 |
+
if image_prompt.startswith("Image description:"):
|
654 |
+
image_prompt = image_prompt[len("Image description:"):]
|
655 |
+
review, success = image_prompt_reviewer.run(json.dumps({
|
656 |
+
"all_pages": pages,
|
657 |
+
"current_page": page,
|
658 |
+
"image_description": image_prompt
|
659 |
+
}, ensure_ascii=False))
|
660 |
+
if review == "Check passed.":
|
661 |
+
break
|
662 |
+
image_prompts.append(image_prompt)
|
663 |
+
return image_prompts
|
mm_story_agent/modality_agents/llm.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
import os
|
3 |
+
|
4 |
+
from dashscope import Generation
|
5 |
+
|
6 |
+
|
7 |
+
class QwenAgent(object):
|
8 |
+
|
9 |
+
def __init__(self,
|
10 |
+
system_prompt: str = None,
|
11 |
+
track_history: bool = True):
|
12 |
+
self.system_prompt = system_prompt
|
13 |
+
if system_prompt is None:
|
14 |
+
self.history = []
|
15 |
+
else:
|
16 |
+
self.history = [
|
17 |
+
{"role": "system", "content": system_prompt}
|
18 |
+
]
|
19 |
+
self.track_history = track_history
|
20 |
+
|
21 |
+
def basic_success_check(self, response):
|
22 |
+
if not response or not response.output or not response.output.text:
|
23 |
+
print(response)
|
24 |
+
return False
|
25 |
+
else:
|
26 |
+
return True
|
27 |
+
|
28 |
+
def run(self,
|
29 |
+
prompt: str,
|
30 |
+
top_p: float = 0.95,
|
31 |
+
temperature: float = 1.0,
|
32 |
+
seed: int = 1,
|
33 |
+
max_length: int = 1024,
|
34 |
+
max_try: int = 5,
|
35 |
+
success_check_fn: Callable = None
|
36 |
+
):
|
37 |
+
self.history.append({
|
38 |
+
"role": "user",
|
39 |
+
"content": prompt
|
40 |
+
})
|
41 |
+
success = False
|
42 |
+
try_times = 0
|
43 |
+
while try_times < max_try:
|
44 |
+
response = Generation.call(
|
45 |
+
model="qwen2-72b-instruct",
|
46 |
+
messages=self.history,
|
47 |
+
top_p=top_p,
|
48 |
+
temperature=temperature,
|
49 |
+
api_key=os.environ.get('DASHSCOPE_API_KEY'),
|
50 |
+
seed=seed,
|
51 |
+
max_length=max_length
|
52 |
+
)
|
53 |
+
if success_check_fn is None:
|
54 |
+
success_check_fn = lambda x: True
|
55 |
+
if self.basic_success_check(response) and success_check_fn(response.output.text):
|
56 |
+
response = response.output.text
|
57 |
+
self.history.append({
|
58 |
+
"role": "assistant",
|
59 |
+
"content": response
|
60 |
+
})
|
61 |
+
success = True
|
62 |
+
break
|
63 |
+
else:
|
64 |
+
try_times += 1
|
65 |
+
|
66 |
+
if not self.track_history:
|
67 |
+
if self.system_prompt is not None:
|
68 |
+
self.history = self.history[:1]
|
69 |
+
else:
|
70 |
+
self.history = []
|
71 |
+
|
72 |
+
return response, success
|
73 |
+
|
mm_story_agent/modality_agents/music_agent.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import json
|
3 |
+
from typing import List, Union
|
4 |
+
|
5 |
+
import torchaudio
|
6 |
+
from audiocraft.models import MusicGen
|
7 |
+
from audiocraft.data.audio import audio_write
|
8 |
+
|
9 |
+
from mm_story_agent.modality_agents.llm import QwenAgent
|
10 |
+
from mm_story_agent.prompts_en import story_to_music_reviser_system, story_to_music_reviewer_system
|
11 |
+
|
12 |
+
|
13 |
+
class MusicGenSynthesizer:
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
model_name: str = 'facebook/musicgen-medium',
|
17 |
+
sample_rate: int = 16000,
|
18 |
+
) -> None:
|
19 |
+
self.model = MusicGen.get_pretrained(model_name)
|
20 |
+
self.sample_rate = sample_rate
|
21 |
+
|
22 |
+
def call(self,
|
23 |
+
prompt: Union[str, List[str]],
|
24 |
+
save_path: Union[str, Path],
|
25 |
+
duration: float = 60.0,
|
26 |
+
):
|
27 |
+
self.model.set_generation_params(duration=duration)
|
28 |
+
wav = self.model.generate([prompt], progress=True)[0].cpu()
|
29 |
+
wav = torchaudio.functional.resample(wav, self.model.sample_rate, self.sample_rate)
|
30 |
+
save_path = Path(save_path).parent / Path(save_path).stem
|
31 |
+
audio_write(save_path, wav, self.sample_rate)
|
32 |
+
|
33 |
+
|
34 |
+
class MusicGenAgent:
|
35 |
+
|
36 |
+
def __init__(self, config, llm_type="qwen2") -> None:
|
37 |
+
self.config = config
|
38 |
+
if llm_type == "qwen2":
|
39 |
+
self.LLM = QwenAgent
|
40 |
+
|
41 |
+
def generate_music_prompt_from_story(
|
42 |
+
self,
|
43 |
+
pages: List,
|
44 |
+
num_turns: int = 3
|
45 |
+
):
|
46 |
+
music_prompt_reviser = self.LLM(story_to_music_reviser_system, track_history=False)
|
47 |
+
music_prompt_reviewer = self.LLM(story_to_music_reviewer_system, track_history=False)
|
48 |
+
|
49 |
+
music_prompt = ""
|
50 |
+
review = ""
|
51 |
+
for turn in range(num_turns):
|
52 |
+
music_prompt, success = music_prompt_reviser.run(json.dumps({
|
53 |
+
"story": pages,
|
54 |
+
"previous_result": music_prompt,
|
55 |
+
"improvement_suggestions": review,
|
56 |
+
}, ensure_ascii=False))
|
57 |
+
review, success = music_prompt_reviewer.run(json.dumps({
|
58 |
+
"story_content": pages,
|
59 |
+
"music_description": music_prompt
|
60 |
+
}, ensure_ascii=False))
|
61 |
+
if review == "Check passed.":
|
62 |
+
break
|
63 |
+
|
64 |
+
return music_prompt
|
65 |
+
|
66 |
+
def call(self, pages: List, save_path: str):
|
67 |
+
save_path = Path(save_path)
|
68 |
+
music_prompt = self.generate_music_prompt_from_story(pages, **self.config["revise_cfg"])
|
69 |
+
generation_agent = MusicGenSynthesizer()
|
70 |
+
generation_agent.call(
|
71 |
+
prompt=music_prompt,
|
72 |
+
save_path=save_path / "music.wav",
|
73 |
+
**self.config["call_cfg"]
|
74 |
+
)
|
75 |
+
return {
|
76 |
+
"prompt": music_prompt,
|
77 |
+
"modality": "music"
|
78 |
+
}
|
mm_story_agent/modality_agents/sound_agent.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import List
|
3 |
+
import json
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import soundfile as sf
|
7 |
+
from diffusers import AudioLDM2Pipeline
|
8 |
+
|
9 |
+
from mm_story_agent.prompts_en import story_to_sound_reviser_system, story_to_sound_review_system
|
10 |
+
from mm_story_agent.modality_agents.llm import QwenAgent
|
11 |
+
|
12 |
+
|
13 |
+
class AudioLDM2Synthesizer:
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
model_path: str = None,
|
17 |
+
) -> None:
|
18 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
self.pipe = AudioLDM2Pipeline.from_pretrained(
|
20 |
+
model_path if model_path is not None else "cvssp/audioldm2",
|
21 |
+
torch_dtype=torch.float16
|
22 |
+
).to(self.device)
|
23 |
+
|
24 |
+
def call(self,
|
25 |
+
prompts: List[str],
|
26 |
+
n_candidate_per_text: int = 3,
|
27 |
+
seed: int = 0,
|
28 |
+
guidance_scale: float = 3.5,
|
29 |
+
ddim_steps: int = 100,
|
30 |
+
):
|
31 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
32 |
+
audios = self.pipe(
|
33 |
+
prompts,
|
34 |
+
num_inference_steps=ddim_steps,
|
35 |
+
audio_length_in_s=10.0,
|
36 |
+
guidance_scale=guidance_scale,
|
37 |
+
generator=generator,
|
38 |
+
num_waveforms_per_prompt=n_candidate_per_text).audios
|
39 |
+
|
40 |
+
audios = audios[::n_candidate_per_text]
|
41 |
+
|
42 |
+
return audios
|
43 |
+
|
44 |
+
|
45 |
+
class AudioLDM2Agent:
|
46 |
+
|
47 |
+
def __init__(self, config, llm_type="qwen2") -> None:
|
48 |
+
self.config = config
|
49 |
+
if llm_type == "qwen2":
|
50 |
+
self.LLM = QwenAgent
|
51 |
+
|
52 |
+
def call(self, pages: List, save_path: str):
|
53 |
+
sound_prompts = self.generate_sound_prompt_from_story(pages, **self.config["revise_cfg"])
|
54 |
+
save_paths = []
|
55 |
+
forward_prompts = []
|
56 |
+
save_path = Path(save_path)
|
57 |
+
for idx in range(len(pages)):
|
58 |
+
if sound_prompts[idx] != "No sounds.":
|
59 |
+
save_paths.append(save_path / f"p{idx + 1}.wav")
|
60 |
+
forward_prompts.append(sound_prompts[idx])
|
61 |
+
|
62 |
+
generation_agent = AudioLDM2Synthesizer()
|
63 |
+
if len(forward_prompts) > 0:
|
64 |
+
sounds = generation_agent.call(
|
65 |
+
forward_prompts,
|
66 |
+
**self.config["call_cfg"]
|
67 |
+
)
|
68 |
+
for sound, path in zip(sounds, save_paths):
|
69 |
+
sf.write(path.__str__(), sound, self.config["sample_rate"])
|
70 |
+
return {
|
71 |
+
"prompts": sound_prompts,
|
72 |
+
"modality": "sound"
|
73 |
+
}
|
74 |
+
|
75 |
+
def generate_sound_prompt_from_story(
|
76 |
+
self,
|
77 |
+
pages: List,
|
78 |
+
num_turns: int = 3
|
79 |
+
):
|
80 |
+
sound_prompt_reviser = self.LLM(story_to_sound_reviser_system, track_history=False)
|
81 |
+
sound_prompt_reviewer = self.LLM(story_to_sound_review_system, track_history=False)
|
82 |
+
|
83 |
+
sound_prompts = []
|
84 |
+
for page in pages:
|
85 |
+
review = ""
|
86 |
+
sound_prompt = ""
|
87 |
+
for turn in range(num_turns):
|
88 |
+
sound_prompt, success = sound_prompt_reviser.run(json.dumps({
|
89 |
+
"story": page,
|
90 |
+
"previous_result": sound_prompt,
|
91 |
+
"improvement_suggestions": review,
|
92 |
+
}, ensure_ascii=False))
|
93 |
+
if sound_prompt.startswith("Sound description:"):
|
94 |
+
sound_prompt = sound_prompt[len("Sound description:"):]
|
95 |
+
review, success = sound_prompt_reviewer.run(json.dumps({
|
96 |
+
"story": page,
|
97 |
+
"sound_description": sound_prompt
|
98 |
+
}, ensure_ascii=False))
|
99 |
+
if review == "Check passed.":
|
100 |
+
break
|
101 |
+
# else:
|
102 |
+
# print(review)
|
103 |
+
sound_prompts.append(sound_prompt)
|
104 |
+
|
105 |
+
return sound_prompts
|
106 |
+
|
mm_story_agent/modality_agents/speech_agent.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from aliyunsdkcore.client import AcsClient
|
7 |
+
from aliyunsdkcore.request import CommonRequest
|
8 |
+
import nls
|
9 |
+
|
10 |
+
|
11 |
+
class CosyVoiceSynthesizer:
|
12 |
+
|
13 |
+
def __init__(self) -> None:
|
14 |
+
self.access_key_id = os.environ.get('ALIYUN_ACCESS_KEY_ID')
|
15 |
+
self.access_key_secret = os.environ.get('ALIYUN_ACCESS_KEY_SECRET')
|
16 |
+
self.app_key = os.environ.get('ALIYUN_APP_KEY')
|
17 |
+
self.setup_token()
|
18 |
+
|
19 |
+
def setup_token(self):
|
20 |
+
client = AcsClient(self.access_key_id, self.access_key_secret,
|
21 |
+
'cn-shanghai')
|
22 |
+
request = CommonRequest()
|
23 |
+
request.set_method('POST')
|
24 |
+
request.set_domain('nls-meta.cn-shanghai.aliyuncs.com')
|
25 |
+
request.set_version('2019-02-28')
|
26 |
+
request.set_action_name('CreateToken')
|
27 |
+
|
28 |
+
try:
|
29 |
+
response = client.do_action_with_exception(request)
|
30 |
+
jss = json.loads(response)
|
31 |
+
if 'Token' in jss and 'Id' in jss['Token']:
|
32 |
+
token = jss['Token']['Id']
|
33 |
+
self.token = token
|
34 |
+
except Exception as e:
|
35 |
+
import traceback
|
36 |
+
raise RuntimeError(
|
37 |
+
f'Request token failed with error: {e}, with detail {traceback.format_exc()}'
|
38 |
+
)
|
39 |
+
|
40 |
+
def call(self, save_file, transcript, voice="longyuan", sample_rate=16000):
|
41 |
+
writer = open(save_file, "wb")
|
42 |
+
return_data = b''
|
43 |
+
|
44 |
+
def write_data(data, *args):
|
45 |
+
nonlocal return_data
|
46 |
+
return_data += data
|
47 |
+
if writer is not None:
|
48 |
+
writer.write(data)
|
49 |
+
|
50 |
+
def raise_error(error, *args):
|
51 |
+
raise RuntimeError(
|
52 |
+
f'Synthesizing speech failed with error: {error}')
|
53 |
+
|
54 |
+
def close_file(*args):
|
55 |
+
if writer is not None:
|
56 |
+
writer.close()
|
57 |
+
|
58 |
+
sdk = nls.NlsStreamInputTtsSynthesizer(
|
59 |
+
url='wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1',
|
60 |
+
token=self.token,
|
61 |
+
appkey=self.app_key,
|
62 |
+
on_data=write_data,
|
63 |
+
on_error=raise_error,
|
64 |
+
on_close=close_file,
|
65 |
+
)
|
66 |
+
|
67 |
+
sdk.startStreamInputTts(voice=voice, sample_rate=sample_rate, aformat='wav')
|
68 |
+
sdk.sendStreamInputTts(transcript,)
|
69 |
+
sdk.stopStreamInputTts()
|
70 |
+
|
71 |
+
|
72 |
+
class CosyVoiceAgent:
|
73 |
+
|
74 |
+
def __init__(self, config) -> None:
|
75 |
+
self.config = config
|
76 |
+
|
77 |
+
def call(self, pages: List, save_path: str):
|
78 |
+
save_path = Path(save_path)
|
79 |
+
generation_agent = CosyVoiceSynthesizer()
|
80 |
+
|
81 |
+
for idx, page in enumerate(pages):
|
82 |
+
generation_agent.call(
|
83 |
+
save_file=save_path / f"p{idx + 1}.wav",
|
84 |
+
transcript=page,
|
85 |
+
**self.config["call_cfg"]
|
86 |
+
)
|
87 |
+
|
88 |
+
return {
|
89 |
+
"modality": "speech"
|
90 |
+
}
|
mm_story_agent/modality_agents/story_agent.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
|
4 |
+
from tqdm import trange, tqdm
|
5 |
+
|
6 |
+
from mm_story_agent.modality_agents.llm import QwenAgent
|
7 |
+
from mm_story_agent.prompts_en import question_asker_system, expert_system, \
|
8 |
+
dlg_based_writer_system, dlg_based_writer_prompt, chapter_writer_system
|
9 |
+
|
10 |
+
|
11 |
+
def parse_list(output):
|
12 |
+
try:
|
13 |
+
pages = eval(output)
|
14 |
+
return True
|
15 |
+
except Exception:
|
16 |
+
return False
|
17 |
+
|
18 |
+
|
19 |
+
def json_parse_outline(outline):
|
20 |
+
outline = outline.strip("```json").strip("```")
|
21 |
+
try:
|
22 |
+
outline = json.loads(outline)
|
23 |
+
if not isinstance(outline, dict):
|
24 |
+
return False
|
25 |
+
if outline.keys() != {"story_title", "story_outline"}:
|
26 |
+
return False
|
27 |
+
for chapter in outline["story_outline"]:
|
28 |
+
if chapter.keys() != {"chapter_title", "chapter_summary"}:
|
29 |
+
return False
|
30 |
+
except json.decoder.JSONDecodeError:
|
31 |
+
return False
|
32 |
+
return True
|
33 |
+
|
34 |
+
|
35 |
+
class QAOutlineStoryWriter:
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
story_gen_config,
|
39 |
+
llm_type: str = "qwen2"):
|
40 |
+
if llm_type == "qwen2":
|
41 |
+
self.LLM = QwenAgent
|
42 |
+
self.story_gen_config = story_gen_config
|
43 |
+
|
44 |
+
def generate_outline(self, story_setting):
|
45 |
+
temperature = self.story_gen_config["temperature"]
|
46 |
+
max_conv_turns = self.story_gen_config["max_conv_turns"]
|
47 |
+
num_outline = self.story_gen_config["num_outline"]
|
48 |
+
asker = self.LLM(question_asker_system, track_history=False)
|
49 |
+
expert = self.LLM(expert_system, track_history=False)
|
50 |
+
|
51 |
+
dialogue = []
|
52 |
+
for turn in trange(max_conv_turns):
|
53 |
+
dialogue_history = "\n".join(dialogue)
|
54 |
+
|
55 |
+
question, success = asker.run(f"Story setting: {story_setting}\nDialogue history: \n{dialogue_history}\n", temperature=temperature)
|
56 |
+
question = question.strip()
|
57 |
+
if question == "Thank you for your help!":
|
58 |
+
break
|
59 |
+
dialogue.append(f"You: {question}")
|
60 |
+
answer, success = expert.run(f"Story setting: {story_setting}\nQuestion: \n{question}\nAnswer: ", temperature=temperature)
|
61 |
+
answer = answer.strip()
|
62 |
+
dialogue.append(f"Expert: {answer}")
|
63 |
+
|
64 |
+
# print("\n".join(dialogue))
|
65 |
+
writer = self.LLM(dlg_based_writer_system, track_history=False)
|
66 |
+
writer_prompt = dlg_based_writer_prompt.format(
|
67 |
+
story_setting=story_setting,
|
68 |
+
dialogue_history="\n".join(dialogue),
|
69 |
+
num_outline=num_outline
|
70 |
+
)
|
71 |
+
|
72 |
+
outline, success = writer.run(writer_prompt, success_check_fn=json_parse_outline)
|
73 |
+
outline = json.loads(outline)
|
74 |
+
# print(outline)
|
75 |
+
return outline
|
76 |
+
|
77 |
+
def generate_story_from_outline(self, outline):
|
78 |
+
temperature = self.story_gen_config["temperature"]
|
79 |
+
chapter_writer = self.LLM(chapter_writer_system, track_history=False)
|
80 |
+
all_pages = []
|
81 |
+
for idx, chapter in enumerate(tqdm(outline["story_outline"])):
|
82 |
+
chapter_detail, success = chapter_writer.run(
|
83 |
+
json.dumps(
|
84 |
+
{
|
85 |
+
"completed_story": all_pages,
|
86 |
+
"current_chapter": chapter
|
87 |
+
},
|
88 |
+
ensure_ascii=False
|
89 |
+
),
|
90 |
+
success_check_fn=parse_list,
|
91 |
+
temperature=temperature
|
92 |
+
)
|
93 |
+
while success is False:
|
94 |
+
chapter_detail, success = chapter_writer.run(
|
95 |
+
json.dumps(
|
96 |
+
{
|
97 |
+
"completed_story": all_pages,
|
98 |
+
"current_chapter": chapter
|
99 |
+
},
|
100 |
+
ensure_ascii=False
|
101 |
+
),
|
102 |
+
seed=random.randint(0, 100000),
|
103 |
+
temperature=temperature,
|
104 |
+
success_check_fn=parse_list
|
105 |
+
)
|
106 |
+
pages = [page.strip() for page in eval(chapter_detail)]
|
107 |
+
all_pages.extend(pages)
|
108 |
+
# print(all_pages)
|
109 |
+
return all_pages
|
110 |
+
|
111 |
+
def call(self, story_setting):
|
112 |
+
outline = self.generate_outline(story_setting)
|
113 |
+
pages = self.generate_story_from_outline(outline)
|
114 |
+
return pages
|
mm_story_agent/prompts_en.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
instruction = """
|
3 |
+
1. Conciseness: Describe the plot of each chapter in a simple and straightforward manner, using a storybook tone without excessive details.
|
4 |
+
2. Narrative Style: There is no need for dialogue or interaction with the reader.
|
5 |
+
3. Coherent Plot: The story should have a coherent plot, with connections and reflections throughout. All chapters should contribute to the same overarching story, rather than being independent little tales.
|
6 |
+
4. Reasonableness: The plot should make sense, avoiding logical errors and unreasonable elements.
|
7 |
+
5. Educational Value: A good bedtime story should have educational significance, helping children learn proper values and behaviors.
|
8 |
+
6. Warm and Pleasant: The story should evoke a sense of ease, warmth, and joy, making children feel loved and cared for.
|
9 |
+
""".strip()
|
10 |
+
|
11 |
+
|
12 |
+
question_asker_system = """
|
13 |
+
## Basic requirements for children stories:
|
14 |
+
1. Storytelling Style: No need for dialogue or interaction with the reader.
|
15 |
+
2. Coherent Plot: The story plot should be coherent and consistent throughout.
|
16 |
+
3. Logical Consistency: The plot must be logical, without any logical errors or unreasonable elements.
|
17 |
+
4. Educational Significance: An excellent bedtime story should convey certain educational values, helping children learn proper values and behaviors.
|
18 |
+
5. Warm and Pleasant: The story should ideally evoke a feeling of lightness, warmth, and happiness, making children feel loved and cared for.
|
19 |
+
|
20 |
+
## Story setting format
|
21 |
+
The story setting is given as a JSON object, such as:
|
22 |
+
{
|
23 |
+
"story_topic": "xxx",
|
24 |
+
"main_role": "xxx",
|
25 |
+
"scene": "xxx",
|
26 |
+
...
|
27 |
+
}
|
28 |
+
|
29 |
+
You are a student learning to write children stories, discussing writing ideas with an expert.
|
30 |
+
Please ask the expert questions to discuss the information needed for writing a story following the given setting.
|
31 |
+
If you have no more questions, say "Thank you for your help!" to end the conversation.
|
32 |
+
Ask only one question at a time and avoid repeating previously asked questions. Your questions should relate to the given setting, such as the story topic.
|
33 |
+
""".strip()
|
34 |
+
|
35 |
+
|
36 |
+
expert_system = """
|
37 |
+
## Basic requirements for children stories:
|
38 |
+
1. Storytelling Style: No need for dialogue or interaction with the reader.
|
39 |
+
2. Coherent Plot: The story plot should be coherent and consistent throughout.
|
40 |
+
3. Logical Consistency: The plot must be logical, without any logical errors or unreasonable elements.
|
41 |
+
4. Educational Significance: An excellent bedtime story should convey certain educational values, helping children learn proper values and behaviors.
|
42 |
+
5. Warm and Pleasant: The story should ideally evoke a feeling of lightness, warmth, and happiness, making children feel loved and cared for.
|
43 |
+
|
44 |
+
## Story setting format
|
45 |
+
The story setting is given as a JSON object, such as:
|
46 |
+
{
|
47 |
+
"story_topic": "xxx",
|
48 |
+
"main_role": "xxx",
|
49 |
+
"scene": "xxx",
|
50 |
+
...
|
51 |
+
}
|
52 |
+
|
53 |
+
You are an expert in children story writing. You are discussing creative ideas with a student learning to write children stories. Please provide meaningful responses to the student's questions.
|
54 |
+
""".strip()
|
55 |
+
|
56 |
+
|
57 |
+
dlg_based_writer_system = """
|
58 |
+
Based on a dialogue, write an outline for a children storybook. This dialogue provides some points and ideas for writing the outline.
|
59 |
+
When writing the outline, basic requirements should be met:
|
60 |
+
{instruction}
|
61 |
+
|
62 |
+
## Output Format
|
63 |
+
Output a valid JSON object, following the format:
|
64 |
+
{{
|
65 |
+
"story_title": "xxx",
|
66 |
+
"story_outline": [{{"chapter_title":"xxx", "chapter_summary": "xxx"}}, {{"chapter_title":"xxx", "chapter_summary": "xxx"}}],
|
67 |
+
}}
|
68 |
+
""".strip().format(instruction=instruction)
|
69 |
+
|
70 |
+
dlg_based_writer_prompt = """
|
71 |
+
Story setting: {story_setting}
|
72 |
+
Dialogue history:
|
73 |
+
{dialogue_history}
|
74 |
+
Write a story outline with {num_outline} chapters.
|
75 |
+
""".strip()
|
76 |
+
|
77 |
+
|
78 |
+
chapter_writer_system = """
|
79 |
+
Based on the story outline, expand the given chapter summary into detailed story content.
|
80 |
+
|
81 |
+
## Input Content
|
82 |
+
The input consists of already written story content and the current chapter that needs to be expanded, in the following format:
|
83 |
+
{
|
84 |
+
"completed_story": ["xxx", "xxx"] // each element represents a page of story content.
|
85 |
+
"current_chapter": {"chapter_title": "xxx", "chapter_summary": "xxx"}
|
86 |
+
}
|
87 |
+
|
88 |
+
## Output Content
|
89 |
+
Output the expanded story content for the current chapter. The result should be a list where each element corresponds to the plot of one page of the storybook.
|
90 |
+
|
91 |
+
## Notes
|
92 |
+
1. Only expand the current chapter; do not overwrite content from other chapters.
|
93 |
+
2. The expanded content should not be too lengthy, with a maximum of 3 pages and no more than 2 sentences per page.
|
94 |
+
3. Maintain the tone of the story; do not add extra annotations, explanations, settings, or comments.
|
95 |
+
4. If the story is already complete, no further writing is necessary.
|
96 |
+
""".strip()
|
97 |
+
|
98 |
+
|
99 |
+
role_extract_system = """
|
100 |
+
Extract all main role names from the given story content and generate corresponding role descriptions. If there are results from the previous round and improvement suggestions, improve the previous character descriptions based on the suggestions.
|
101 |
+
|
102 |
+
## Steps
|
103 |
+
1. First, identify the main role's name in the story.
|
104 |
+
2. Then, identify other frequently occurring roles.
|
105 |
+
3. Generate descriptions for these roles. Ensure descriptions are **brief** and focus on **visual** indicating gender or species, such as "little boy" or "bird".
|
106 |
+
4. Ensure that descriptions do not exceed 20 words.
|
107 |
+
|
108 |
+
|
109 |
+
## Input Format
|
110 |
+
The input consists of the story content and possibly the previous output results with corresponding improvement suggestions, formatted as:
|
111 |
+
{
|
112 |
+
"story_content": "xxx",
|
113 |
+
"previous_result": {
|
114 |
+
"(role 1's name)": "xxx",
|
115 |
+
"(role 2's name)": "xxx"
|
116 |
+
}, // Empty indicates the first round
|
117 |
+
"improvement_suggestions": "xxx" // Empty indicates the first round
|
118 |
+
}
|
119 |
+
|
120 |
+
## Output Format
|
121 |
+
Output the character names and descriptions following this format:
|
122 |
+
{
|
123 |
+
"(role 1's name)": "xxx",
|
124 |
+
"(role 2's name)": "xxx"
|
125 |
+
}
|
126 |
+
Strictly follow the above steps and directly output the results without any additional content.
|
127 |
+
""".strip()
|
128 |
+
|
129 |
+
|
130 |
+
role_review_system = """
|
131 |
+
Review the role descriptions corresponding to the given story. If requirements are met, output "Check passed.". If not, provide improvement suggestions.
|
132 |
+
|
133 |
+
## Requirements for Role Descriptions
|
134 |
+
1. Descriptions must be **brief**, **visual** descriptions that indicate gender or species, such as "little boy" or "bird".
|
135 |
+
2. Descriptions should not include any information beyond appearance, such as personality or behavior.
|
136 |
+
3. The description of each role must not exceed 20 words.
|
137 |
+
|
138 |
+
## Input Format
|
139 |
+
The input consists of the story content and role extraction results, with a format of:
|
140 |
+
{
|
141 |
+
"story_content": "xxx",
|
142 |
+
"role_descriptions": {
|
143 |
+
"(Character 1's Name)": "xxx",
|
144 |
+
"(Character 2's Name)": "xxx"
|
145 |
+
}
|
146 |
+
}
|
147 |
+
|
148 |
+
## Output Format
|
149 |
+
Directly output improvement suggestions without any additional content if requirements are not met. Otherwise, output "Check passed."
|
150 |
+
""".strip()
|
151 |
+
|
152 |
+
|
153 |
+
story_to_image_reviser_system = """
|
154 |
+
Convert the given story content into image description. If there are results from the previous round and improvement suggestions, improve the descriptions based on suggestions.
|
155 |
+
|
156 |
+
## Input Format
|
157 |
+
The input consists of all story pages, the current page, and possibly the previous output results with corresponding improvement suggestions, formatted as:
|
158 |
+
{
|
159 |
+
"all_pages": ["xxx", "xxx"], // Each element is a page of story content
|
160 |
+
"current_page": "xxx",
|
161 |
+
"previous_result": "xxx", // If empty, indicates the first round
|
162 |
+
"improvement_suggestions": "xxx" // If empty, indicates the first round
|
163 |
+
}
|
164 |
+
|
165 |
+
## Output Format
|
166 |
+
Output a string describing the image corresponding to the current story content without any additional content.
|
167 |
+
|
168 |
+
## Notes
|
169 |
+
1. Keep it concise. Focus on the main visual elements, omit details.
|
170 |
+
2. Retain visual elements. Only describe static scenes, avoid the plot details.
|
171 |
+
3. Remove non-visual elements. Typical non-visual elements include dialogue, thoughts, and plot.
|
172 |
+
4. Retain role names.
|
173 |
+
""".strip()
|
174 |
+
|
175 |
+
story_to_image_review_system = """
|
176 |
+
Review the image description corresponding to the given story content. If the requirements are met, output "Check passed.". If not, provide improvement suggestions.
|
177 |
+
|
178 |
+
## Requirements for Image Descriptions
|
179 |
+
1. Keep it concise. Focus on the main visual elements, omit details.
|
180 |
+
2. Retain visual elements. Only describe static scenes, avoid the plot details.
|
181 |
+
3. Remove non-visual elements. Typical non-visual elements include dialogue, thoughts, and plot.
|
182 |
+
4. Retain role names.
|
183 |
+
|
184 |
+
## Input Format
|
185 |
+
The input consists of all story content, the current story content, and the corresponding image description, structured as:
|
186 |
+
{
|
187 |
+
"all_pages": ["xxx", "xxx"],
|
188 |
+
"current_page": "xxx",
|
189 |
+
"image_description": "xxx"
|
190 |
+
}
|
191 |
+
|
192 |
+
## Output Format
|
193 |
+
Directly output improvement suggestions without any additional content if requirements are not met. Otherwise, output "Check passed."
|
194 |
+
""".strip()
|
195 |
+
|
196 |
+
story_to_sound_reviser_system = """
|
197 |
+
Extract possible sound effects from the given story content. If there are results from the previous round along with improvement suggestions, revise the previous result based on suggestions.
|
198 |
+
|
199 |
+
## Input Format
|
200 |
+
The input consists of the story content, and may also include the previous result and corresponding improvement suggestions, formatted as:
|
201 |
+
{
|
202 |
+
"story": "xxx",
|
203 |
+
"previous_result": "xxx", // empty indicates the first round
|
204 |
+
"improvement_suggestions": "xxx" // empty indicates the first round
|
205 |
+
}
|
206 |
+
|
207 |
+
## Output Format
|
208 |
+
Output a string describing the sound effects without any additional content.
|
209 |
+
|
210 |
+
## Notes
|
211 |
+
1. The description must be sounds. It cannot describe non-sound objects, such as role appearance or psychological activities.
|
212 |
+
2. The number of sound effects must not exceed 3.
|
213 |
+
3. Exclude speech.
|
214 |
+
4. Exclude musical and instrumental sounds, such as background music.
|
215 |
+
5. Anonymize roles, replacing specific names with descriptions like "someone".
|
216 |
+
6. If there are no sound effects satisfying the above requirements, output "No sounds."
|
217 |
+
""".strip()
|
218 |
+
|
219 |
+
story_to_sound_review_system = """
|
220 |
+
Review sound effects corresponding to the given story content. If the requirements are met, output "Check passed.". If not, provide improvement suggestions.
|
221 |
+
|
222 |
+
## Requirements for Sound Descriptions
|
223 |
+
1. The description must be sounds. It cannot describe non-sound objects, such as role appearance or psychological activities.
|
224 |
+
2. The number of sounds must not exceed 3.
|
225 |
+
3. No speech should be included.
|
226 |
+
4. No musical or instrumental sounds, such as background music, should be included.
|
227 |
+
5. Roles must be anonymized. Role names should be replaced by descriptions like "someone".
|
228 |
+
6. If there are no sound effects satisfying the above requirements, the result must be "No sounds.".
|
229 |
+
|
230 |
+
## Input Format
|
231 |
+
The input consists of the story content and the corresponding sound description, formatted as:
|
232 |
+
{
|
233 |
+
"story": "xxx",
|
234 |
+
"sound_description": "xxx"
|
235 |
+
}
|
236 |
+
|
237 |
+
## Output Format
|
238 |
+
Directly output improvement suggestions without any additional content if requirements are not met. Otherwise, output "Check passed."
|
239 |
+
""".strip()
|
240 |
+
|
241 |
+
story_to_music_reviser_system = """
|
242 |
+
Generate suitable background music descriptions based on the story content. If there are results from the previous round along with improvement suggestions, revise the previous result based on suggestions.
|
243 |
+
|
244 |
+
## Input Format
|
245 |
+
The input consists of the story content, and may also include the previous result and corresponding improvement suggestions, formatted as:
|
246 |
+
{
|
247 |
+
"story": ["xxx", "xxx"], // Each element is a page of story content
|
248 |
+
"previous_result": "xxx", // empty indicates the first round
|
249 |
+
"improvement_suggestions": "xxx" // empty indicates the first round
|
250 |
+
}
|
251 |
+
|
252 |
+
## Output Format
|
253 |
+
Output a string describing the background music without any additional content.
|
254 |
+
|
255 |
+
## Notes
|
256 |
+
1. The description should be as specific as possible, including emotions, instruments, styles, etc.
|
257 |
+
2. Do not include specific role names.
|
258 |
+
""".strip()
|
259 |
+
|
260 |
+
|
261 |
+
story_to_music_reviewer_system = """
|
262 |
+
Review the background music description corresponding to the story content to check whether the description is suitable. If suitable, output "Check passed.". If not, provide improvement suggestions.
|
263 |
+
|
264 |
+
## Requirements for Background Music Descriptions
|
265 |
+
1. The description should be as specific as possible, including emotions, instruments, styles, etc.
|
266 |
+
2. Do not include specific role names.
|
267 |
+
|
268 |
+
## Input Format
|
269 |
+
The input consists of the story content and the corresponding music description, structured as:
|
270 |
+
{
|
271 |
+
"story": ["xxx", "xxx"], // Each element is a page of story content
|
272 |
+
"music_description": "xxx"
|
273 |
+
}
|
274 |
+
|
275 |
+
## Output Format
|
276 |
+
Directly output improvement suggestions without any additional content if requirements are not met. Otherwise, output "Check passed.".
|
277 |
+
""".strip()
|
mm_story_agent/video_compose_agent.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import List, Union
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
from datetime import timedelta
|
6 |
+
|
7 |
+
from tqdm import trange
|
8 |
+
import numpy as np
|
9 |
+
import librosa
|
10 |
+
import cv2
|
11 |
+
from zhon.hanzi import punctuation as zh_punc
|
12 |
+
|
13 |
+
from moviepy.editor import ImageClip, AudioFileClip, CompositeAudioClip, \
|
14 |
+
CompositeVideoClip, ColorClip, VideoFileClip, VideoClip, TextClip, concatenate_audioclips
|
15 |
+
import moviepy.video.compositing.transitions as transfx
|
16 |
+
from moviepy.audio.AudioClip import AudioArrayClip
|
17 |
+
from moviepy.audio.fx.all import audio_loop
|
18 |
+
from moviepy.video.tools.subtitles import SubtitlesClip
|
19 |
+
|
20 |
+
|
21 |
+
def generate_srt(timestamps: List,
|
22 |
+
captions: List,
|
23 |
+
save_path: Union[str, Path],
|
24 |
+
max_single_length: int = 30):
|
25 |
+
|
26 |
+
def format_time(seconds: float) -> str:
|
27 |
+
td = timedelta(seconds=seconds)
|
28 |
+
total_seconds = int(td.total_seconds())
|
29 |
+
millis = int((td.total_seconds() - total_seconds) * 1000)
|
30 |
+
hours, remainder = divmod(total_seconds, 3600)
|
31 |
+
minutes, seconds = divmod(remainder, 60)
|
32 |
+
return f"{hours:02}:{minutes:02}:{seconds:02},{millis:03}"
|
33 |
+
|
34 |
+
srt_content = []
|
35 |
+
num_caps = len(timestamps)
|
36 |
+
|
37 |
+
for idx in range(num_caps):
|
38 |
+
start_time, end_time = timestamps[idx]
|
39 |
+
caption_chunks = split_caption(captions[idx], max_single_length).split("\n")
|
40 |
+
num_chunks = len(caption_chunks)
|
41 |
+
|
42 |
+
if num_chunks == 0:
|
43 |
+
continue
|
44 |
+
|
45 |
+
segment_duration = (end_time - start_time) / num_chunks
|
46 |
+
|
47 |
+
for chunk_idx, chunk in enumerate(caption_chunks):
|
48 |
+
chunk_start_time = start_time + segment_duration * chunk_idx
|
49 |
+
chunk_end_time = start_time + segment_duration * (chunk_idx + 1)
|
50 |
+
start_time_str = format_time(chunk_start_time)
|
51 |
+
end_time_str = format_time(chunk_end_time)
|
52 |
+
srt_content.append(f"{len(srt_content) // 2 + 1}\n{start_time_str} --> {end_time_str}\n{chunk}\n\n")
|
53 |
+
|
54 |
+
with open(save_path, 'w') as srt_file:
|
55 |
+
srt_file.writelines(srt_content)
|
56 |
+
|
57 |
+
|
58 |
+
def add_caption(captions: List,
|
59 |
+
srt_path: Union[str, Path],
|
60 |
+
timestamps: List,
|
61 |
+
video_clip: VideoClip,
|
62 |
+
max_single_length: int = 30,
|
63 |
+
**caption_config):
|
64 |
+
generate_srt(timestamps, captions, srt_path, max_single_length)
|
65 |
+
|
66 |
+
generator = lambda txt: TextClip(txt, **caption_config)
|
67 |
+
subtitles = SubtitlesClip(srt_path.__str__(), generator)
|
68 |
+
captioned_clip = CompositeVideoClip([video_clip,
|
69 |
+
subtitles.set_position(("center", "bottom"), relative=True)])
|
70 |
+
return captioned_clip
|
71 |
+
|
72 |
+
|
73 |
+
def split_keep_separator(text, separator):
|
74 |
+
pattern = f'([{re.escape(separator)}])'
|
75 |
+
pieces = re.split(pattern, text)
|
76 |
+
return pieces
|
77 |
+
|
78 |
+
|
79 |
+
def split_caption(caption, max_length=30):
|
80 |
+
lines = []
|
81 |
+
if ord(caption[0]) >= ord("a") and ord(caption[0]) <= ord("z") or ord(caption[0]) >= ord("A") and ord(caption[0]) <= ord("Z"):
|
82 |
+
words = caption.split(" ")
|
83 |
+
current_words = []
|
84 |
+
for word in words:
|
85 |
+
if len(" ".join(current_words + [word])) <= max_length:
|
86 |
+
current_words += [word]
|
87 |
+
else:
|
88 |
+
if current_words:
|
89 |
+
lines.append(" ".join(current_words))
|
90 |
+
current_words = []
|
91 |
+
|
92 |
+
if current_words:
|
93 |
+
lines.append(" ".join(current_words))
|
94 |
+
else:
|
95 |
+
sentences = split_keep_separator(caption, zh_punc)
|
96 |
+
current_line = ""
|
97 |
+
for sentence in sentences:
|
98 |
+
if len(current_line + sentence) <= max_length:
|
99 |
+
current_line += sentence
|
100 |
+
else:
|
101 |
+
if current_line:
|
102 |
+
lines.append(current_line)
|
103 |
+
current_line = ""
|
104 |
+
if sentence.startswith(tuple(zh_punc)):
|
105 |
+
if lines:
|
106 |
+
lines[-1] += sentence[0]
|
107 |
+
current_line = sentence[1:]
|
108 |
+
else:
|
109 |
+
current_line = sentence
|
110 |
+
|
111 |
+
if current_line:
|
112 |
+
lines.append(current_line.strip())
|
113 |
+
|
114 |
+
return '\n'.join(lines)
|
115 |
+
|
116 |
+
|
117 |
+
def add_bottom_black_area(clip: VideoFileClip,
|
118 |
+
black_area_height: int = 64):
|
119 |
+
"""
|
120 |
+
Add a black area at the bottom of the video clip (for captions).
|
121 |
+
|
122 |
+
Args:
|
123 |
+
clip (VideoFileClip): Video clip to be processed.
|
124 |
+
black_area_height (int): Height of the black area.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
VideoFileClip: Processed video clip.
|
128 |
+
"""
|
129 |
+
black_bar = ColorClip(size=(clip.w, black_area_height), color=(0, 0, 0), duration=clip.duration)
|
130 |
+
extended_clip = CompositeVideoClip([clip, black_bar.set_position(("center", "bottom"))])
|
131 |
+
return extended_clip
|
132 |
+
|
133 |
+
|
134 |
+
def add_zoom_effect(clip, speed=1.0, mode='in', position='center'):
|
135 |
+
fps = clip.fps
|
136 |
+
duration = clip.duration
|
137 |
+
total_frames = int(duration * fps)
|
138 |
+
def main(getframe, t):
|
139 |
+
frame = getframe(t)
|
140 |
+
h, w = frame.shape[: 2]
|
141 |
+
i = t * fps
|
142 |
+
if mode == 'out':
|
143 |
+
i = total_frames - i
|
144 |
+
zoom = 1 + (i * ((0.1 * speed) / total_frames))
|
145 |
+
positions = {'center': [(w - (w * zoom)) / 2, (h - (h * zoom)) / 2],
|
146 |
+
'left': [0, (h - (h * zoom)) / 2],
|
147 |
+
'right': [(w - (w * zoom)), (h - (h * zoom)) / 2],
|
148 |
+
'top': [(w - (w * zoom)) / 2, 0],
|
149 |
+
'topleft': [0, 0],
|
150 |
+
'topright': [(w - (w * zoom)), 0],
|
151 |
+
'bottom': [(w - (w * zoom)) / 2, (h - (h * zoom))],
|
152 |
+
'bottomleft': [0, (h - (h * zoom))],
|
153 |
+
'bottomright': [(w - (w * zoom)), (h - (h * zoom))]}
|
154 |
+
tx, ty = positions[position]
|
155 |
+
M = np.array([[zoom, 0, tx], [0, zoom, ty]])
|
156 |
+
frame = cv2.warpAffine(frame, M, (w, h))
|
157 |
+
return frame
|
158 |
+
return clip.fl(main)
|
159 |
+
|
160 |
+
|
161 |
+
def add_move_effect(clip, direction="left", move_raito=0.95):
|
162 |
+
|
163 |
+
orig_width = clip.size[0]
|
164 |
+
orig_height = clip.size[1]
|
165 |
+
|
166 |
+
new_width = int(orig_width / move_raito)
|
167 |
+
new_height = int(orig_height / move_raito)
|
168 |
+
clip = clip.resize(width=new_width, height=new_height)
|
169 |
+
|
170 |
+
if direction == "left":
|
171 |
+
start_position = (0, 0)
|
172 |
+
end_position = (orig_width - new_width, 0)
|
173 |
+
elif direction == "right":
|
174 |
+
start_position = (orig_width - new_width, 0)
|
175 |
+
end_position = (0, 0)
|
176 |
+
|
177 |
+
duration = clip.duration
|
178 |
+
moving_clip = clip.set_position(
|
179 |
+
lambda t: (start_position[0] + (
|
180 |
+
end_position[0] - start_position[0]) / duration * t, start_position[1])
|
181 |
+
)
|
182 |
+
|
183 |
+
final_clip = CompositeVideoClip([moving_clip], size=(orig_width, orig_height))
|
184 |
+
|
185 |
+
return final_clip
|
186 |
+
|
187 |
+
|
188 |
+
def add_slide_effect(clips, slide_duration):
|
189 |
+
####### CAUTION: requires at least `slide_duration` of silence at the end of each clip #######
|
190 |
+
durations = [clip.duration for clip in clips]
|
191 |
+
first_clip = CompositeVideoClip(
|
192 |
+
[clips[0].fx(transfx.slide_out, duration=slide_duration, side="left")]
|
193 |
+
).set_start(0)
|
194 |
+
|
195 |
+
slide_out_sides = ["left"]
|
196 |
+
videos = [first_clip]
|
197 |
+
|
198 |
+
out_to_in_mapping = {"left": "right", "right": "left"}
|
199 |
+
|
200 |
+
for idx, clip in enumerate(clips[1: -1], start=1):
|
201 |
+
# For all other clips in the middle, we need them to slide in to the previous clip and out for the next one
|
202 |
+
|
203 |
+
# determine `slide_in_side` according to the `slide_out_side` of the previous clip
|
204 |
+
slide_in_side = out_to_in_mapping[slide_out_sides[-1]]
|
205 |
+
|
206 |
+
slide_out_side = "left" if random.random() <= 0.5 else "right"
|
207 |
+
slide_out_sides.append(slide_out_side)
|
208 |
+
|
209 |
+
videos.append(
|
210 |
+
(
|
211 |
+
CompositeVideoClip(
|
212 |
+
[clip.fx(transfx.slide_in, duration=slide_duration, side=slide_in_side)]
|
213 |
+
)
|
214 |
+
.set_start(sum(durations[:idx]) - (slide_duration) * idx)
|
215 |
+
.fx(transfx.slide_out, duration=slide_duration, side=slide_out_side)
|
216 |
+
)
|
217 |
+
)
|
218 |
+
|
219 |
+
last_clip = CompositeVideoClip(
|
220 |
+
[clips[-1].fx(transfx.slide_in, duration=slide_duration, side=out_to_in_mapping[slide_out_sides[-1]])]
|
221 |
+
).set_start(sum(durations[:-1]) - slide_duration * (len(clips) - 1))
|
222 |
+
videos.append(last_clip)
|
223 |
+
|
224 |
+
video = CompositeVideoClip(videos)
|
225 |
+
return video
|
226 |
+
|
227 |
+
|
228 |
+
def compose_video(story_dir: Union[str, Path],
|
229 |
+
save_path: Union[str, Path],
|
230 |
+
captions: List,
|
231 |
+
music_path: Union[str, Path],
|
232 |
+
num_pages: int,
|
233 |
+
fps: int = 10,
|
234 |
+
audio_sample_rate: int = 16000,
|
235 |
+
audio_codec: str = "mp3",
|
236 |
+
caption_config: dict = {},
|
237 |
+
max_single_caption_length: int = 30,
|
238 |
+
fade_duration: float = 1.0,
|
239 |
+
slide_duration: float = 0.4,
|
240 |
+
zoom_speed: float = 0.5,
|
241 |
+
move_ratio: float = 0.95,
|
242 |
+
sound_volume: float = 0.2,
|
243 |
+
music_volume: float = 0.2,
|
244 |
+
bg_speech_ratio: float = 0.4):
|
245 |
+
if not isinstance(story_dir, Path):
|
246 |
+
story_dir = Path(story_dir)
|
247 |
+
|
248 |
+
sound_dir = story_dir / "sound"
|
249 |
+
image_dir = story_dir / "image"
|
250 |
+
speech_dir = story_dir / "speech"
|
251 |
+
|
252 |
+
video_clips = []
|
253 |
+
# audio_durations = []
|
254 |
+
cur_duration = 0
|
255 |
+
timestamps = []
|
256 |
+
|
257 |
+
for page in trange(1, num_pages + 1):
|
258 |
+
##### speech track
|
259 |
+
slide_silence = AudioArrayClip(np.zeros((int(audio_sample_rate * slide_duration), 2)), fps=audio_sample_rate)
|
260 |
+
fade_silence = AudioArrayClip(np.zeros((int(audio_sample_rate * fade_duration), 2)), fps=audio_sample_rate)
|
261 |
+
|
262 |
+
if (speech_dir / f"p{page}.wav").exists(): # single speech file
|
263 |
+
single_utterance = True
|
264 |
+
speech_file = (speech_dir / f"./p{page}.wav").__str__()
|
265 |
+
speech_clip = AudioFileClip(speech_file, fps=audio_sample_rate)
|
266 |
+
# speech_clip = speech_clip.audio_fadein(fade_duration)
|
267 |
+
|
268 |
+
speech_clip = concatenate_audioclips([fade_silence, speech_clip, fade_silence])
|
269 |
+
else: # multiple speech files
|
270 |
+
single_utterance = False
|
271 |
+
speech_files = list(speech_dir.glob(f"p{page}_*.wav"))
|
272 |
+
speech_files = sorted(speech_files, key=lambda x: int(x.stem.split("_")[-1]))
|
273 |
+
speech_clips = []
|
274 |
+
for utt_idx, speech_file in enumerate(speech_files):
|
275 |
+
speech_clip = AudioFileClip(speech_file.__str__(), fps=audio_sample_rate)
|
276 |
+
# add multiple timestamps of the same speech clip
|
277 |
+
if utt_idx == 0:
|
278 |
+
timestamps.append([cur_duration + fade_duration,
|
279 |
+
cur_duration + fade_duration + speech_clip.duration])
|
280 |
+
cur_duration += speech_clip.duration + fade_duration
|
281 |
+
elif utt_idx == len(speech_files) - 1:
|
282 |
+
timestamps.append([
|
283 |
+
cur_duration,
|
284 |
+
cur_duration + speech_clip.duration
|
285 |
+
])
|
286 |
+
cur_duration += speech_clip.duration + fade_duration + slide_duration
|
287 |
+
else:
|
288 |
+
timestamps.append([
|
289 |
+
cur_duration,
|
290 |
+
cur_duration + speech_clip.duration
|
291 |
+
])
|
292 |
+
cur_duration += speech_clip.duration
|
293 |
+
speech_clips.append(speech_clip)
|
294 |
+
speech_clip = concatenate_audioclips([fade_silence] + speech_clips + [fade_silence])
|
295 |
+
speech_file = speech_files[0] # for energy calculation
|
296 |
+
|
297 |
+
# add slide silence
|
298 |
+
if page == 1:
|
299 |
+
speech_clip = concatenate_audioclips([speech_clip, slide_silence])
|
300 |
+
else:
|
301 |
+
speech_clip = concatenate_audioclips([slide_silence, speech_clip, slide_silence])
|
302 |
+
|
303 |
+
# add the timestamp of the whole clip as a single element
|
304 |
+
if single_utterance:
|
305 |
+
if page == 1:
|
306 |
+
timestamps.append([cur_duration + fade_duration,
|
307 |
+
cur_duration + speech_clip.duration - fade_duration - slide_duration])
|
308 |
+
cur_duration += speech_clip.duration - slide_duration
|
309 |
+
else:
|
310 |
+
timestamps.append([cur_duration + fade_duration + slide_duration,
|
311 |
+
cur_duration + speech_clip.duration - fade_duration - slide_duration])
|
312 |
+
cur_duration += speech_clip.duration - slide_duration
|
313 |
+
|
314 |
+
speech_array, _ = librosa.core.load(speech_file, sr=None)
|
315 |
+
speech_rms = librosa.feature.rms(y=speech_array)[0].mean()
|
316 |
+
|
317 |
+
# set image as the main content, align the duration
|
318 |
+
image_file = (image_dir / f"./p{page}.png").__str__()
|
319 |
+
image_clip = ImageClip(image_file)
|
320 |
+
image_clip = image_clip.set_duration(speech_clip.duration).set_fps(fps)
|
321 |
+
image_clip = image_clip.crossfadein(fade_duration).crossfadeout(fade_duration)
|
322 |
+
|
323 |
+
if random.random() <= 0.5: # zoom in or zoom out
|
324 |
+
if random.random() <= 0.5:
|
325 |
+
zoom_mode = "in"
|
326 |
+
else:
|
327 |
+
zoom_mode = "out"
|
328 |
+
image_clip = add_zoom_effect(image_clip, zoom_speed, zoom_mode)
|
329 |
+
else: # move left or right
|
330 |
+
if random.random() <= 0.5:
|
331 |
+
direction = "left"
|
332 |
+
else:
|
333 |
+
direction = "right"
|
334 |
+
image_clip = add_move_effect(image_clip, direction=direction, move_raito=move_ratio)
|
335 |
+
|
336 |
+
# sound track
|
337 |
+
sound_file = sound_dir / f"p{page}.wav"
|
338 |
+
if sound_file.exists():
|
339 |
+
sound_clip = AudioFileClip(sound_file.__str__(), fps=audio_sample_rate)
|
340 |
+
sound_clip = sound_clip.audio_fadein(fade_duration)
|
341 |
+
if sound_clip.duration < speech_clip.duration:
|
342 |
+
sound_clip = audio_loop(sound_clip, duration=speech_clip.duration)
|
343 |
+
else:
|
344 |
+
sound_clip = sound_clip.subclip(0, speech_clip.duration)
|
345 |
+
sound_array, _ = librosa.core.load(sound_file.__str__(), sr=None)
|
346 |
+
sound_rms = librosa.feature.rms(y=sound_array)[0].mean()
|
347 |
+
ratio = speech_rms / sound_rms * bg_speech_ratio
|
348 |
+
audio_clip = CompositeAudioClip([speech_clip, sound_clip.volumex(sound_volume * ratio).audio_fadeout(fade_duration)])
|
349 |
+
else:
|
350 |
+
audio_clip = speech_clip
|
351 |
+
|
352 |
+
video_clip = image_clip.set_audio(audio_clip)
|
353 |
+
video_clips.append(video_clip)
|
354 |
+
|
355 |
+
# audio_durations.append(audio_clip.duration)
|
356 |
+
|
357 |
+
# final_clip = concatenate_videoclips(video_clips, method="compose")
|
358 |
+
composite_clip = add_slide_effect(video_clips, slide_duration=slide_duration)
|
359 |
+
composite_clip = add_bottom_black_area(composite_clip, black_area_height=caption_config["area_height"])
|
360 |
+
del caption_config["area_height"]
|
361 |
+
composite_clip = add_caption(
|
362 |
+
captions,
|
363 |
+
story_dir / "captions.srt",
|
364 |
+
timestamps,
|
365 |
+
composite_clip,
|
366 |
+
max_single_caption_length,
|
367 |
+
**caption_config
|
368 |
+
)
|
369 |
+
|
370 |
+
# add music track, align the duration
|
371 |
+
music_clip = AudioFileClip(music_path.__str__(), fps=audio_sample_rate)
|
372 |
+
music_array, _ = librosa.core.load(music_path.__str__(), sr=None)
|
373 |
+
music_rms = librosa.feature.rms(y=music_array)[0].mean()
|
374 |
+
ratio = speech_rms / music_rms * bg_speech_ratio
|
375 |
+
if music_clip.duration < composite_clip.duration:
|
376 |
+
music_clip = audio_loop(music_clip, duration=composite_clip.duration)
|
377 |
+
else:
|
378 |
+
music_clip = music_clip.subclip(0, composite_clip.duration)
|
379 |
+
all_audio_clip = CompositeAudioClip([composite_clip.audio, music_clip.volumex(music_volume * ratio)])
|
380 |
+
composite_clip = composite_clip.set_audio(all_audio_clip)
|
381 |
+
|
382 |
+
composite_clip.write_videofile(save_path.__str__(),
|
383 |
+
audio_fps=audio_sample_rate,
|
384 |
+
audio_codec=audio_codec,)
|
385 |
+
|
386 |
+
|
387 |
+
class VideoComposeAgent:
|
388 |
+
|
389 |
+
def adjust_caption_config(self, width, height):
|
390 |
+
area_height = int(height * 0.06)
|
391 |
+
fontsize = int((width + height) / 2 * 0.025)
|
392 |
+
return {
|
393 |
+
"fontsize": fontsize,
|
394 |
+
"area_height": area_height
|
395 |
+
}
|
396 |
+
|
397 |
+
def call(self, pages, config):
|
398 |
+
height = config["image_generation"]["obj_cfg"]["height"]
|
399 |
+
width = config["image_generation"]["obj_cfg"]["width"]
|
400 |
+
config["caption_config"].update(self.adjust_caption_config(width, height))
|
401 |
+
compose_video(
|
402 |
+
story_dir=Path(config["story_dir"]),
|
403 |
+
save_path=Path(config["story_dir"]) / "output.mp4",
|
404 |
+
captions=pages,
|
405 |
+
music_path=Path(config["story_dir"]) / "music/music.wav",
|
406 |
+
num_pages=len(pages),
|
407 |
+
audio_sample_rate=config["audio_sample_rate"],
|
408 |
+
audio_codec=config["audio_codec"],
|
409 |
+
caption_config=config["caption_config"],
|
410 |
+
max_single_caption_length=config["max_single_caption_length"],
|
411 |
+
**config["slideshow_effect"]
|
412 |
+
)
|
nls-1.0.0-py3-none-any.whl
ADDED
Binary file (47 kB). View file
|
|
policy.xml
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<!DOCTYPE policymap [
|
3 |
+
<!ELEMENT policymap (policy)*>
|
4 |
+
<!ATTLIST policymap xmlns CDATA #FIXED ''>
|
5 |
+
<!ELEMENT policy EMPTY>
|
6 |
+
<!ATTLIST policy xmlns CDATA #FIXED '' domain NMTOKEN #REQUIRED
|
7 |
+
name NMTOKEN #IMPLIED pattern CDATA #IMPLIED rights NMTOKEN #IMPLIED
|
8 |
+
stealth NMTOKEN #IMPLIED value CDATA #IMPLIED>
|
9 |
+
]>
|
10 |
+
<!--
|
11 |
+
Configure ImageMagick policies.
|
12 |
+
|
13 |
+
Domains include system, delegate, coder, filter, path, or resource.
|
14 |
+
|
15 |
+
Rights include none, read, write, execute and all. Use | to combine them,
|
16 |
+
for example: "read | write" to permit read from, or write to, a path.
|
17 |
+
|
18 |
+
Use a glob expression as a pattern.
|
19 |
+
|
20 |
+
Suppose we do not want users to process MPEG video images:
|
21 |
+
|
22 |
+
<policy domain="delegate" rights="none" pattern="mpeg:decode" />
|
23 |
+
|
24 |
+
Here we do not want users reading images from HTTP:
|
25 |
+
|
26 |
+
<policy domain="coder" rights="none" pattern="HTTP" />
|
27 |
+
|
28 |
+
The /repository file system is restricted to read only. We use a glob
|
29 |
+
expression to match all paths that start with /repository:
|
30 |
+
|
31 |
+
<policy domain="path" rights="read" pattern="/repository/*" />
|
32 |
+
|
33 |
+
Lets prevent users from executing any image filters:
|
34 |
+
|
35 |
+
<policy domain="filter" rights="none" pattern="*" />
|
36 |
+
|
37 |
+
Any large image is cached to disk rather than memory:
|
38 |
+
|
39 |
+
<policy domain="resource" name="area" value="1GP"/>
|
40 |
+
|
41 |
+
Use the default system font unless overwridden by the application:
|
42 |
+
|
43 |
+
<policy domain="system" name="font" value="/usr/share/fonts/favorite.ttf"/>
|
44 |
+
|
45 |
+
Define arguments for the memory, map, area, width, height and disk resources
|
46 |
+
with SI prefixes (.e.g 100MB). In addition, resource policies are maximums
|
47 |
+
for each instance of ImageMagick (e.g. policy memory limit 1GB, -limit 2GB
|
48 |
+
exceeds policy maximum so memory limit is 1GB).
|
49 |
+
|
50 |
+
Rules are processed in order. Here we want to restrict ImageMagick to only
|
51 |
+
read or write a small subset of proven web-safe image types:
|
52 |
+
|
53 |
+
<policy domain="delegate" rights="none" pattern="*" />
|
54 |
+
<policy domain="filter" rights="none" pattern="*" />
|
55 |
+
<policy domain="coder" rights="none" pattern="*" />
|
56 |
+
<policy domain="coder" rights="read|write" pattern="{GIF,JPEG,PNG,WEBP}" />
|
57 |
+
-->
|
58 |
+
<policymap>
|
59 |
+
<!-- <policy domain="resource" name="temporary-path" value="/tmp"/> -->
|
60 |
+
<policy domain="resource" name="memory" value="256MiB"/>
|
61 |
+
<policy domain="resource" name="map" value="512MiB"/>
|
62 |
+
<policy domain="resource" name="width" value="16KP"/>
|
63 |
+
<policy domain="resource" name="height" value="16KP"/>
|
64 |
+
<!-- <policy domain="resource" name="list-length" value="128"/> -->
|
65 |
+
<policy domain="resource" name="area" value="128MP"/>
|
66 |
+
<policy domain="resource" name="disk" value="1GiB"/>
|
67 |
+
<!-- <policy domain="resource" name="file" value="768"/> -->
|
68 |
+
<!-- <policy domain="resource" name="thread" value="4"/> -->
|
69 |
+
<!-- <policy domain="resource" name="throttle" value="0"/> -->
|
70 |
+
<!-- <policy domain="resource" name="time" value="3600"/> -->
|
71 |
+
<!-- <policy domain="coder" rights="none" pattern="MVG" /> -->
|
72 |
+
<!-- <policy domain="module" rights="none" pattern="{PS,PDF,XPS}" /> -->
|
73 |
+
<!-- <policy domain="path" rights="none" pattern="@*" /> -->
|
74 |
+
<!-- <policy domain="cache" name="memory-map" value="anonymous"/> -->
|
75 |
+
<!-- <policy domain="cache" name="synchronize" value="True"/> -->
|
76 |
+
<!-- <policy domain="cache" name="shared-secret" value="passphrase" stealth="true"/>
|
77 |
+
<!-- <policy domain="system" name="max-memory-request" value="256MiB"/> -->
|
78 |
+
<!-- <policy domain="system" name="shred" value="2"/> -->
|
79 |
+
<!-- <policy domain="system" name="precision" value="6"/> -->
|
80 |
+
<!-- <policy domain="system" name="font" value="/path/to/font.ttf"/> -->
|
81 |
+
<!-- <policy domain="system" name="pixel-cache-memory" value="anonymous"/> -->
|
82 |
+
<!-- <policy domain="system" name="shred" value="2"/> -->
|
83 |
+
<!-- <policy domain="system" name="precision" value="6"/> -->
|
84 |
+
<!-- not needed due to the need to use explicitly by mvg: -->
|
85 |
+
<!-- <policy domain="delegate" rights="none" pattern="MVG" /> -->
|
86 |
+
<!-- use curl -->
|
87 |
+
<policy domain="delegate" rights="none" pattern="URL" />
|
88 |
+
<policy domain="delegate" rights="none" pattern="HTTPS" />
|
89 |
+
<policy domain="delegate" rights="none" pattern="HTTP" />
|
90 |
+
<!-- in order to avoid to get image with password text -->
|
91 |
+
<!-- <policy domain="path" rights="none" pattern="@*"/> -->
|
92 |
+
<!-- disable ghostscript format types -->
|
93 |
+
<policy domain="coder" rights="none" pattern="PS" />
|
94 |
+
<policy domain="coder" rights="none" pattern="PS2" />
|
95 |
+
<policy domain="coder" rights="none" pattern="PS3" />
|
96 |
+
<policy domain="coder" rights="none" pattern="EPS" />
|
97 |
+
<policy domain="coder" rights="none" pattern="PDF" />
|
98 |
+
<policy domain="coder" rights="none" pattern="XPS" />
|
99 |
+
</policymap>
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pillow
|
2 |
+
PyYAML
|
3 |
+
pypinyin
|
4 |
+
soundfile
|
5 |
+
dashscope
|
6 |
+
tqdm
|
7 |
+
zhon
|
8 |
+
numpy
|
9 |
+
librosa
|
10 |
+
moviepy
|
11 |
+
opencv-python
|
12 |
+
nls-1.0.0-py3-none-any.whl
|
13 |
+
audiocraft
|