Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import os | |
from huggingface_hub import hf_hub_download | |
API_KEY = os.environ["GOOGLE_API_KEY"] # 環境変数から取得 | |
import google.generativeai as genai | |
import torch | |
from pathlib import Path | |
from style_bert_vits2.nlp import bert_models | |
from style_bert_vits2.constants import Languages | |
from style_bert_vits2.tts_model import TTSModel | |
# ローカル実行時のみ有効化してください | |
# import sounddevice as sd | |
# import pytchat | |
import time | |
# --- Google Gemini API の設定 --- | |
genai.configure(api_key=API_KEY) | |
generation_config = { | |
"temperature": 1, | |
"top_p": 0.95, | |
"top_k": 40, | |
"max_output_tokens": 8192, | |
"response_mime_type": "text/plain", | |
} | |
safety_settings = [ | |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, | |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, | |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, | |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","threshold": "BLOCK_NONE"}, | |
] | |
model = genai.GenerativeModel( | |
model_name="gemini-2.0-flash-exp", | |
generation_config=generation_config, | |
safety_settings=safety_settings | |
) | |
chat_session = model.start_chat(history=[ | |
{"role":"user","parts":["今からあなたは明るい女の子です!"]}, | |
{"role":"model","parts":["こんにちは!"]} | |
]) | |
# --- BERT モデルのロード --- | |
bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm") | |
bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm") | |
# --- TTS モデル用ファイルパス(Hugging Face Hubからダウンロード) --- | |
model_file = hf_hub_download( | |
repo_id="buchi-stdesign/3DAItuber-model", | |
filename="Anneli_e116_s32000.safetensors", | |
repo_type="model", | |
token=os.environ.get("HUGGINGFACE_TOKEN") # トークンを明示的に指定 | |
) | |
config_file = hf_hub_download( | |
repo_id="buchi-stdesign/3DAItuber-model", | |
filename="config.json", | |
repo_type="model", | |
token=os.environ.get("HUGGINGFACE_TOKEN") # トークンを明示的に指定 | |
) | |
style_file = hf_hub_download( | |
repo_id="buchi-stdesign/3DAItuber-model", | |
filename="style_vectors.npy", | |
repo_type="model", | |
token=os.environ.get("HUGGINGFACE_TOKEN") # トークンを明示的に指定 | |
) | |
# デバイス設定 (CUDA 未サポート時は CPU にフォールバック) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"[INFO] Using device: {device}") | |
# --- YouTube LiveChat 取得準備 --- | |
# ローカル実行時のみ有効化してください | |
# import pytchat | |
# livechat = pytchat.create(video_id="MYLhogwYrY4") | |
# --- オーディオ再生用ユーティリティ関数 --- | |
# ローカル実行時のみ有効化してください | |
# device_id = 10 # お使いの環境に合わせて変更してください | |
# def play_tts(text: str): | |
# """ | |
# テキスト → 音声 → 再生 を行う関数(ローカル用) | |
# """ | |
# tts = TTSModel( | |
# model_path=model_file, | |
# config_path=config_file, | |
# style_vec_path=style_file, | |
# device=device, | |
# ) | |
# sr, wav = tts.infer(text=text, length=0.85) | |
# sd.play(wav, sr, device=device_id) | |
# sd.wait() | |
# クラウド・Streamlit用:音声データを返す関数 | |
def tts_to_wav(text: str): | |
""" | |
テキスト → 音声データ(wav配列, サンプリングレート)を返す(クラウド用) | |
""" | |
tts = TTSModel( | |
model_path=model_file, | |
config_path=config_file, | |
style_vec_path=style_file, | |
device=device, | |
) | |
sr, wav = tts.infer(text=text, length=0.85) | |
return sr, wav | |
# --- ライブチャットに応答して音声再生ループ --- | |
# ローカル実行時のみ有効化してください | |
# while livechat.is_alive(): | |
# chatdata = livechat.get() | |
# for c in chatdata.items: | |
# user_msg = f"{c.datetime} {c.author.name} {c.message} {c.amountString}" | |
# print(user_msg) | |
# resp = chat_session.send_message(user_msg) | |
# print(resp.text) | |
# play_tts(resp.text) | |
# time.sleep(1) | |
# --- コンソール入力にも対応 --- | |
# ローカル実行時のみ有効化してください | |
# while True: | |
# user_input = input("You: ") | |
# resp = chat_session.send_message(user_input) | |
# print("Bot:", resp.text) | |
# play_tts(resp.text) | |