mpt2026 / webui /Main.py
chaowenguo's picture
Upload 108 files
e58b6a6 verified
import os
import platform
import sys
from uuid import uuid4
import streamlit as st
from loguru import logger
# Add the root directory of the project to the system path to allow importing modules from the project
root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
if root_dir not in sys.path:
sys.path.append(root_dir)
print("******** sys.path ********")
print(sys.path)
print("")
from app.config import config
from app.models.schema import (
MaterialInfo,
VideoAspect,
VideoConcatMode,
VideoParams,
VideoTransitionMode,
)
from app.services import llm, voice
from app.services import task as tm
from app.utils import utils
st.set_page_config(
page_title="MoneyPrinterTurbo",
page_icon="🤖",
layout="wide",
initial_sidebar_state="auto",
menu_items={
"Report a bug": "https://github.com/harry0703/MoneyPrinterTurbo/issues",
"About": "# MoneyPrinterTurbo\nSimply provide a topic or keyword for a video, and it will "
"automatically generate the video copy, video materials, video subtitles, "
"and video background music before synthesizing a high-definition short "
"video.\n\nhttps://github.com/harry0703/MoneyPrinterTurbo",
},
)
streamlit_style = """
<style>
h1 {
padding-top: 0 !important;
}
</style>
"""
st.markdown(streamlit_style, unsafe_allow_html=True)
# 定义资源目录
font_dir = os.path.join(root_dir, "resource", "fonts")
song_dir = os.path.join(root_dir, "resource", "songs")
i18n_dir = os.path.join(root_dir, "webui", "i18n")
config_file = os.path.join(root_dir, "webui", ".streamlit", "webui.toml")
system_locale = utils.get_system_locale()
if "video_subject" not in st.session_state:
st.session_state["video_subject"] = ""
if "video_script" not in st.session_state:
st.session_state["video_script"] = ""
if "video_terms" not in st.session_state:
st.session_state["video_terms"] = ""
if "ui_language" not in st.session_state:
st.session_state["ui_language"] = config.ui.get("language", system_locale)
# 加载语言文件
locales = utils.load_locales(i18n_dir)
# 创建一个顶部栏,包含标题和语言选择
title_col, lang_col = st.columns([3, 1])
with title_col:
st.title(f"MoneyPrinterTurbo v{config.project_version}")
with lang_col:
display_languages = []
selected_index = 0
for i, code in enumerate(locales.keys()):
display_languages.append(f"{code} - {locales[code].get('Language')}")
if code == st.session_state.get("ui_language", ""):
selected_index = i
selected_language = st.selectbox(
"Language / 语言",
options=display_languages,
index=selected_index,
key="top_language_selector",
label_visibility="collapsed",
)
if selected_language:
code = selected_language.split(" - ")[0].strip()
st.session_state["ui_language"] = code
config.ui["language"] = code
support_locales = [
"zh-CN",
"zh-HK",
"zh-TW",
"de-DE",
"en-US",
"fr-FR",
"vi-VN",
"th-TH",
]
def get_all_fonts():
fonts = []
for root, dirs, files in os.walk(font_dir):
for file in files:
if file.endswith(".ttf") or file.endswith(".ttc"):
fonts.append(file)
fonts.sort()
return fonts
def get_all_songs():
songs = []
for root, dirs, files in os.walk(song_dir):
for file in files:
if file.endswith(".mp3"):
songs.append(file)
return songs
def open_task_folder(task_id):
try:
sys = platform.system()
path = os.path.join(root_dir, "storage", "tasks", task_id)
if os.path.exists(path):
if sys == "Windows":
os.system(f"start {path}")
if sys == "Darwin":
os.system(f"open {path}")
except Exception as e:
logger.error(e)
def scroll_to_bottom():
js = """
<script>
console.log("scroll_to_bottom");
function scroll(dummy_var_to_force_repeat_execution){
var sections = parent.document.querySelectorAll('section.main');
console.log(sections);
for(let index = 0; index<sections.length; index++) {
sections[index].scrollTop = sections[index].scrollHeight;
}
}
scroll(1);
</script>
"""
st.components.v1.html(js, height=0, width=0)
def init_log():
logger.remove()
_lvl = "DEBUG"
def format_record(record):
# 获取日志记录中的文件全路径
file_path = record["file"].path
# 将绝对路径转换为相对于项目根目录的路径
relative_path = os.path.relpath(file_path, root_dir)
# 更新记录中的文件路径
record["file"].path = f"./{relative_path}"
# 返回修改后的格式字符串
# 您可以根据需要调整这里的格式
record["message"] = record["message"].replace(root_dir, ".")
_format = (
"<green>{time:%Y-%m-%d %H:%M:%S}</> | "
+ "<level>{level}</> | "
+ '"{file.path}:{line}":<blue> {function}</> '
+ "- <level>{message}</>"
+ "\n"
)
return _format
logger.add(
sys.stdout,
level=_lvl,
format=format_record,
colorize=True,
)
init_log()
locales = utils.load_locales(i18n_dir)
def tr(key):
loc = locales.get(st.session_state["ui_language"], {})
return loc.get("Translation", {}).get(key, key)
# 创建基础设置折叠框
if not config.app.get("hide_config", False):
with st.expander(tr("Basic Settings"), expanded=False):
config_panels = st.columns(3)
left_config_panel = config_panels[0]
middle_config_panel = config_panels[1]
right_config_panel = config_panels[2]
# 左侧面板 - 日志设置
with left_config_panel:
# 是否隐藏配置面板
hide_config = st.checkbox(
tr("Hide Basic Settings"), value=config.app.get("hide_config", False)
)
config.app["hide_config"] = hide_config
# 是否禁用日志显示
hide_log = st.checkbox(
tr("Hide Log"), value=config.ui.get("hide_log", False)
)
config.ui["hide_log"] = hide_log
# 中间面板 - LLM 设置
with middle_config_panel:
st.write(tr("LLM Settings"))
llm_providers = [
"OpenAI",
"Moonshot",
"Azure",
"Qwen",
"DeepSeek",
"Gemini",
"Ollama",
"G4f",
"OneAPI",
"Cloudflare",
"ERNIE",
"Pollinations",
]
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0
for i, provider in enumerate(llm_providers):
if provider.lower() == saved_llm_provider:
saved_llm_provider_index = i
break
llm_provider = st.selectbox(
tr("LLM Provider"),
options=llm_providers,
index=saved_llm_provider_index,
)
llm_helper = st.container()
llm_provider = llm_provider.lower()
config.app["llm_provider"] = llm_provider
llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
llm_secret_key = config.app.get(
f"{llm_provider}_secret_key", ""
) # only for baidu ernie
llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
tips = ""
if llm_provider == "ollama":
if not llm_model_name:
llm_model_name = "qwen:7b"
if not llm_base_url:
llm_base_url = "http://localhost:11434/v1"
with llm_helper:
tips = """
##### Ollama配置说明
- **API Key**: 随便填写,比如 123
- **Base Url**: 一般为 http://localhost:11434/v1
- 如果 `MoneyPrinterTurbo` 和 `Ollama` **不在同一台机器上**,需要填写 `Ollama` 机器的IP地址
- 如果 `MoneyPrinterTurbo` 是 `Docker` 部署,建议填写 `http://host.docker.internal:11434/v1`
- **Model Name**: 使用 `ollama list` 查看,比如 `qwen:7b`
"""
if llm_provider == "openai":
if not llm_model_name:
llm_model_name = "gpt-3.5-turbo"
with llm_helper:
tips = """
##### OpenAI 配置说明
> 需要VPN开启全局流量模式
- **API Key**: [点击到官网申请](https://platform.openai.com/api-keys)
- **Base Url**: 可以留空
- **Model Name**: 填写**有权限**的模型,[点击查看模型列表](https://platform.openai.com/settings/organization/limits)
"""
if llm_provider == "moonshot":
if not llm_model_name:
llm_model_name = "moonshot-v1-8k"
with llm_helper:
tips = """
##### Moonshot 配置说明
- **API Key**: [点击到官网申请](https://platform.moonshot.cn/console/api-keys)
- **Base Url**: 固定为 https://api.moonshot.cn/v1
- **Model Name**: 比如 moonshot-v1-8k,[点击查看模型列表](https://platform.moonshot.cn/docs/intro#%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8)
"""
if llm_provider == "oneapi":
if not llm_model_name:
llm_model_name = (
"claude-3-5-sonnet-20240620" # 默认模型,可以根据需要调整
)
with llm_helper:
tips = """
##### OneAPI 配置说明
- **API Key**: 填写您的 OneAPI 密钥
- **Base Url**: 填写 OneAPI 的基础 URL
- **Model Name**: 填写您要使用的模型名称,例如 claude-3-5-sonnet-20240620
"""
if llm_provider == "qwen":
if not llm_model_name:
llm_model_name = "qwen-max"
with llm_helper:
tips = """
##### 通义千问Qwen 配置说明
- **API Key**: [点击到官网申请](https://dashscope.console.aliyun.com/apiKey)
- **Base Url**: 留空
- **Model Name**: 比如 qwen-max,[点击查看模型列表](https://help.aliyun.com/zh/dashscope/developer-reference/model-introduction#3ef6d0bcf91wy)
"""
if llm_provider == "g4f":
if not llm_model_name:
llm_model_name = "gpt-3.5-turbo"
with llm_helper:
tips = """
##### gpt4free 配置说明
> [GitHub开源项目](https://github.com/xtekky/gpt4free),可以免费使用GPT模型,但是**稳定性较差**
- **API Key**: 随便填写,比如 123
- **Base Url**: 留空
- **Model Name**: 比如 gpt-3.5-turbo,[点击查看模型列表](https://github.com/xtekky/gpt4free/blob/main/g4f/models.py#L308)
"""
if llm_provider == "azure":
with llm_helper:
tips = """
##### Azure 配置说明
> [点击查看如何部署模型](https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/create-resource)
- **API Key**: [点击到Azure后台创建](https://portal.azure.com/#view/Microsoft_Azure_ProjectOxford/CognitiveServicesHub/~/OpenAI)
- **Base Url**: 留空
- **Model Name**: 填写你实际的部署名
"""
if llm_provider == "gemini":
if not llm_model_name:
llm_model_name = "gemini-1.0-pro"
with llm_helper:
tips = """
##### Gemini 配置说明
> 需要VPN开启全局流量模式
- **API Key**: [点击到官网申请](https://ai.google.dev/)
- **Base Url**: 留空
- **Model Name**: 比如 gemini-1.0-pro
"""
if llm_provider == "deepseek":
if not llm_model_name:
llm_model_name = "deepseek-chat"
if not llm_base_url:
llm_base_url = "https://api.deepseek.com"
with llm_helper:
tips = """
##### DeepSeek 配置说明
- **API Key**: [点击到官网申请](https://platform.deepseek.com/api_keys)
- **Base Url**: 固定为 https://api.deepseek.com
- **Model Name**: 固定为 deepseek-chat
"""
if llm_provider == "ernie":
with llm_helper:
tips = """
##### 百度文心一言 配置说明
- **API Key**: [点击到官网申请](https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application)
- **Secret Key**: [点击到官网申请](https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application)
- **Base Url**: 填写 **请求地址** [点击查看文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11#%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E)
"""
if llm_provider == "pollinations":
if not llm_model_name:
llm_model_name = "default"
with llm_helper:
tips = """
##### Pollinations AI Configuration
- **API Key**: Optional - Leave empty for public access
- **Base Url**: Default is https://text.pollinations.ai/openai
- **Model Name**: Use 'openai-fast' or specify a model name
"""
if tips and config.ui["language"] == "zh":
st.warning(
"中国用户建议使用 **DeepSeek** 或 **Moonshot** 作为大模型提供商\n- 国内可直接访问,不需要VPN \n- 注册就送额度,基本够用"
)
st.info(tips)
st_llm_api_key = st.text_input(
tr("API Key"), value=llm_api_key, type="password"
)
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
st_llm_model_name = ""
if llm_provider != "ernie":
st_llm_model_name = st.text_input(
tr("Model Name"),
value=llm_model_name,
key=f"{llm_provider}_model_name_input",
)
if st_llm_model_name:
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
else:
st_llm_model_name = None
if st_llm_api_key:
config.app[f"{llm_provider}_api_key"] = st_llm_api_key
if st_llm_base_url:
config.app[f"{llm_provider}_base_url"] = st_llm_base_url
if st_llm_model_name:
config.app[f"{llm_provider}_model_name"] = st_llm_model_name
if llm_provider == "ernie":
st_llm_secret_key = st.text_input(
tr("Secret Key"), value=llm_secret_key, type="password"
)
config.app[f"{llm_provider}_secret_key"] = st_llm_secret_key
if llm_provider == "cloudflare":
st_llm_account_id = st.text_input(
tr("Account ID"), value=llm_account_id
)
if st_llm_account_id:
config.app[f"{llm_provider}_account_id"] = st_llm_account_id
# 右侧面板 - API 密钥设置
with right_config_panel:
def get_keys_from_config(cfg_key):
api_keys = config.app.get(cfg_key, [])
if isinstance(api_keys, str):
api_keys = [api_keys]
api_key = ", ".join(api_keys)
return api_key
def save_keys_to_config(cfg_key, value):
value = value.replace(" ", "")
if value:
config.app[cfg_key] = value.split(",")
st.write(tr("Video Source Settings"))
pexels_api_key = get_keys_from_config("pexels_api_keys")
pexels_api_key = st.text_input(
tr("Pexels API Key"), value=pexels_api_key, type="password"
)
save_keys_to_config("pexels_api_keys", pexels_api_key)
pixabay_api_key = get_keys_from_config("pixabay_api_keys")
pixabay_api_key = st.text_input(
tr("Pixabay API Key"), value=pixabay_api_key, type="password"
)
save_keys_to_config("pixabay_api_keys", pixabay_api_key)
llm_provider = config.app.get("llm_provider", "").lower()
panel = st.columns(3)
left_panel = panel[0]
middle_panel = panel[1]
right_panel = panel[2]
params = VideoParams(video_subject="")
uploaded_files = []
with left_panel:
with st.container(border=True):
st.write(tr("Video Script Settings"))
params.video_subject = st.text_input(
tr("Video Subject"),
value=st.session_state["video_subject"],
key="video_subject_input",
).strip()
video_languages = [
(tr("Auto Detect"), ""),
]
for code in support_locales:
video_languages.append((code, code))
selected_index = st.selectbox(
tr("Script Language"),
index=0,
options=range(
len(video_languages)
), # Use the index as the internal option value
format_func=lambda x: video_languages[x][
0
], # The label is displayed to the user
)
params.video_language = video_languages[selected_index][1]
if st.button(
tr("Generate Video Script and Keywords"), key="auto_generate_script"
):
with st.spinner(tr("Generating Video Script and Keywords")):
script = llm.generate_script(
video_subject=params.video_subject, language=params.video_language
)
terms = llm.generate_terms(params.video_subject, script)
if "Error: " in script:
st.error(tr(script))
elif "Error: " in terms:
st.error(tr(terms))
else:
st.session_state["video_script"] = script
st.session_state["video_terms"] = ", ".join(terms)
params.video_script = st.text_area(
tr("Video Script"), value=st.session_state["video_script"], height=280
)
if st.button(tr("Generate Video Keywords"), key="auto_generate_terms"):
if not params.video_script:
st.error(tr("Please Enter the Video Subject"))
st.stop()
with st.spinner(tr("Generating Video Keywords")):
terms = llm.generate_terms(params.video_subject, params.video_script)
if "Error: " in terms:
st.error(tr(terms))
else:
st.session_state["video_terms"] = ", ".join(terms)
params.video_terms = st.text_area(
tr("Video Keywords"), value=st.session_state["video_terms"]
)
with middle_panel:
with st.container(border=True):
st.write(tr("Video Settings"))
video_concat_modes = [
(tr("Sequential"), "sequential"),
(tr("Random"), "random"),
]
video_sources = [
(tr("Pexels"), "pexels"),
(tr("Pixabay"), "pixabay"),
(tr("Local file"), "local"),
(tr("TikTok"), "douyin"),
(tr("Bilibili"), "bilibili"),
(tr("Xiaohongshu"), "xiaohongshu"),
]
saved_video_source_name = config.app.get("video_source", "pexels")
saved_video_source_index = [v[1] for v in video_sources].index(
saved_video_source_name
)
selected_index = st.selectbox(
tr("Video Source"),
options=range(len(video_sources)),
format_func=lambda x: video_sources[x][0],
index=saved_video_source_index,
)
params.video_source = video_sources[selected_index][1]
config.app["video_source"] = params.video_source
if params.video_source == "local":
uploaded_files = st.file_uploader(
"Upload Local Files",
type=["mp4", "mov", "avi", "flv", "mkv", "jpg", "jpeg", "png"],
accept_multiple_files=True,
)
selected_index = st.selectbox(
tr("Video Concat Mode"),
index=1,
options=range(
len(video_concat_modes)
), # Use the index as the internal option value
format_func=lambda x: video_concat_modes[x][
0
], # The label is displayed to the user
)
params.video_concat_mode = VideoConcatMode(
video_concat_modes[selected_index][1]
)
# 视频转场模式
video_transition_modes = [
(tr("None"), VideoTransitionMode.none.value),
(tr("Shuffle"), VideoTransitionMode.shuffle.value),
(tr("FadeIn"), VideoTransitionMode.fade_in.value),
(tr("FadeOut"), VideoTransitionMode.fade_out.value),
(tr("SlideIn"), VideoTransitionMode.slide_in.value),
(tr("SlideOut"), VideoTransitionMode.slide_out.value),
]
selected_index = st.selectbox(
tr("Video Transition Mode"),
options=range(len(video_transition_modes)),
format_func=lambda x: video_transition_modes[x][0],
index=0,
)
params.video_transition_mode = VideoTransitionMode(
video_transition_modes[selected_index][1]
)
video_aspect_ratios = [
(tr("Portrait"), VideoAspect.portrait.value),
(tr("Landscape"), VideoAspect.landscape.value),
]
selected_index = st.selectbox(
tr("Video Ratio"),
options=range(
len(video_aspect_ratios)
), # Use the index as the internal option value
format_func=lambda x: video_aspect_ratios[x][
0
], # The label is displayed to the user
)
params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1])
params.video_clip_duration = st.selectbox(
tr("Clip Duration"), options=[2, 3, 4, 5, 6, 7, 8, 9, 10], index=1
)
params.video_count = st.selectbox(
tr("Number of Videos Generated Simultaneously"),
options=[1, 2, 3, 4, 5],
index=0,
)
with st.container(border=True):
st.write(tr("Audio Settings"))
# 添加TTS服务器选择下拉框
tts_servers = [
("azure-tts-v1", "Azure TTS V1"),
("azure-tts-v2", "Azure TTS V2"),
("siliconflow", "SiliconFlow TTS"),
]
# 获取保存的TTS服务器,默认为v1
saved_tts_server = config.ui.get("tts_server", "azure-tts-v1")
saved_tts_server_index = 0
for i, (server_value, _) in enumerate(tts_servers):
if server_value == saved_tts_server:
saved_tts_server_index = i
break
selected_tts_server_index = st.selectbox(
tr("TTS Servers"),
options=range(len(tts_servers)),
format_func=lambda x: tts_servers[x][1],
index=saved_tts_server_index,
)
selected_tts_server = tts_servers[selected_tts_server_index][0]
config.ui["tts_server"] = selected_tts_server
# 根据选择的TTS服务器获取声音列表
filtered_voices = []
if selected_tts_server == "siliconflow":
# 获取硅基流动的声音列表
filtered_voices = voice.get_siliconflow_voices()
else:
# 获取Azure的声音列表
all_voices = voice.get_all_azure_voices(filter_locals=None)
# 根据选择的TTS服务器筛选声音
for v in all_voices:
if selected_tts_server == "azure-tts-v2":
# V2版本的声音名称中包含"v2"
if "V2" in v:
filtered_voices.append(v)
else:
# V1版本的声音名称中不包含"v2"
if "V2" not in v:
filtered_voices.append(v)
friendly_names = {
v: v.replace("Female", tr("Female"))
.replace("Male", tr("Male"))
.replace("Neural", "")
for v in filtered_voices
}
saved_voice_name = config.ui.get("voice_name", "")
saved_voice_name_index = 0
# 检查保存的声音是否在当前筛选的声音列表中
if saved_voice_name in friendly_names:
saved_voice_name_index = list(friendly_names.keys()).index(saved_voice_name)
else:
# 如果不在,则根据当前UI语言选择一个默认声音
for i, v in enumerate(filtered_voices):
if v.lower().startswith(st.session_state["ui_language"].lower()):
saved_voice_name_index = i
break
# 如果没有找到匹配的声音,使用第一个声音
if saved_voice_name_index >= len(friendly_names) and friendly_names:
saved_voice_name_index = 0
# 确保有声音可选
if friendly_names:
selected_friendly_name = st.selectbox(
tr("Speech Synthesis"),
options=list(friendly_names.values()),
index=min(saved_voice_name_index, len(friendly_names) - 1)
if friendly_names
else 0,
)
voice_name = list(friendly_names.keys())[
list(friendly_names.values()).index(selected_friendly_name)
]
params.voice_name = voice_name
config.ui["voice_name"] = voice_name
else:
# 如果没有声音可选,显示提示信息
st.warning(
tr(
"No voices available for the selected TTS server. Please select another server."
)
)
params.voice_name = ""
config.ui["voice_name"] = ""
# 只有在有声音可选时才显示试听按钮
if friendly_names and st.button(tr("Play Voice")):
play_content = params.video_subject
if not play_content:
play_content = params.video_script
if not play_content:
play_content = tr("Voice Example")
with st.spinner(tr("Synthesizing Voice")):
temp_dir = utils.storage_dir("temp", create=True)
audio_file = os.path.join(temp_dir, f"tmp-voice-{str(uuid4())}.mp3")
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=params.voice_rate,
voice_file=audio_file,
voice_volume=params.voice_volume,
)
# if the voice file generation failed, try again with a default content.
if not sub_maker:
play_content = "This is a example voice. if you hear this, the voice synthesis failed with the original content."
sub_maker = voice.tts(
text=play_content,
voice_name=voice_name,
voice_rate=params.voice_rate,
voice_file=audio_file,
voice_volume=params.voice_volume,
)
if sub_maker and os.path.exists(audio_file):
st.audio(audio_file, format="audio/mp3")
if os.path.exists(audio_file):
os.remove(audio_file)
# 当选择V2版本或者声音是V2声音时,显示服务区域和API key输入框
if selected_tts_server == "azure-tts-v2" or (
voice_name and voice.is_azure_v2_voice(voice_name)
):
saved_azure_speech_region = config.azure.get("speech_region", "")
saved_azure_speech_key = config.azure.get("speech_key", "")
azure_speech_region = st.text_input(
tr("Speech Region"),
value=saved_azure_speech_region,
key="azure_speech_region_input",
)
azure_speech_key = st.text_input(
tr("Speech Key"),
value=saved_azure_speech_key,
type="password",
key="azure_speech_key_input",
)
config.azure["speech_region"] = azure_speech_region
config.azure["speech_key"] = azure_speech_key
# 当选择硅基流动时,显示API key输入框和说明信息
if selected_tts_server == "siliconflow" or (
voice_name and voice.is_siliconflow_voice(voice_name)
):
saved_siliconflow_api_key = config.siliconflow.get("api_key", "")
siliconflow_api_key = st.text_input(
tr("SiliconFlow API Key"),
value=saved_siliconflow_api_key,
type="password",
key="siliconflow_api_key_input",
)
# 显示硅基流动的说明信息
st.info(
tr("SiliconFlow TTS Settings")
+ ":\n"
+ "- "
+ tr("Speed: Range [0.25, 4.0], default is 1.0")
+ "\n"
+ "- "
+ tr("Volume: Uses Speech Volume setting, default 1.0 maps to gain 0")
)
config.siliconflow["api_key"] = siliconflow_api_key
params.voice_volume = st.selectbox(
tr("Speech Volume"),
options=[0.6, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0, 4.0, 5.0],
index=2,
)
params.voice_rate = st.selectbox(
tr("Speech Rate"),
options=[0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.8, 2.0],
index=2,
)
bgm_options = [
(tr("No Background Music"), ""),
(tr("Random Background Music"), "random"),
(tr("Custom Background Music"), "custom"),
]
selected_index = st.selectbox(
tr("Background Music"),
index=1,
options=range(
len(bgm_options)
), # Use the index as the internal option value
format_func=lambda x: bgm_options[x][
0
], # The label is displayed to the user
)
# Get the selected background music type
params.bgm_type = bgm_options[selected_index][1]
# Show or hide components based on the selection
if params.bgm_type == "custom":
custom_bgm_file = st.text_input(
tr("Custom Background Music File"), key="custom_bgm_file_input"
)
if custom_bgm_file and os.path.exists(custom_bgm_file):
params.bgm_file = custom_bgm_file
# st.write(f":red[已选择自定义背景音乐]:**{custom_bgm_file}**")
params.bgm_volume = st.selectbox(
tr("Background Music Volume"),
options=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
index=2,
)
with right_panel:
with st.container(border=True):
st.write(tr("Subtitle Settings"))
params.subtitle_enabled = st.checkbox(tr("Enable Subtitles"), value=True)
font_names = get_all_fonts()
saved_font_name = config.ui.get("font_name", "MicrosoftYaHeiBold.ttc")
saved_font_name_index = 0
if saved_font_name in font_names:
saved_font_name_index = font_names.index(saved_font_name)
params.font_name = st.selectbox(
tr("Font"), font_names, index=saved_font_name_index
)
config.ui["font_name"] = params.font_name
subtitle_positions = [
(tr("Top"), "top"),
(tr("Center"), "center"),
(tr("Bottom"), "bottom"),
(tr("Custom"), "custom"),
]
selected_index = st.selectbox(
tr("Position"),
index=2,
options=range(len(subtitle_positions)),
format_func=lambda x: subtitle_positions[x][0],
)
params.subtitle_position = subtitle_positions[selected_index][1]
if params.subtitle_position == "custom":
custom_position = st.text_input(
tr("Custom Position (% from top)"),
value="70.0",
key="custom_position_input",
)
try:
params.custom_position = float(custom_position)
if params.custom_position < 0 or params.custom_position > 100:
st.error(tr("Please enter a value between 0 and 100"))
except ValueError:
st.error(tr("Please enter a valid number"))
font_cols = st.columns([0.3, 0.7])
with font_cols[0]:
saved_text_fore_color = config.ui.get("text_fore_color", "#FFFFFF")
params.text_fore_color = st.color_picker(
tr("Font Color"), saved_text_fore_color
)
config.ui["text_fore_color"] = params.text_fore_color
with font_cols[1]:
saved_font_size = config.ui.get("font_size", 60)
params.font_size = st.slider(tr("Font Size"), 30, 100, saved_font_size)
config.ui["font_size"] = params.font_size
stroke_cols = st.columns([0.3, 0.7])
with stroke_cols[0]:
params.stroke_color = st.color_picker(tr("Stroke Color"), "#000000")
with stroke_cols[1]:
params.stroke_width = st.slider(tr("Stroke Width"), 0.0, 10.0, 1.5)
start_button = st.button(tr("Generate Video"), use_container_width=True, type="primary")
if start_button:
config.save_config()
task_id = str(uuid4())
if not params.video_subject and not params.video_script:
st.error(tr("Video Script and Subject Cannot Both Be Empty"))
scroll_to_bottom()
st.stop()
if params.video_source not in ["pexels", "pixabay", "local"]:
st.error(tr("Please Select a Valid Video Source"))
scroll_to_bottom()
st.stop()
if params.video_source == "pexels" and not config.app.get("pexels_api_keys", ""):
st.error(tr("Please Enter the Pexels API Key"))
scroll_to_bottom()
st.stop()
if params.video_source == "pixabay" and not config.app.get("pixabay_api_keys", ""):
st.error(tr("Please Enter the Pixabay API Key"))
scroll_to_bottom()
st.stop()
if uploaded_files:
local_videos_dir = utils.storage_dir("local_videos", create=True)
for file in uploaded_files:
file_path = os.path.join(local_videos_dir, f"{file.file_id}_{file.name}")
with open(file_path, "wb") as f:
f.write(file.getbuffer())
m = MaterialInfo()
m.provider = "local"
m.url = file_path
if not params.video_materials:
params.video_materials = []
params.video_materials.append(m)
log_container = st.empty()
log_records = []
def log_received(msg):
if config.ui["hide_log"]:
return
with log_container:
log_records.append(msg)
st.code("\n".join(log_records))
logger.add(log_received)
st.toast(tr("Generating Video"))
logger.info(tr("Start Generating Video"))
logger.info(utils.to_json(params))
scroll_to_bottom()
result = tm.start(task_id=task_id, params=params)
if not result or "videos" not in result:
st.error(tr("Video Generation Failed"))
logger.error(tr("Video Generation Failed"))
scroll_to_bottom()
st.stop()
video_files = result.get("videos", [])
st.success(tr("Video Generation Completed"))
try:
if video_files:
player_cols = st.columns(len(video_files) * 2 + 1)
for i, url in enumerate(video_files):
player_cols[i * 2 + 1].video(url)
except Exception:
pass
open_task_folder(task_id)
logger.info(tr("Video Generation Completed"))
scroll_to_bottom()
config.save_config()