3DAItuber / AIvtuber.py
buchi-stdesign's picture
Upload AIvtuber.py
5d02468 verified
# -*- coding: utf-8 -*-
import os
from huggingface_hub import hf_hub_download
API_KEY = os.environ["GOOGLE_API_KEY"] # 環境変数から取得
import google.generativeai as genai
import torch
from pathlib import Path
from style_bert_vits2.nlp import bert_models
from style_bert_vits2.constants import Languages
from style_bert_vits2.tts_model import TTSModel
# ローカル実行時のみ有効化してください
# import sounddevice as sd
# import pytchat
import time
# --- Google Gemini API の設定 ---
genai.configure(api_key=API_KEY)
generation_config = {
"temperature": 1,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 8192,
"response_mime_type": "text/plain",
}
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","threshold": "BLOCK_NONE"},
]
model = genai.GenerativeModel(
model_name="gemini-2.0-flash-exp",
generation_config=generation_config,
safety_settings=safety_settings
)
chat_session = model.start_chat(history=[
{"role":"user","parts":["今からあなたは明るい女の子です!"]},
{"role":"model","parts":["こんにちは!"]}
])
# --- BERT モデルのロード ---
bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
# --- TTS モデル用ファイルパス(Hugging Face Hubからダウンロード) ---
model_file = hf_hub_download(
repo_id="buchi-stdesign/3DAItuber-model",
filename="Anneli_e116_s32000.safetensors",
repo_type="model",
token=os.environ.get("HUGGINGFACE_TOKEN") # トークンを明示的に指定
)
config_file = hf_hub_download(
repo_id="buchi-stdesign/3DAItuber-model",
filename="config.json",
repo_type="model",
token=os.environ.get("HUGGINGFACE_TOKEN") # トークンを明示的に指定
)
style_file = hf_hub_download(
repo_id="buchi-stdesign/3DAItuber-model",
filename="style_vectors.npy",
repo_type="model",
token=os.environ.get("HUGGINGFACE_TOKEN") # トークンを明示的に指定
)
# デバイス設定 (CUDA 未サポート時は CPU にフォールバック)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device}")
# --- YouTube LiveChat 取得準備 ---
# ローカル実行時のみ有効化してください
# import pytchat
# livechat = pytchat.create(video_id="MYLhogwYrY4")
# --- オーディオ再生用ユーティリティ関数 ---
# ローカル実行時のみ有効化してください
# device_id = 10 # お使いの環境に合わせて変更してください
# def play_tts(text: str):
# """
# テキスト → 音声 → 再生 を行う関数(ローカル用)
# """
# tts = TTSModel(
# model_path=model_file,
# config_path=config_file,
# style_vec_path=style_file,
# device=device,
# )
# sr, wav = tts.infer(text=text, length=0.85)
# sd.play(wav, sr, device=device_id)
# sd.wait()
# クラウド・Streamlit用:音声データを返す関数
def tts_to_wav(text: str):
"""
テキスト → 音声データ(wav配列, サンプリングレート)を返す(クラウド用)
"""
tts = TTSModel(
model_path=model_file,
config_path=config_file,
style_vec_path=style_file,
device=device,
)
sr, wav = tts.infer(text=text, length=0.85)
return sr, wav
# --- ライブチャットに応答して音声再生ループ ---
# ローカル実行時のみ有効化してください
# while livechat.is_alive():
# chatdata = livechat.get()
# for c in chatdata.items:
# user_msg = f"{c.datetime} {c.author.name} {c.message} {c.amountString}"
# print(user_msg)
# resp = chat_session.send_message(user_msg)
# print(resp.text)
# play_tts(resp.text)
# time.sleep(1)
# --- コンソール入力にも対応 ---
# ローカル実行時のみ有効化してください
# while True:
# user_input = input("You: ")
# resp = chat_session.send_message(user_input)
# print("Bot:", resp.text)
# play_tts(resp.text)