Spaces:
Sleeping
Sleeping
| import random | |
| from typing import Iterator | |
| from data_handlers import get_melody_handler | |
| from .utils.g2p import preprocess_text | |
| class MelodyController: | |
| def __init__(self, melody_source_id: str, cache_dir: str): | |
| self.melody_source_id = melody_source_id | |
| self.song_id = None | |
| # load song database if needed | |
| parts = self.melody_source_id.split("-") | |
| self.mode = parts[0] | |
| self.align_type = parts[1] | |
| dataset_name = parts[-1] | |
| if dataset_name == "none": | |
| self.database = None | |
| else: | |
| handler_cls = get_melody_handler(dataset_name) | |
| self.database = handler_cls(self.align_type, cache_dir) | |
| def get_melody_constraints(self, max_num_phrases: int = 5) -> str: | |
| """Return a lyric-format prompt based on melody structure.""" | |
| if self.mode == "gen": | |
| return "" | |
| elif self.mode == "sample": | |
| assert self.database is not None, "Song database is not loaded." | |
| self.song_id = random.choice(self.database.get_song_ids()) | |
| self.reference_song = self.database.iter_song_phrases(self.song_id) | |
| phrase_length = self.database.get_phrase_length(self.song_id) | |
| if not phrase_length: | |
| return "" | |
| prompt = ( | |
| "\n请按照歌词格式回复,每句需遵循以下字数规则:" | |
| + "".join( | |
| [ | |
| f"\n第{i}句:{c}个字" | |
| for i, c in enumerate(phrase_length[:max_num_phrases], 1) | |
| ] | |
| ) | |
| + "\n如果没有足够的信息回答,请使用最少的句子,不要重复、不要扩展、不要加入无关内容。\n" | |
| ) | |
| return prompt | |
| else: | |
| raise ValueError(f"Unsupported melody mode: {self.mode}") | |
| def generate_score( | |
| self, lyrics: str, language: str | |
| ) -> list[tuple[float, float, str, int]]: | |
| """ | |
| lyrics: [lyric, ...] | |
| returns: [(start, end, lyric, pitch), ...] | |
| """ | |
| text_list = preprocess_text(lyrics, language) | |
| if self.mode == "gen" and self.align_type == "random": | |
| return self._generate_random_score(text_list) | |
| elif self.mode == "sample": | |
| if not self.reference_song: | |
| raise RuntimeError( | |
| "Must call get_melody_constraints() before generate_score() in sample mode." | |
| ) | |
| return self._align_text_to_score( | |
| text_list, self.reference_song, self.align_type | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported melody_source_id: {self.melody_source_id}") | |
| def _generate_random_score(self, text_list: list[str]): | |
| st = 0 | |
| score = [] | |
| for lyric in text_list: | |
| pitch = random.randint(57, 69) | |
| duration = round(random.uniform(0.1, 0.5), 4) | |
| ed = st + duration | |
| score.append((st, ed, lyric, pitch)) | |
| st = ed | |
| return score | |
| def _align_text_to_score( | |
| self, | |
| text_list: list[str], | |
| song_phrase_iterator: Iterator[dict], | |
| align_type: str, | |
| ): | |
| score = [] | |
| text_idx = 0 | |
| while text_idx < len(text_list): | |
| reference = next(song_phrase_iterator) | |
| for st, ed, ref_lyric, pitch in zip( | |
| reference["note_start_times"], | |
| reference["note_end_times"], | |
| reference["note_lyrics"], | |
| reference["note_midi"], | |
| ): | |
| assert ref_lyric not in [ | |
| "<AP>", | |
| "<SP>", | |
| ], f"Proccessed {self.melody_source_id} score segments should not contain <AP> or <SP>." # TODO: remove in PR, only for debug | |
| if pitch == 0: | |
| score.append((st, ed, ref_lyric, pitch)) | |
| elif ref_lyric in ["-", "——"] and align_type == "lyric": | |
| score.append((st, ed, "-", pitch)) | |
| else: | |
| score.append((st, ed, text_list[text_idx], pitch)) | |
| text_idx += 1 | |
| if text_idx >= len(text_list): | |
| break | |
| return score | |