Xu Xuenan
Multi-GPUs
5152717
raw
history blame
4.65 kB
import time
import json
from pathlib import Path
import torch.multiprocessing as mp
from mm_story_agent.modality_agents.story_agent import QAOutlineStoryWriter
from mm_story_agent.modality_agents.speech_agent import CosyVoiceAgent
from mm_story_agent.modality_agents.sound_agent import AudioLDM2Agent
from mm_story_agent.modality_agents.music_agent import MusicGenAgent
from mm_story_agent.modality_agents.image_agent import StoryDiffusionAgent
from mm_story_agent.video_compose_agent import VideoComposeAgent
class MMStoryAgent:
def __init__(self) -> None:
self.modalities = ["image", "sound", "speech", "music"]
self.modality_agent_class = {
"image": StoryDiffusionAgent,
"sound": AudioLDM2Agent,
"speech": CosyVoiceAgent,
"music": MusicGenAgent
}
self.modality_devices = {
"image": "cuda:0",
"sound": "cuda:1",
"music": "cuda:2",
"speech": "cuda:3"
}
self.agents = {}
def call_modality_agent(self, agent, device, pages, save_path, return_dict):
result = agent.call(pages, device, save_path)
modality = result["modality"]
return_dict[modality] = result
def write_story(self, config):
story_writer = QAOutlineStoryWriter(config["story_gen_config"])
pages = story_writer.call(config["story_setting"])
return pages
def generate_speech(self, config, pages):
story_dir = Path(config["story_dir"])
(story_dir / "speech").mkdir(exist_ok=True, parents=True)
speech_agent = CosyVoiceAgent(config["speech_generation"])
speech_agent.call(pages, story_dir / "speech")
def generate_sound(self, config, pages):
story_dir = Path(config["story_dir"])
(story_dir / "sound").mkdir(exist_ok=True, parents=True)
sound_agent = AudioLDM2Agent(config["sound_generation"])
sound_agent.call(pages, story_dir / "sound")
def generate_music(self, config, pages):
story_dir = Path(config["story_dir"])
(story_dir / "music").mkdir(exist_ok=True, parents=True)
music_agent = MusicGenAgent(config["music_generation"])
music_agent.call(pages, story_dir / "music")
def generate_image(self, config, pages):
story_dir = Path(config["story_dir"])
(story_dir / "image").mkdir(exist_ok=True, parents=True)
image_agent = StoryDiffusionAgent(config["image_generation"])
image_agent.call(pages, story_dir / "image")
def generate_modality_assets(self, config, pages):
script_data = {"pages": [{"story": page} for page in pages]}
story_dir = Path(config["story_dir"])
for sub_dir in self.modalities:
(story_dir / sub_dir).mkdir(exist_ok=True, parents=True)
agents = {}
for modality in self.modalities:
agents[modality] = self.modality_agent_class[modality](config[modality + "_generation"])
processes = []
return_dict = mp.Manager().dict()
for modality in self.modalities:
p = mp.Process(target=self.call_modality_agent, args=(agents[modality], self.modality_devices[modality], pages, story_dir / modality, return_dict), daemon=False)
processes.append(p)
p.start()
for p in processes:
p.join()
for modality, result in return_dict.items():
try:
if result["modality"] == "image":
images = result["generation_results"]
for idx in range(len(pages)):
script_data["pages"][idx]["image_prompt"] = result["prompts"][idx]
elif result["modality"] == "sound":
for idx in range(len(pages)):
script_data["pages"][idx]["sound_prompt"] = result["prompts"][idx]
elif result["modality"] == "music":
script_data["music_prompt"] = result["prompt"]
except Exception as e:
print(f"Error occurred during generation: {e}")
with open(story_dir / "script_data.json", "w") as writer:
json.dump(script_data, writer, ensure_ascii=False, indent=4)
return images
def compose_storytelling_video(self, config, pages):
video_compose_agent = VideoComposeAgent()
video_compose_agent.call(pages, config)
def call(self, config):
pages = self.write_story(config)
images = self.generate_modality_assets(config, pages)
self.compose_storytelling_video(config, pages)