tori29umai commited on
Commit
20e3524
1 Parent(s): 160851b

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +240 -0
  2. custom.html +12 -0
  3. requirements.txt +2 -0
  4. test_prompt.jinja2 +22 -0
  5. utils/dl_utils.py +19 -0
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from jinja2 import Template
3
+ from llama_cpp import Llama
4
+ import os
5
+ import configparser
6
+ from utils.dl_utils import dl_guff_model
7
+
8
+ # モデルディレクトリが存在しない場合は作成
9
+ if not os.path.exists("models"):
10
+ os.makedirs("models")
11
+
12
+ # 使用するモデルのファイル名を指定
13
+ model_filename = "Llama-3.1-70B-EZO-1.1-it-Q4_K_M.gguf"
14
+ model_path = os.path.join("models", model_filename)
15
+
16
+ # モデルファイルが存在しない場合はダウンロード
17
+ if not os.path.exists(model_path):
18
+ dl_guff_model("models", f"https://huggingface.co/mmnga/Llama-3.1-70B-EZO-1.1-it-gguf/resolve/main/{model_filename}")
19
+
20
+ # 設定をINIファイルに保存する関数
21
+ def save_settings_to_ini(settings, filename='character_settings.ini'):
22
+ config = configparser.ConfigParser()
23
+ config['Settings'] = {
24
+ 'name': settings['name'],
25
+ 'gender': settings['gender'],
26
+ 'situation': '\n'.join(settings['situation']),
27
+ 'orders': '\n'.join(settings['orders']),
28
+ 'dirty_talk_list': '\n'.join(settings['dirty_talk_list']),
29
+ 'example_quotes': '\n'.join(settings['example_quotes'])
30
+ }
31
+ with open(filename, 'w', encoding='utf-8') as configfile:
32
+ config.write(configfile)
33
+
34
+ # INIファイルから設定を読み込む関数
35
+ def load_settings_from_ini(filename='character_settings.ini'):
36
+ if not os.path.exists(filename):
37
+ return None
38
+
39
+ config = configparser.ConfigParser()
40
+ config.read(filename, encoding='utf-8')
41
+
42
+ if 'Settings' not in config:
43
+ return None
44
+
45
+ try:
46
+ settings = {
47
+ 'name': config['Settings']['name'],
48
+ 'gender': config['Settings']['gender'],
49
+ 'situation': config['Settings']['situation'].split('\n'),
50
+ 'orders': config['Settings']['orders'].split('\n'),
51
+ 'dirty_talk_list': config['Settings']['dirty_talk_list'].split('\n'),
52
+ 'example_quotes': config['Settings']['example_quotes'].split('\n')
53
+ }
54
+ return settings
55
+ except KeyError:
56
+ return None
57
+
58
+ # LlamaCppのラッパークラス
59
+ class LlamaCppAdapter:
60
+ def __init__(self, model_path, n_ctx=4096):
61
+ print(f"モデルの初期化: {model_path}")
62
+ self.llama = Llama(model_path=model_path, n_ctx=n_ctx, n_gpu_layers=-1)
63
+
64
+ def generate(self, prompt, max_new_tokens=4096, temperature=0.5, top_p=0.7, top_k=80, stop=["<END>"]):
65
+ return self._generate(prompt, max_new_tokens, temperature, top_p, top_k, stop)
66
+
67
+ def _generate(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, stop: list):
68
+ return self.llama(
69
+ prompt,
70
+ temperature=temperature,
71
+ max_tokens=max_new_tokens,
72
+ top_p=top_p,
73
+ top_k=top_k,
74
+ stop=stop,
75
+ repeat_penalty=1.2,
76
+ )
77
+
78
+ # キャラクターメーカークラス
79
+ class CharacterMaker:
80
+ def __init__(self):
81
+ self.llama = LlamaCppAdapter(model_path)
82
+ self.history = []
83
+ self.settings = load_settings_from_ini()
84
+ if not self.settings:
85
+ self.settings = {
86
+ "name": "ナツ",
87
+ "gender": "女性",
88
+ "situation": [
89
+ "あなたは人工知能アシスタントです。",
90
+ "ユーザーの日常生活をサポートし、より良い生活を送るお手伝いをします。",
91
+ "AIアシスタント『ナツ』として、ユーザーの健康と幸福をケアし、様々な質問に答えたり課題解決を手伝ったりします。"
92
+ ],
93
+ "orders": [
94
+ "丁寧な言葉遣いを心がけてください。",
95
+ "ユーザーとの対話を通じてサポートを提供します。",
96
+ "ユーザーのことは『ユーザー様』と呼んでください。"
97
+ ],
98
+ "conversation_topics": [
99
+ "健康管理",
100
+ "目標設定",
101
+ "時間管理"
102
+ ],
103
+ "example_quotes": [
104
+ "ユーザー様の健康と幸福が何より大切です。どのようなサポートが必要でしょうか?",
105
+ "私はユーザー様の生活をより良いものにするためのアシスタントです。お手伝いできることがありましたらお申し付けください。",
106
+ "目標達成に向けて一緒に頑張りましょう。具体的な計画を立てるお手伝いをさせていただきます。",
107
+ "効率的な時間管理のコツをお教えします。まずは1日のスケジュールを確認してみましょう。",
108
+ "ストレス解消法についてアドバイスいたします。リラックスするための簡単な呼吸法から始めてみませんか?"
109
+ ]
110
+ }
111
+ save_settings_to_ini(self.settings)
112
+
113
+ def make(self, input_str: str):
114
+ prompt = self._generate_aki(input_str)
115
+ print(prompt)
116
+ print("-----------------")
117
+ res = self.llama.generate(prompt, max_new_tokens=1000, stop=["<END>", "\n"])
118
+ res_text = res["choices"][0]["text"]
119
+ self.history.append({"user": input_str, "assistant": res_text})
120
+ return res_text
121
+
122
+ def make_prompt(self, name: str, gender: str, situation: list, orders: list, dirty_talk_list: list, example_quotes: list, input_str: str):
123
+ with open('test_prompt.jinja2', 'r', encoding='utf-8') as f:
124
+ prompt = f.readlines()
125
+ fix_example_quotes = [quote+"<END>" for quote in example_quotes]
126
+ prompt = "".join(prompt)
127
+ prompt = Template(prompt).render(name=name, gender=gender, situation=situation, orders=orders, dirty_talk_list=dirty_talk_list, example_quotes=fix_example_quotes, histories=self.history, input_str=input_str)
128
+ return prompt
129
+
130
+ def _generate_aki(self, input_str: str):
131
+ prompt = self.make_prompt(
132
+ self.settings["name"],
133
+ self.settings["gender"],
134
+ self.settings["situation"],
135
+ self.settings["orders"],
136
+ self.settings["dirty_talk_list"],
137
+ self.settings["example_quotes"],
138
+ input_str
139
+ )
140
+ print(prompt)
141
+ return prompt
142
+
143
+ def update_settings(self, new_settings):
144
+ self.settings.update(new_settings)
145
+ save_settings_to_ini(self.settings)
146
+
147
+ def reset(self):
148
+ self.history = []
149
+ self.llama = LlamaCppAdapter(model_path)
150
+
151
+ character_maker = CharacterMaker()
152
+
153
+ # 設定を更新する関数
154
+ def update_settings(name, gender, situation, orders, dirty_talk_list, example_quotes):
155
+ new_settings = {
156
+ "name": name,
157
+ "gender": gender,
158
+ "situation": [s.strip() for s in situation.split('\n') if s.strip()],
159
+ "orders": [o.strip() for o in orders.split('\n') if o.strip()],
160
+ "dirty_talk_list": [d.strip() for d in dirty_talk_list.split('\n') if d.strip()],
161
+ "example_quotes": [e.strip() for e in example_quotes.split('\n') if e.strip()]
162
+ }
163
+ character_maker.update_settings(new_settings)
164
+ return "設定が更新されました。"
165
+
166
+ # チャット機能の関数
167
+ def chat_with_character(message, history):
168
+ character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
169
+ response = character_maker.make(message)
170
+ return response
171
+
172
+ # チャットをクリアする関数
173
+ def clear_chat():
174
+ character_maker.reset()
175
+ return []
176
+
177
+ # カスタムCSS
178
+ custom_css = """
179
+ #chatbot {
180
+ height: 60vh !important;
181
+ overflow-y: auto;
182
+ }
183
+ """
184
+
185
+ # カスタムJavaScript(HTML内に埋め込む)
186
+ custom_js = """
187
+ <script>
188
+ function adjustChatbotHeight() {
189
+ var chatbot = document.querySelector('#chatbot');
190
+ if (chatbot) {
191
+ chatbot.style.height = window.innerHeight * 0.6 + 'px';
192
+ }
193
+ }
194
+
195
+ // ページ読み込み時と画面サイズ変更時にチャットボットの高さを調整
196
+ window.addEventListener('load', adjustChatbotHeight);
197
+ window.addEventListener('resize', adjustChatbotHeight);
198
+ </script>
199
+ """
200
+
201
+ # Gradioインターフェースの設定
202
+ with gr.Blocks(css=custom_css) as iface:
203
+ chatbot = gr.Chatbot(elem_id="chatbot")
204
+
205
+ with gr.Tab("チャット"):
206
+ gr.ChatInterface(
207
+ chat_with_character,
208
+ chatbot=chatbot,
209
+ textbox=gr.Textbox(placeholder="メッセージを入力してください...", container=False, scale=7),
210
+ theme="soft",
211
+ retry_btn="もう一度生成",
212
+ undo_btn="前のメッセージを取り消す",
213
+ clear_btn="チャットをクリア",
214
+ )
215
+
216
+ with gr.Tab("設定"):
217
+ gr.Markdown("## キャラクター設定")
218
+ name_input = gr.Textbox(label="名前", value=character_maker.settings["name"])
219
+ gender_input = gr.Textbox(label="性別", value=character_maker.settings["gender"])
220
+ situation_input = gr.Textbox(label="状況設定", value="\n".join(character_maker.settings["situation"]), lines=5)
221
+ orders_input = gr.Textbox(label="指示", value="\n".join(character_maker.settings["orders"]), lines=5)
222
+ dirty_talk_input = gr.Textbox(label="淫語リスト", value="\n".join(character_maker.settings["dirty_talk_list"]), lines=5)
223
+ example_quotes_input = gr.Textbox(label="例文", value="\n".join(character_maker.settings["example_quotes"]), lines=5)
224
+
225
+ update_button = gr.Button("設定を更新")
226
+ update_output = gr.Textbox(label="更新状態")
227
+
228
+ update_button.click(
229
+ update_settings,
230
+ inputs=[name_input, gender_input, situation_input, orders_input, dirty_talk_input, example_quotes_input],
231
+ outputs=[update_output]
232
+ )
233
+
234
+ # Gradioアプリの起動
235
+ if __name__ == "__main__":
236
+ iface.launch(
237
+ share=True,
238
+ allowed_paths=["models"],
239
+ favicon_path="custom.html"
240
+ )
custom.html ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script>
2
+ function adjustChatbotHeight() {
3
+ var chatbot = document.querySelector('#chatbot');
4
+ if (chatbot) {
5
+ chatbot.style.height = window.innerHeight * 0.6 + 'px';
6
+ }
7
+ }
8
+
9
+ // ページ読み込み時と画面サイズ変更時にチャットボットの高さを調整
10
+ window.addEventListener('load', adjustChatbotHeight);
11
+ window.addEventListener('resize', adjustChatbotHeight);
12
+ </script>
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.81-cu124/llama_cpp_python-0.2.81-cp310-cp310-linux_x86_64.whl
test_prompt.jinja2 ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ・キャラクター設定
2
+ 名前:{{name}}
3
+ 性別:{{gender}}
4
+
5
+ {%for situation in situation %}
6
+ {{situation}}{%endfor%}
7
+
8
+ ・今回のユーザーのオーダー
9
+ {%for order in orders %}
10
+ {{order}}{%endfor%}
11
+
12
+ ・使ってほしい淫語表現
13
+ {%for dirty_talk in dirty_talk_list %}
14
+ {{dirty_talk}}{%endfor%}
15
+ ・キャラクターの発言例
16
+ {%for example_quote in example_quotes %}
17
+ {{example_quote}}{%endfor%}
18
+
19
+ {%for history in histories %}user: {{history.user}}
20
+ {{name}}: {{history.assistant}}{%endfor%}
21
+ user: {{input_str}}
22
+ {{name}}:
utils/dl_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+
6
+ def dl_guff_model(model_dir, url):
7
+ file_name = url.split('/')[-1]
8
+ folder = model_dir
9
+ file_path = os.path.join(folder, file_name)
10
+ if not os.path.exists(file_path):
11
+ response = requests.get(url, allow_redirects=True)
12
+ if response.status_code == 200:
13
+ with open(file_path, 'wb') as f:
14
+ f.write(response.content)
15
+ print(f'Downloaded {file_name}')
16
+ else:
17
+ print(f'Failed to download {file_name}')
18
+ else:
19
+ print(f'{file_name} already exists.')