Spaces:
Sleeping
Sleeping
import time | |
import json | |
from pathlib import Path | |
import torch.multiprocessing as mp | |
from mm_story_agent.modality_agents.story_agent import QAOutlineStoryWriter | |
from mm_story_agent.modality_agents.speech_agent import CosyVoiceAgent | |
from mm_story_agent.modality_agents.sound_agent import AudioLDM2Agent | |
from mm_story_agent.modality_agents.music_agent import MusicGenAgent | |
from mm_story_agent.modality_agents.image_agent import StoryDiffusionAgent | |
from mm_story_agent.video_compose_agent import VideoComposeAgent | |
class MMStoryAgent: | |
def __init__(self) -> None: | |
self.modalities = ["image", "sound", "speech", "music"] | |
self.modality_agent_class = { | |
"image": StoryDiffusionAgent, | |
"sound": AudioLDM2Agent, | |
"speech": CosyVoiceAgent, | |
"music": MusicGenAgent | |
} | |
self.modality_devices = { | |
"image": "cuda:0", | |
"sound": "cuda:1", | |
"music": "cuda:2", | |
"speech": "cuda:3" | |
} | |
self.agents = {} | |
def call_modality_agent(self, agent, device, pages, save_path, return_dict): | |
result = agent.call(pages, device, save_path) | |
modality = result["modality"] | |
return_dict[modality] = result | |
def write_story(self, config): | |
story_writer = QAOutlineStoryWriter(config["story_gen_config"]) | |
pages = story_writer.call(config["story_setting"]) | |
return pages | |
def generate_speech(self, config, pages): | |
story_dir = Path(config["story_dir"]) | |
(story_dir / "speech").mkdir(exist_ok=True, parents=True) | |
speech_agent = CosyVoiceAgent(config["speech_generation"]) | |
speech_agent.call(pages, story_dir / "speech") | |
def generate_sound(self, config, pages): | |
story_dir = Path(config["story_dir"]) | |
(story_dir / "sound").mkdir(exist_ok=True, parents=True) | |
sound_agent = AudioLDM2Agent(config["sound_generation"]) | |
sound_agent.call(pages, story_dir / "sound") | |
def generate_music(self, config, pages): | |
story_dir = Path(config["story_dir"]) | |
(story_dir / "music").mkdir(exist_ok=True, parents=True) | |
music_agent = MusicGenAgent(config["music_generation"]) | |
music_agent.call(pages, story_dir / "music") | |
def generate_image(self, config, pages): | |
story_dir = Path(config["story_dir"]) | |
(story_dir / "image").mkdir(exist_ok=True, parents=True) | |
image_agent = StoryDiffusionAgent(config["image_generation"]) | |
image_agent.call(pages, story_dir / "image") | |
def generate_modality_assets(self, config, pages): | |
script_data = {"pages": [{"story": page} for page in pages]} | |
story_dir = Path(config["story_dir"]) | |
for sub_dir in self.modalities: | |
(story_dir / sub_dir).mkdir(exist_ok=True, parents=True) | |
agents = {} | |
for modality in self.modalities: | |
agents[modality] = self.modality_agent_class[modality](config[modality + "_generation"]) | |
processes = [] | |
return_dict = mp.Manager().dict() | |
for modality in self.modalities: | |
p = mp.Process(target=self.call_modality_agent, args=(agents[modality], self.modality_devices[modality], pages, story_dir / modality, return_dict), daemon=False) | |
processes.append(p) | |
p.start() | |
for p in processes: | |
p.join() | |
for modality, result in return_dict.items(): | |
try: | |
if result["modality"] == "image": | |
images = result["generation_results"] | |
for idx in range(len(pages)): | |
script_data["pages"][idx]["image_prompt"] = result["prompts"][idx] | |
elif result["modality"] == "sound": | |
for idx in range(len(pages)): | |
script_data["pages"][idx]["sound_prompt"] = result["prompts"][idx] | |
elif result["modality"] == "music": | |
script_data["music_prompt"] = result["prompt"] | |
except Exception as e: | |
print(f"Error occurred during generation: {e}") | |
with open(story_dir / "script_data.json", "w") as writer: | |
json.dump(script_data, writer, ensure_ascii=False, indent=4) | |
return images | |
def compose_storytelling_video(self, config, pages): | |
video_compose_agent = VideoComposeAgent() | |
video_compose_agent.call(pages, config) | |
def call(self, config): | |
pages = self.write_story(config) | |
images = self.generate_modality_assets(config, pages) | |
self.compose_storytelling_video(config, pages) | |