Spaces:
Running
on
L4
Running
on
L4
from pathlib import Path | |
from copy import deepcopy | |
import shutil | |
import os | |
from datetime import datetime | |
import time | |
import uuid | |
import subprocess | |
import gradio as gr | |
import yaml | |
import torch.multiprocessing as mp | |
mp.set_start_method('spawn', force=True) | |
from mm_story_agent import MMStoryAgent | |
os.system("cp policy.xml /etc/ImageMagick-6/") | |
with open("configs/mm_story_agent.yaml", "r") as reader: | |
config = yaml.load(reader, Loader=yaml.FullLoader) | |
default_story_setting = config["story_setting"] | |
default_story_gen_config = config["story_gen_config"] | |
default_slideshow_effect = config["slideshow_effect"] | |
default_image_config = config["image_generation"] | |
default_sound_config = config["sound_generation"] | |
default_music_config = config["music_generation"] | |
def set_generating_progress_text(text): | |
return gr.update(visible=True, value=f"<h3>{text}</h3>") | |
def set_text_invisible(): | |
return gr.update(visible=False) | |
def deep_update(original, updates): | |
for key, value in updates.items(): | |
if isinstance(value, dict): | |
original[key] = deep_update(original.get(key, {}), value) | |
else: | |
original[key] = value | |
return original | |
def update_page(direction, page, story_data): | |
orig_page = page | |
if direction == 'next' and page < len(story_data) - 1: | |
page = orig_page + 1 | |
elif direction == 'prev' and page > 0: | |
page = orig_page - 1 | |
return page, story_data[page], story_data | |
def write_story_fn(story_topic, main_role, scene, | |
num_outline, temperature, | |
current_page, | |
config, | |
progress=gr.Progress(track_tqdm=True)): | |
config["story_dir"] = f"generated_stories/{time.strftime('%Y%m%d-%H%M%S') + '-' + str(uuid.uuid1().hex)}" | |
current_date = datetime.now() | |
if Path("generated_stories").exists(): | |
for story_dir in Path("generated_stories").iterdir(): | |
story_date = story_dir.name[:8] | |
story_date = datetime.strptime(story_date, '%Y%m%d') | |
date_difference = current_date - story_date | |
if date_difference.days >= 2: | |
shutil.rmtree(story_dir) | |
deep_update(config, { | |
"story_setting": { | |
"story_topic": story_topic, | |
"main_role": main_role, | |
"scene": scene, | |
}, | |
"story_gen_config": { | |
"num_outline": num_outline, | |
"temperature": temperature | |
}, | |
}) | |
story_gen_agent = MMStoryAgent() | |
pages = story_gen_agent.write_story(config) | |
# story_data, story_accordion, story_content | |
return pages, gr.update(visible=True), pages[current_page], gr.update() | |
def modality_assets_generation_fn( | |
height, width, image_seed, sound_guidance_scale, sound_seed, | |
n_candidate_per_text, | |
config, | |
story_data): | |
deep_update(config, { | |
"image_generation": { | |
"obj_cfg": { | |
"height": height, | |
"width": width, | |
}, | |
"call_cfg": { | |
"seed": image_seed | |
} | |
}, | |
"sound_generation": { | |
"call_cfg": { | |
"guidance_scale": sound_guidance_scale, | |
"seed": sound_seed, | |
"n_candidate_per_text": n_candidate_per_text | |
} | |
}, | |
}) | |
story_gen_agent = MMStoryAgent() | |
images = story_gen_agent.generate_modality_assets(config, story_data) | |
# image gallery | |
return gr.update(visible=True, value=images, columns=[len(images)], rows=[1], height="auto") | |
def compose_storytelling_video_fn( | |
fade_duration, slide_duration, zoom_speed, move_ratio, | |
sound_volume, music_volume, bg_speech_ratio, fps, | |
config, | |
story_data, | |
progress=gr.Progress(track_tqdm=True)): | |
deep_update(config, { | |
"slideshow_effect": { | |
"fade_duration": fade_duration, | |
"slide_duration": slide_duration, | |
"zoom_speed": zoom_speed, | |
"move_ratio": move_ratio, | |
"sound_volume": sound_volume, | |
"music_volume": music_volume, | |
"bg_speech_ratio": bg_speech_ratio, | |
"fps": fps | |
}, | |
}) | |
story_gen_agent = MMStoryAgent() | |
story_gen_agent.compose_storytelling_video(config, story_data) | |
# video_output | |
return Path(config["story_dir"]) / "output.mp4" | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.HTML(""" | |
<h1 style="text-align: center;">MM-StoryAgent</h1> | |
<p style="font-size: 16px;">This is a demo for generating attractive storytelling videos based on the given story setting.</p> | |
<p style="font-size: 16px;">Depending on the chapter number, the generation may take a long time. Please be patient.</p> | |
""") | |
config = gr.State(deepcopy(config)) | |
with gr.Row(): | |
with gr.Column(): | |
story_topic = gr.Textbox(label="Story Topic", value=default_story_setting["story_topic"]) | |
main_role = gr.Textbox(label="Main Role", value=default_story_setting["main_role"]) | |
scene = gr.Textbox(label="Scene", value=default_story_setting["scene"]) | |
chapter_num = gr.Number(label="Chapter Number", value=default_story_gen_config["num_outline"]) | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature", value=default_story_gen_config["temperature"]) | |
with gr.Accordion("Detailed Image Configuration (Optional)", open=False): | |
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['height']) | |
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['width']) | |
image_seed = gr.Number(label="Image Seed", value=default_image_config["call_cfg"]['seed']) | |
with gr.Accordion("Detailed Sound Configuration (Optional)", open=False): | |
sound_guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=7.0, step=0.5, value=default_sound_config["call_cfg"]['guidance_scale']) | |
sound_seed = gr.Number(label="Sound Seed", value=default_sound_config["call_cfg"]['seed']) | |
n_candidate_per_text = gr.Slider(label="Number of Candidates per Text", minimum=0, maximum=5, step=1, value=default_sound_config["call_cfg"]['n_candidate_per_text']) | |
with gr.Accordion("Detailed Slideshow Effect (Optional)", open=False): | |
fade_duration = gr.Slider(label="Fade Duration", minimum=0.1, maximum=1.5, step=0.1, value=default_slideshow_effect['fade_duration']) | |
slide_duration = gr.Slider(label="Slide Duration", minimum=0.1, maximum=1.0, step=0.1, value=default_slideshow_effect['slide_duration']) | |
zoom_speed = gr.Slider(label="Zoom Speed", minimum=0.1, maximum=2.0, step=0.1, value=default_slideshow_effect['zoom_speed']) | |
move_ratio = gr.Slider(label="Move Ratio", minimum=0.8, maximum=1.0, step=0.05, value=default_slideshow_effect['move_ratio']) | |
sound_volume = gr.Slider(label="Sound Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['sound_volume']) | |
music_volume = gr.Slider(label="Music Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['music_volume']) | |
bg_speech_ratio = gr.Slider(label="Background / Speech Ratio", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['bg_speech_ratio']) | |
fps = gr.Slider(label="FPS", minimum=1, maximum=30, step=1, value=default_slideshow_effect['fps']) | |
with gr.Column(): | |
story_data = gr.State([]) | |
story_generation_information = gr.Markdown( | |
label="Story Generation Status", | |
value="<h3>Generating Story Script ......</h3>", | |
visible=False) | |
with gr.Accordion(label="Story Content", open=False, visible=False) as story_accordion: | |
with gr.Row(): | |
prev_button = gr.Button("Previous Page",) | |
next_button = gr.Button("Next Page",) | |
story_content = gr.Textbox(label="Page Content") | |
video_generation_information = gr.Markdown(label="Generation Status", value="<h3>Generating Video ......</h3>", visible=False) | |
image_gallery = gr.Gallery(label="Images", show_label=False, visible=False) | |
video_generation_btn = gr.Button("Generate Video") | |
video_output = gr.Video(label="Generated Story", interactive=False) | |
current_page = gr.State(0) | |
prev_button.click( | |
fn=update_page, | |
inputs=[gr.State("prev"), current_page, story_data], | |
outputs=[current_page, story_content] | |
) | |
next_button.click( | |
fn=update_page, | |
inputs=[gr.State("next"), current_page, story_data], | |
outputs=[current_page, story_content,]) | |
# (possibly) update role description and scripts | |
video_generation_btn.click( | |
fn=set_generating_progress_text, | |
inputs=[gr.State("Generating Story ...")], | |
outputs=video_generation_information | |
).then( | |
fn=write_story_fn, | |
inputs=[story_topic, main_role, scene, | |
chapter_num, temperature, | |
current_page, | |
config | |
], | |
outputs=[story_data, story_accordion, story_content, video_output] | |
).then( | |
fn=set_generating_progress_text, | |
inputs=[gr.State("Generating Modality Assets ...")], | |
outputs=video_generation_information | |
).then( | |
fn=modality_assets_generation_fn, | |
inputs=[height, width, image_seed, sound_guidance_scale, sound_seed, | |
n_candidate_per_text, | |
config, | |
story_data], | |
outputs=[image_gallery] | |
).then( | |
fn=set_generating_progress_text, | |
inputs=[gr.State("Composing Video ...")], | |
outputs=video_generation_information | |
).then( | |
fn=compose_storytelling_video_fn, | |
inputs=[fade_duration, slide_duration, zoom_speed, move_ratio, | |
sound_volume, music_volume, bg_speech_ratio, fps, | |
config, | |
story_data], | |
outputs=[video_output] | |
).then( | |
fn=lambda : gr.update(visible=False), | |
inputs=[], | |
outputs=[image_gallery] | |
).then( | |
fn=set_generating_progress_text, | |
inputs=[gr.State("Generation Finished!")], | |
outputs=video_generation_information | |
) | |
if __name__ == "__main__": | |
demo.launch() |