| | |
| | """ |
| | Batch Generate Text2Music Examples using LM |
| | Generates 50 examples and saves them to examples/text2music/ |
| | """ |
| | import os |
| | import json |
| | import sys |
| | from pathlib import Path |
| |
|
| | |
| | project_root = Path(__file__).parent |
| | sys.path.insert(0, str(project_root)) |
| |
|
| | from acestep.llm_inference import LLMHandler |
| | from loguru import logger |
| | from tqdm import tqdm |
| |
|
| |
|
| | def generate_examples(num_examples=50, output_dir="examples/text2music", start_index=1): |
| | """ |
| | Generate examples using LM and save to JSON files |
| | |
| | Args: |
| | num_examples: Number of examples to generate |
| | output_dir: Output directory for JSON files |
| | start_index: Starting index for example files |
| | """ |
| | |
| | logger.info("Initializing LLM Handler...") |
| | llm_handler = LLMHandler() |
| | |
| | |
| | checkpoint_dir = os.path.join(project_root, "checkpoints") |
| | |
| | |
| | available_models = llm_handler.get_available_5hz_lm_models() |
| | if not available_models: |
| | logger.error("No 5Hz LM models found in checkpoints directory") |
| | return |
| | |
| | |
| | lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_models else available_models[0] |
| | logger.info(f"Using LM model: {lm_model}") |
| | |
| | |
| | status_msg, success = llm_handler.initialize( |
| | checkpoint_dir=checkpoint_dir, |
| | lm_model_path=lm_model, |
| | backend="vllm", |
| | device="auto", |
| | offload_to_cpu=False, |
| | dtype=None, |
| | ) |
| | |
| | if not success: |
| | logger.error(f"Failed to initialize LM: {status_msg}") |
| | return |
| | |
| | logger.info(f"LM initialized successfully: {status_msg}") |
| | |
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | successful_count = 0 |
| | failed_count = 0 |
| | |
| | for i in tqdm(range(num_examples), desc="Generating examples"): |
| | example_num = start_index + i |
| | output_file = os.path.join(output_dir, f"example_{example_num:02d}.json") |
| | |
| | logger.info(f"Generating example {example_num}/{start_index + num_examples - 1}...") |
| | |
| | try: |
| | |
| | metadata, status = llm_handler.understand_audio_from_codes( |
| | audio_codes="NO USER INPUT", |
| | use_constrained_decoding=True, |
| | temperature=0.85, |
| | cfg_scale=1.0, |
| | top_k=None, |
| | top_p=0.9, |
| | ) |
| | |
| | if not metadata: |
| | logger.warning(f"Failed to generate example {example_num}: {status}") |
| | failed_count += 1 |
| | continue |
| | |
| | |
| | example_data = { |
| | "think": True, |
| | "caption": metadata.get("caption", ""), |
| | "lyrics": metadata.get("lyrics", ""), |
| | } |
| | |
| | |
| | if "bpm" in metadata and metadata["bpm"] not in [None, "N/A", ""]: |
| | try: |
| | |
| | example_data["bpm"] = int(metadata["bpm"]) if isinstance(metadata["bpm"], (int, str)) else metadata["bpm"] |
| | except (ValueError, TypeError): |
| | example_data["bpm"] = metadata["bpm"] |
| | |
| | if "duration" in metadata and metadata["duration"] not in [None, "N/A", ""]: |
| | try: |
| | |
| | example_data["duration"] = int(metadata["duration"]) if isinstance(metadata["duration"], (int, str)) else metadata["duration"] |
| | except (ValueError, TypeError): |
| | example_data["duration"] = metadata["duration"] |
| | |
| | if "keyscale" in metadata and metadata["keyscale"] not in [None, "N/A", ""]: |
| | example_data["keyscale"] = metadata["keyscale"] |
| | |
| | if "language" in metadata and metadata["language"] not in [None, "N/A", ""]: |
| | example_data["language"] = metadata["language"] |
| | |
| | if "timesignature" in metadata and metadata["timesignature"] not in [None, "N/A", ""]: |
| | example_data["timesignature"] = metadata["timesignature"] |
| | |
| | |
| | with open(output_file, 'w', encoding='utf-8') as f: |
| | json.dump(example_data, f, ensure_ascii=False, indent=4) |
| | |
| | logger.info(f"✅ Saved example {example_num} to {output_file}") |
| | logger.info(f" Caption preview: {example_data['caption'][:100]}...") |
| | successful_count += 1 |
| | |
| | except Exception as e: |
| | logger.error(f"❌ Error generating example {example_num}: {str(e)}") |
| | failed_count += 1 |
| | continue |
| | |
| | |
| | logger.info(f"\n{'='*60}") |
| | logger.info(f"Generation complete!") |
| | logger.info(f"Successful: {successful_count}/{num_examples}") |
| | logger.info(f"Failed: {failed_count}/{num_examples}") |
| | logger.info(f"Output directory: {output_dir}") |
| | logger.info(f"{'='*60}\n") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Generate text2music examples using LM") |
| | parser.add_argument("--num", type=int, default=100, help="Number of examples to generate (default: 100)") |
| | parser.add_argument("--output-dir", type=str, default="examples/text2music", help="Output directory (default: examples/text2music)") |
| | parser.add_argument("--start-index", type=int, default=1, help="Starting index for example files (default: 1)") |
| | |
| | args = parser.parse_args() |
| | |
| | generate_examples( |
| | num_examples=args.num, |
| | output_dir=args.output_dir, |
| | start_index=args.start_index |
| | ) |
| |
|