try: import spaces USING_SPACES = True except ImportError: USING_SPACES = False import zero import gradio as gr import sys import threading import queue from io import TextIOBase import datetime import subprocess import os from inference import postprocess_inst_names from inference import inference_patch from convert import abc2xml, xml2, pdf2img def gpu_decorator(func): if USING_SPACES: return spaces.GPU(func) else: return func # 读取 prompt 组合 with open('prompts.txt', 'r') as f: prompts = f.readlines() valid_combinations = set() for prompt in prompts: prompt = prompt.strip() parts = prompt.split('_') valid_combinations.add((parts[0], parts[1], parts[2])) # 准备下拉框选项 periods = sorted({p for p, _, _ in valid_combinations}) composers = sorted({c for _, c, _ in valid_combinations}) instruments = sorted({i for _, _, i in valid_combinations}) # 动态更新作曲家、乐器下拉选项 def update_components(period, composer): if not period: return [ gr.update(choices=[], value=None, interactive=False), gr.update(choices=[], value=None, interactive=False) ] valid_composers = sorted({c for p, c, _ in valid_combinations if p == period}) valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else [] return [ gr.update( choices=valid_composers, value=composer if composer in valid_composers else None, interactive=True ), gr.update( choices=valid_instruments, value=None, interactive=bool(valid_instruments) ) ] # 自定义实时流,用于把模型推理过程输出到前端 class RealtimeStream(TextIOBase): def __init__(self, queue): self.queue = queue def write(self, text): self.queue.put(text) return len(text) def convert_files(abc_content, period, composer, instrumentation): if not all([period, composer, instrumentation]): raise gr.Error("Please complete a valid generation first before saving") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") prompt_str = f"{period}_{composer}_{instrumentation}" filename_base = f"{timestamp}_{prompt_str}" abc_filename = f"{filename_base}.abc" with open(abc_filename, "w", encoding="utf-8") as f: f.write(abc_content) # instrumentation replacement postprocessed_inst_abc = postprocess_inst_names(abc_content) filename_base_postinst = f"{filename_base}_postinst" with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f: f.write(postprocessed_inst_abc) # 转换文件 file_paths = {'abc': abc_filename} try: # abc2xml abc2xml(filename_base) abc2xml(filename_base_postinst) # xml2pdf xml2(filename_base, 'pdf') # xml2mid xml2(filename_base, 'mid') xml2(filename_base_postinst, 'mid') # xml2wav xml2(filename_base, 'wav') xml2(filename_base_postinst, 'wav') # 将PDF转为图片 images = pdf2img(filename_base) for i, image in enumerate(images): image.save(f"{filename_base}_page_{i+1}.png", "PNG") file_paths.update({ 'xml': f"{filename_base_postinst}.xml", 'pdf': f"{filename_base}.pdf", 'mid': f"{filename_base_postinst}.mid", 'wav': f"{filename_base_postinst}.wav", 'pages': len(images), 'current_page': 0, 'base': filename_base }) except Exception as e: raise gr.Error(f"文件处理失败: {str(e)}") return file_paths # 翻页控制函数 def update_page(direction, data): """ data 里面包含了 'pages','current_page','base' 三个关键信息 """ if not data: return None, gr.update(interactive=False), gr.update(interactive=False), data if direction == "prev" and data['current_page'] > 0: data['current_page'] -= 1 elif direction == "next" and data['current_page'] < data['pages'] - 1: data['current_page'] += 1 current_page_index = data['current_page'] # 更新图片路径 new_image = f"{data['base']}_page_{current_page_index+1}.png" # 当 current_page==0 时,prev_btn 不可用;当 current_page==pages-1 时,next_btn 不可用 prev_btn_state = gr.update(interactive=(current_page_index > 0)) next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1)) return new_image, prev_btn_state, next_btn_state, data @gpu_decorator def generate_music(period, composer, instrumentation): """ 需要保证每次 yield 的返回值数量一致。 我们这里准备返回 5 个值,对应: 1) process_output (中间推理信息) 2) final_output (最终 ABC) 3) pdf_image (PDF 第一页对应的 png 路径) 4) audio_player (WAV 路径) 5) pdf_state (翻页用的 state) """ if (period, composer, instrumentation) not in valid_combinations: # 如果组合非法,直接抛出错误 raise gr.Error("Invalid prompt combination! Please re-select from the period options") # # Ensure model weights were downloaded successfully # if not os.path.exists(model_weights_path): # raise gr.Error(f"Model weights not available at {model_weights_path}") output_queue = queue.Queue() original_stdout = sys.stdout sys.stdout = RealtimeStream(output_queue) result_container = [] def run_inference(): try: # 使用下载的模型权重路径进行推理 result = inference_patch(period, composer, instrumentation) result_container.append(result) finally: sys.stdout = original_stdout thread = threading.Thread(target=run_inference) thread.start() process_output = "" final_output_abc = "" pdf_image = None audio_file = None pdf_state = None # 先持续读中间输出 while thread.is_alive(): try: text = output_queue.get(timeout=0.1) process_output += text # 暂时没有最终 ABC,还没有转文件 yield process_output, final_output_abc, pdf_image, audio_file, pdf_state except queue.Empty: continue # 线程结束后,把剩余的队列都拿出来 while not output_queue.empty(): text = output_queue.get() process_output += text # 最终推理结果 final_result = result_container[0] if result_container else "" # 显示转换文件的提示 final_output_abc = "Converting files..." yield process_output, final_output_abc, pdf_image, audio_file, pdf_state # 做文件转换 try: file_paths = convert_files(final_result, period, composer, instrumentation) final_output_abc = final_result # 拿到第一张图片和 wav 文件 if file_paths['pages'] > 0: pdf_image = f"{file_paths['base']}_page_1.png" audio_file = file_paths['wav'] pdf_state = file_paths # 直接把转换后的信息字典拿来存到 state except Exception as e: # 如果失败了,把错误信息返回到输出框 yield process_output, f"Error converting files: {str(e)}", None, None, None return # 最后一次 yield,带上所有信息 yield process_output, final_output_abc, pdf_image, audio_file, pdf_state def get_file(file_type, period, composer, instrumentation): """ 返回本地的指定类型文件,用于 Gradio 下载 """ # 这里其实需要你根据先前保存下来的具体文件路径来返回,演示时可以简化 # 如果是按 timestamp 去匹配,可以把转换的文件都存在某个目录下再拿最新的 # 这里仅做示例: possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')] if not possible_files: return None # 简单返回最新的 possible_files.sort(key=os.path.getmtime) return possible_files[-1] css = """ /* 紧凑按钮样式 */ button[size="sm"] { padding: 4px 8px !important; margin: 2px !important; min-width: 60px; } /* PDF预览区 */ #pdf-preview { border-radius: 8px; /* 圆角 */ box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* 阴影 */ } .page-btn { padding: 12px !important; /* 增大点击区域 */ margin: auto !important; /* 垂直居中 */ } /* 按钮悬停效果 */ .page-btn:hover { background: #f0f0f0 !important; transform: scale(1.05); } /* 布局调整 */ .gr-row { gap: 10px !important; /* 元素间距 */ } /* 音频播放器 */ .audio-panel { margin-top: 15px !important; max-width: 400px; } #audio-preview audio { height: 200px !important; } /* 保存功能区 */ .save-as-row { margin-top: 15px; padding: 10px; border-top: 1px solid #eee; } .save-as-label { font-weight: bold; margin-right: 10px; align-self: center; } .save-buttons { gap: 5px; /* 按钮间距 */ } """ with gr.Blocks(css=css) as demo: gr.Markdown("## NotaGen") # 用于保存 PDF 页数、当前页等信息 pdf_state = gr.State() with gr.Column(): with gr.Row(): # 左侧栏 with gr.Column(): with gr.Row(): period_dd = gr.Dropdown( choices=periods, value=None, label="Period", interactive=True ) composer_dd = gr.Dropdown( choices=[], value=None, label="Composer", interactive=False ) instrument_dd = gr.Dropdown( choices=[], value=None, label="Instrumentation", interactive=False ) generate_btn = gr.Button("Generate!", variant="primary") process_output = gr.Textbox( label="Generation process", interactive=False, lines=2, max_lines=2, placeholder="Generation progress will be shown here..." ) final_output = gr.Textbox( label="Post-processed ABC notation scores", interactive=True, lines=8, max_lines=8, placeholder="Post-processed ABC scores will be shown here..." ) # 音频播放 audio_player = gr.Audio( label="Audio Preview", format="wav", interactive=False, # container=False, # elem_id="audio-preview" ) # 右侧栏 with gr.Column(): # 图片容器 pdf_image = gr.Image( label="Sheet Music Preview", show_label=False, height=650, type="filepath", elem_id="pdf-preview", interactive=False, show_download_button=False ) # 翻页按钮 with gr.Row(): prev_btn = gr.Button( "⬅️ Last Page", variant="secondary", size="sm", elem_classes="page-btn" ) next_btn = gr.Button( "Next Page ➡️", variant="secondary", size="sm", elem_classes="page-btn" ) # 按钮组 with gr.Row(): gr.Markdown("**Save As: (Scroll down to get the link)**") save_abc = gr.Button("🅰️ ABC", variant="secondary", size="sm") save_xml = gr.Button("🎼 XML", variant="secondary", size="sm") save_pdf = gr.Button("📑 PDF", variant="secondary", size="sm") save_mid = gr.Button("🎹 MIDI", variant="secondary", size="sm") save_wav = gr.Button("🎧 WAV", variant="secondary", size="sm") # save_status = gr.Textbox( # label="Save Status", # interactive=False, # visible=True, # max_lines=1 # ) # 下拉框联动 period_dd.change( update_components, inputs=[period_dd, composer_dd], outputs=[composer_dd, instrument_dd] ) composer_dd.change( update_components, inputs=[period_dd, composer_dd], outputs=[composer_dd, instrument_dd] ) # 点击生成按钮,注意 outputs 要和 generate_music 里每次 yield 保持一致 generate_btn.click( generate_music, inputs=[period_dd, composer_dd, instrument_dd], outputs=[process_output, final_output, pdf_image, audio_player, pdf_state] ) # 翻页 prev_signal = gr.Textbox(value="prev", visible=False) next_signal = gr.Textbox(value="next", visible=False) prev_btn.click( update_page, inputs=[prev_signal, pdf_state], # ✅ 使用组件 outputs=[pdf_image, prev_btn, next_btn, pdf_state] ) next_btn.click( update_page, inputs=[next_signal, pdf_state], # ✅ 使用组件 outputs=[pdf_image, prev_btn, next_btn, pdf_state] ) # 文件保存按钮 save_abc.click( lambda state: state.get('abc') if state else None, inputs=[pdf_state], outputs=gr.File(label="abc", visible=True) ) save_xml.click( lambda state: state.get('xml') if state else None, inputs=[pdf_state], outputs=gr.File(label="xml", visible=True) ) save_pdf.click( lambda state: state.get('pdf') if state else None, inputs=[pdf_state], outputs=gr.File(label="pdf", visible=True) ) save_mid.click( lambda state: state.get('mid') if state else None, inputs=[pdf_state], outputs=gr.File(label="midi", visible=True) ) save_wav.click( lambda state: state.get('wav') if state else None, inputs=[pdf_state], outputs=gr.File(label="wav", visible=True) ) if __name__ == "__main__": # Configure GPU/CPU handling import torch # Function to initialize CUDA safely and verify it's working def is_cuda_working(): try: if torch.cuda.is_available(): # Test CUDA initialization with a small operation test_tensor = torch.tensor([1.0], device="cuda") _ = test_tensor * 2 return True return False except Exception as e: print(f"CUDA initialization test failed: {e}") return False # Check if running on Hugging Face Spaces if "SPACE_ID" in os.environ: cuda_working = is_cuda_working() if cuda_working: print("GPU is available and working. Using CUDA.") # You might want to set some environment variables or configurations here os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" else: print("CUDA not working properly. Forcing CPU mode.") os.environ["CUDA_VISIBLE_DEVICES"] = "" torch.backends.cudnn.enabled = False # Launch with minimal parameters on Spaces demo.launch() else: # Running locally - use custom server settings and share print(f"Running locally with device: {'cuda' if torch.cuda.is_available() else 'cpu'}") demo.launch( server_name="0.0.0.0", server_port=7860, share=True # 确保外部访问 )