Spaces:
Running
Running
from InquirerPy import inquirer | |
import sys | |
import os | |
from pathlib import Path | |
from msdl.config import ( | |
CLOUD_LLM_DOCKERFILE, | |
LOCAL_LLM_DOCKERFILE, | |
) | |
from msdl.i18n import ( | |
t, | |
get_available_languages, | |
set_language, | |
get_env_variable, | |
) | |
from msdl.utils import ( | |
clean_api_key, | |
get_model_formats, | |
get_existing_api_key, | |
save_api_key_to_env, | |
validate_api_key, | |
) | |
SEARCH_ENGINES = { | |
"DuckDuckGoSearch": { | |
"name": "DuckDuckGo", | |
"key": "DUCKDUCKGO", | |
"requires_key": False, | |
"env_var": None | |
}, | |
"BingSearch": { | |
"name": "Bing", | |
"key": "BING", | |
"requires_key": True, | |
"env_var": "BING_SEARCH_API_KEY" | |
}, | |
"BraveSearch": { | |
"name": "Brave", | |
"key": "BRAVE", | |
"requires_key": True, | |
"env_var": "BRAVE_SEARCH_API_KEY" | |
}, | |
"GoogleSearch": { | |
"name": "Google Serper", | |
"key": "GOOGLE", | |
"requires_key": True, | |
"env_var": "GOOGLE_SERPER_API_KEY" | |
}, | |
"TencentSearch": { | |
"name": "Tencent", | |
"key": "TENCENT", | |
"requires_key": True, | |
"env_vars": ["TENCENT_SEARCH_SECRET_ID", "TENCENT_SEARCH_SECRET_KEY"] | |
} | |
} | |
def get_language_choice(): | |
"""Get user's language preference""" | |
def _get_language_options(): | |
available_langs = get_available_languages() | |
lang_choices = { | |
"en": "English", | |
"zh_CN": "中文" | |
} | |
return [{"name": f"{lang_choices.get(lang, lang)}", "value": lang} for lang in available_langs] | |
current_lang = get_env_variable("LAUNCHER_INTERACTION_LANGUAGE") | |
if not current_lang: | |
lang_options = _get_language_options() | |
language = inquirer.select( | |
message=t("SELECT_INTERFACE_LANGUAGE"), | |
choices=lang_options, | |
default="en" | |
).execute() | |
if language: | |
set_language(language) | |
sys.stdout.flush() | |
restart_program() | |
def get_backend_language(): | |
"""Get user's backend language preference""" | |
return inquirer.select( | |
message=t("SELECT_BACKEND_LANGUAGE"), | |
choices=[ | |
{"name": t("CHINESE"), "value": "cn"}, | |
{"name": t("ENGLISH"), "value": "en"}, | |
], | |
default="cn", | |
).execute() | |
def get_model_choice(): | |
"""Get user's model deployment type preference""" | |
model_deployment_type = [ | |
{ | |
"name": t("CLOUD_MODEL"), | |
"value": CLOUD_LLM_DOCKERFILE | |
}, | |
{ | |
"name": t("LOCAL_MODEL"), | |
"value": LOCAL_LLM_DOCKERFILE | |
}, | |
] | |
return inquirer.select( | |
message=t("MODEL_DEPLOYMENT_TYPE"), | |
choices=model_deployment_type, | |
).execute() | |
def get_model_format(model): | |
"""Get user's model format preference""" | |
model_formats = get_model_formats(model) | |
return inquirer.select( | |
message=t("MODEL_FORMAT_CHOICE"), | |
choices=[{ | |
"name": format, | |
"value": format | |
} for format in model_formats], | |
).execute() | |
def _handle_api_key_input(env_var_name, message=None): | |
"""Handle API key input and validation for a given environment variable""" | |
if message is None: | |
message = t("PLEASE_INPUT_NEW_API_KEY", ENV_VAR_NAME=env_var_name) | |
print(message) | |
while True: | |
api_key = inquirer.secret( | |
message=t("PLEASE_INPUT_NEW_API_KEY_FROM_ZERO", ENV_VAR_NAME=env_var_name) | |
).execute() | |
cleaned_api_key = clean_api_key(api_key) | |
try: | |
save_api_key_to_env(env_var_name, cleaned_api_key, t) | |
break | |
except ValueError as e: | |
print(str(e)) | |
retry = inquirer.confirm( | |
message=t("RETRY_API_KEY_INPUT"), default=True | |
).execute() | |
if not retry: | |
print(t("API_KEY_INPUT_CANCELLED")) | |
sys.exit(1) | |
def handle_api_key_input(model, model_format): | |
"""Handle API key input and validation""" | |
if model != CLOUD_LLM_DOCKERFILE: | |
return | |
env_var_name = { | |
"internlm_silicon": "SILICON_API_KEY", | |
"gpt4": "OPENAI_API_KEY", | |
"qwen": "QWEN_API_KEY", | |
}.get(model_format) | |
existing_api_key = get_existing_api_key(env_var_name) | |
if existing_api_key: | |
use_existing = inquirer.confirm( | |
message=t("CONFIRM_USE_EXISTING_API_KEY", ENV_VAR_NAME=env_var_name), | |
default=True, | |
).execute() | |
if use_existing: | |
return | |
print(t("CONFIRM_OVERWRITE_EXISTING_API_KEY", ENV_VAR_NAME=env_var_name)) | |
try: | |
save_api_key_to_env(model_format, clean_api_key(inquirer.secret( | |
message=t("PLEASE_INPUT_NEW_API_KEY_FROM_ZERO", ENV_VAR_NAME=env_var_name) | |
).execute()), t) | |
except ValueError as e: | |
print(str(e)) | |
retry = inquirer.confirm( | |
message=t("RETRY_API_KEY_INPUT"), default=True | |
).execute() | |
if not retry: | |
print(t("API_KEY_INPUT_CANCELLED")) | |
sys.exit(1) | |
def get_search_engine(): | |
"""Get user's preferred search engine and handle API key if needed""" | |
search_engine = inquirer.select( | |
message=t("SELECT_SEARCH_ENGINE"), | |
choices=[{ | |
"name": f"{t(f'SEARCH_ENGINE_{info["key"]}')} ({t('NO_API_KEY_NEEDED') if not info['requires_key'] else t('API_KEY_REQUIRED')})", | |
"value": engine | |
} for engine, info in SEARCH_ENGINES.items()], | |
).execute() | |
engine_info = SEARCH_ENGINES[search_engine] | |
if engine_info['requires_key']: | |
if search_engine == "TencentSearch": | |
# Handle Tencent's special case with two keys | |
for env_var in engine_info['env_vars']: | |
is_id = "ID" in env_var | |
message = t("TENCENT_ID_REQUIRED") if is_id else t("TENCENT_KEY_REQUIRED") | |
existing_key = get_existing_api_key(env_var) | |
if existing_key: | |
use_existing = inquirer.confirm( | |
message=t("CONFIRM_USE_EXISTING_API_KEY", ENV_VAR_NAME=env_var), | |
default=True, | |
).execute() | |
if not use_existing: | |
_handle_api_key_input(env_var, message) | |
else: | |
_handle_api_key_input(env_var, message) | |
else: | |
# Handle standard case with single WEB_SEARCH_API_KEY | |
env_var = engine_info['env_var'] | |
existing_key = get_existing_api_key(env_var) | |
if existing_key: | |
use_existing = inquirer.confirm( | |
message=t("CONFIRM_USE_EXISTING_API_KEY", ENV_VAR_NAME=env_var), | |
default=True, | |
).execute() | |
if not use_existing: | |
_handle_api_key_input(env_var, t("WEB_SEARCH_KEY_REQUIRED")) | |
else: | |
_handle_api_key_input(env_var, t("WEB_SEARCH_KEY_REQUIRED")) | |
print(t("SEARCH_ENGINE_CONFIGURED", engine=engine_info['name'])) | |
return search_engine | |
def restart_program(): | |
"""Restart the current program with the same arguments""" | |
print(t("LANGUAGE_CHANGED_RESTARTING")) | |
python = sys.executable | |
os.execl(python, python, *sys.argv) | |
def get_user_choices(): | |
"""Get all user choices in a single function""" | |
# Get language preference | |
get_language_choice() | |
# Get backend language | |
backend_language = get_backend_language() | |
# Get model choice | |
model = get_model_choice() | |
# Get model format | |
model_format = get_model_format(model) | |
# Get search engine choice | |
search_engine = get_search_engine() | |
# Handle API key if needed | |
handle_api_key_input(model, model_format) | |
return backend_language, model, model_format, search_engine | |