Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from jinja2 import Template | |
from llama_cpp import Llama | |
import os | |
import configparser | |
from utils.dl_utils import dl_guff_model | |
# モデルディレクトリが存在しない場合は作成 | |
if not os.path.exists("models"): | |
os.makedirs("models") | |
# 使用するモデルのファイル名を指定 | |
model_filename = "Llama-3.1-70B-EZO-1.1-it-Q4_K_M.gguf" | |
model_path = os.path.join("models", model_filename) | |
# モデルファイルが存在しない場合はダウンロード | |
if not os.path.exists(model_path): | |
dl_guff_model("models", f"https://huggingface.co/mmnga/Llama-3.1-70B-EZO-1.1-it-gguf/resolve/main/{model_filename}") | |
# 設定をINIファイルに保存する関数 | |
def save_settings_to_ini(settings, filename='character_settings.ini'): | |
config = configparser.ConfigParser() | |
config['Settings'] = { | |
'name': settings['name'], | |
'gender': settings['gender'], | |
'situation': '\n'.join(settings['situation']), | |
'orders': '\n'.join(settings['orders']), | |
'dirty_talk_list': '\n'.join(settings['dirty_talk_list']), | |
'example_quotes': '\n'.join(settings['example_quotes']) | |
} | |
with open(filename, 'w', encoding='utf-8') as configfile: | |
config.write(configfile) | |
# INIファイルから設定を読み込む関数 | |
def load_settings_from_ini(filename='character_settings.ini'): | |
if not os.path.exists(filename): | |
return None | |
config = configparser.ConfigParser() | |
config.read(filename, encoding='utf-8') | |
if 'Settings' not in config: | |
return None | |
try: | |
settings = { | |
'name': config['Settings']['name'], | |
'gender': config['Settings']['gender'], | |
'situation': config['Settings']['situation'].split('\n'), | |
'orders': config['Settings']['orders'].split('\n'), | |
'dirty_talk_list': config['Settings']['dirty_talk_list'].split('\n'), | |
'example_quotes': config['Settings']['example_quotes'].split('\n') | |
} | |
return settings | |
except KeyError: | |
return None | |
# LlamaCppのラッパークラス | |
class LlamaCppAdapter: | |
def __init__(self, model_path, n_ctx=4096): | |
print(f"モデルの初期化: {model_path}") | |
self.llama = Llama(model_path=model_path, n_ctx=n_ctx, n_gpu_layers=-1) | |
def generate(self, prompt, max_new_tokens=4096, temperature=0.5, top_p=0.7, top_k=80, stop=["<END>"]): | |
return self._generate(prompt, max_new_tokens, temperature, top_p, top_k, stop) | |
def _generate(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, stop: list): | |
return self.llama( | |
prompt, | |
temperature=temperature, | |
max_tokens=max_new_tokens, | |
top_p=top_p, | |
top_k=top_k, | |
stop=stop, | |
repeat_penalty=1.2, | |
) | |
# キャラクターメーカークラス | |
class CharacterMaker: | |
def __init__(self): | |
self.llama = LlamaCppAdapter(model_path) | |
self.history = [] | |
self.settings = load_settings_from_ini() | |
if not self.settings: | |
self.settings = { | |
"name": "ナツ", | |
"gender": "女性", | |
"situation": [ | |
"あなたは人工知能アシスタントです。", | |
"ユーザーの日常生活をサポートし、より良い生活を送るお手伝いをします。", | |
"AIアシスタント『ナツ』として、ユーザーの健康と幸福をケアし、様々な質問に答えたり課題解決を手伝ったりします。" | |
], | |
"orders": [ | |
"丁寧な言葉遣いを心がけてください。", | |
"ユーザーとの対話を通じてサポートを提供します。", | |
"ユーザーのことは『ユーザー様』と呼んでください。" | |
], | |
"conversation_topics": [ | |
"健康管理", | |
"目標設定", | |
"時間管理" | |
], | |
"example_quotes": [ | |
"ユーザー様の健康と幸福が何より大切です。どのようなサポートが必要でしょうか?", | |
"私はユーザー様の生活をより良いものにするためのアシスタントです。お手伝いできることがありましたらお申し付けください。", | |
"目標達成に向けて一緒に頑張りましょう。具体的な計画を立てるお手伝いをさせていただきます。", | |
"効率的な時間管理のコツをお教えします。まずは1日のスケジュールを確認してみましょう。", | |
"ストレス解消法についてアドバイスいたします。リラックスするための簡単な呼吸法から始めてみませんか?" | |
] | |
} | |
save_settings_to_ini(self.settings) | |
def make(self, input_str: str): | |
prompt = self._generate_aki(input_str) | |
print(prompt) | |
print("-----------------") | |
res = self.llama.generate(prompt, max_new_tokens=1000, stop=["<END>", "\n"]) | |
res_text = res["choices"][0]["text"] | |
self.history.append({"user": input_str, "assistant": res_text}) | |
return res_text | |
def make_prompt(self, name: str, gender: str, situation: list, orders: list, dirty_talk_list: list, example_quotes: list, input_str: str): | |
with open('test_prompt.jinja2', 'r', encoding='utf-8') as f: | |
prompt = f.readlines() | |
fix_example_quotes = [quote+"<END>" for quote in example_quotes] | |
prompt = "".join(prompt) | |
prompt = Template(prompt).render(name=name, gender=gender, situation=situation, orders=orders, dirty_talk_list=dirty_talk_list, example_quotes=fix_example_quotes, histories=self.history, input_str=input_str) | |
return prompt | |
def _generate_aki(self, input_str: str): | |
prompt = self.make_prompt( | |
self.settings["name"], | |
self.settings["gender"], | |
self.settings["situation"], | |
self.settings["orders"], | |
self.settings["dirty_talk_list"], | |
self.settings["example_quotes"], | |
input_str | |
) | |
print(prompt) | |
return prompt | |
def update_settings(self, new_settings): | |
self.settings.update(new_settings) | |
save_settings_to_ini(self.settings) | |
def reset(self): | |
self.history = [] | |
self.llama = LlamaCppAdapter(model_path) | |
character_maker = CharacterMaker() | |
# 設定を更新する関数 | |
def update_settings(name, gender, situation, orders, dirty_talk_list, example_quotes): | |
new_settings = { | |
"name": name, | |
"gender": gender, | |
"situation": [s.strip() for s in situation.split('\n') if s.strip()], | |
"orders": [o.strip() for o in orders.split('\n') if o.strip()], | |
"dirty_talk_list": [d.strip() for d in dirty_talk_list.split('\n') if d.strip()], | |
"example_quotes": [e.strip() for e in example_quotes.split('\n') if e.strip()] | |
} | |
character_maker.update_settings(new_settings) | |
return "設定が更新されました。" | |
# チャット機能の関数 | |
def chat_with_character(message, history): | |
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history] | |
response = character_maker.make(message) | |
return response | |
# チャットをクリアする関数 | |
def clear_chat(): | |
character_maker.reset() | |
return [] | |
# カスタムCSS | |
custom_css = """ | |
#chatbot { | |
height: 60vh !important; | |
overflow-y: auto; | |
} | |
""" | |
# カスタムJavaScript(HTML内に埋め込む) | |
custom_js = """ | |
<script> | |
function adjustChatbotHeight() { | |
var chatbot = document.querySelector('#chatbot'); | |
if (chatbot) { | |
chatbot.style.height = window.innerHeight * 0.6 + 'px'; | |
} | |
} | |
// ページ読み込み時と画面サイズ変更時にチャットボットの高さを調整 | |
window.addEventListener('load', adjustChatbotHeight); | |
window.addEventListener('resize', adjustChatbotHeight); | |
</script> | |
""" | |
# Gradioインターフェースの設定 | |
with gr.Blocks(css=custom_css) as iface: | |
chatbot = gr.Chatbot(elem_id="chatbot") | |
with gr.Tab("チャット"): | |
gr.ChatInterface( | |
chat_with_character, | |
chatbot=chatbot, | |
textbox=gr.Textbox(placeholder="メッセージを入力してください...", container=False, scale=7), | |
theme="soft", | |
retry_btn="もう一度生成", | |
undo_btn="前のメッセージを取り消す", | |
clear_btn="チャットをクリア", | |
) | |
with gr.Tab("設定"): | |
gr.Markdown("## キャラクター設定") | |
name_input = gr.Textbox(label="名前", value=character_maker.settings["name"]) | |
gender_input = gr.Textbox(label="性別", value=character_maker.settings["gender"]) | |
situation_input = gr.Textbox(label="状況設定", value="\n".join(character_maker.settings["situation"]), lines=5) | |
orders_input = gr.Textbox(label="指示", value="\n".join(character_maker.settings["orders"]), lines=5) | |
dirty_talk_input = gr.Textbox(label="淫語リスト", value="\n".join(character_maker.settings["dirty_talk_list"]), lines=5) | |
example_quotes_input = gr.Textbox(label="例文", value="\n".join(character_maker.settings["example_quotes"]), lines=5) | |
update_button = gr.Button("設定を更新") | |
update_output = gr.Textbox(label="更新状態") | |
update_button.click( | |
update_settings, | |
inputs=[name_input, gender_input, situation_input, orders_input, dirty_talk_input, example_quotes_input], | |
outputs=[update_output] | |
) | |
# Gradioアプリの起動 | |
if __name__ == "__main__": | |
iface.launch( | |
share=True, | |
allowed_paths=["models"], | |
favicon_path="custom.html" | |
) |