Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
|
3 |
import spaces
|
4 |
import sys
|
5 |
import time
|
@@ -289,12 +289,10 @@ class LlamaAdapter:
|
|
289 |
repeat_penalty=repeat_penalty
|
290 |
)
|
291 |
|
292 |
-
@spaces.GPU(duration=120)
|
293 |
def load_model_gpu(model_type, model_path, n_gpu_layers, params):
|
294 |
llama = LlamaAdapter(model_path, params, n_gpu_layers)
|
295 |
print(f"{model_type} モデル {model_path} のロードが完了しました。(n_gpu_layers: {n_gpu_layers})")
|
296 |
return llama
|
297 |
-
|
298 |
|
299 |
class CharacterMaker:
|
300 |
def __init__(self):
|
@@ -336,7 +334,7 @@ class CharacterMaker:
|
|
336 |
|
337 |
try:
|
338 |
# 新しいモデルをロード
|
339 |
-
self.llama =
|
340 |
self.current_model = model_type
|
341 |
self.model_loaded.set()
|
342 |
print(f"{model_type} モデルをロードしました。モデルパス: {model_path}、GPUレイヤー数: {n_gpu_layers}")
|
@@ -344,6 +342,17 @@ class CharacterMaker:
|
|
344 |
print(f"{model_type} モデルのロード中にエラーが発生しました: {e}")
|
345 |
self.model_loaded.set()
|
346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
def generate_response(self, input_str):
|
348 |
self.load_model('CHAT')
|
349 |
if not self.model_loaded.wait(timeout=30) or not self.llama:
|
@@ -470,14 +479,14 @@ def chat_with_character(message, history):
|
|
470 |
character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
|
471 |
else:
|
472 |
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
|
473 |
-
return character_maker.
|
474 |
|
475 |
def chat_with_character_stream(message, history):
|
476 |
if character_maker.use_chat_format:
|
477 |
character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
|
478 |
else:
|
479 |
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
|
480 |
-
response = character_maker.
|
481 |
for i in range(len(response)):
|
482 |
time.sleep(0.05) # 各文字の表示間隔を調整
|
483 |
yield response[:i+1]
|
@@ -702,7 +711,7 @@ def build_gradio_interface():
|
|
702 |
generated_output = gr.Textbox(label="生成された文章")
|
703 |
|
704 |
generate_button.click(
|
705 |
-
character_maker.
|
706 |
inputs=[gen_input_text, gen_characters, gen_token_multiplier, gen_instruction],
|
707 |
outputs=[generated_output]
|
708 |
)
|
|
|
1 |
import os
|
2 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
3 |
import spaces
|
4 |
import sys
|
5 |
import time
|
|
|
289 |
repeat_penalty=repeat_penalty
|
290 |
)
|
291 |
|
|
|
292 |
def load_model_gpu(model_type, model_path, n_gpu_layers, params):
|
293 |
llama = LlamaAdapter(model_path, params, n_gpu_layers)
|
294 |
print(f"{model_type} モデル {model_path} のロードが完了しました。(n_gpu_layers: {n_gpu_layers})")
|
295 |
return llama
|
|
|
296 |
|
297 |
class CharacterMaker:
|
298 |
def __init__(self):
|
|
|
334 |
|
335 |
try:
|
336 |
# 新しいモデルをロード
|
337 |
+
self.llama = LlamaAdapter(model_path, params, n_gpu_layers)
|
338 |
self.current_model = model_type
|
339 |
self.model_loaded.set()
|
340 |
print(f"{model_type} モデルをロードしました。モデルパス: {model_path}、GPUレイヤー数: {n_gpu_layers}")
|
|
|
342 |
print(f"{model_type} モデルのロード中にエラーが発生しました: {e}")
|
343 |
self.model_loaded.set()
|
344 |
|
345 |
+
@spaces.GPU(duration=120)
|
346 |
+
def chat_or_gen(self, text, gen_characters, gen_token_multiplier, instruction, mode):
|
347 |
+
if mode == "chat":
|
348 |
+
return self.generate_response(text)
|
349 |
+
elif mode == "gen":
|
350 |
+
return self.generate_text(text, gen_characters, gen_token_multiplier, instruction)
|
351 |
+
|
352 |
+
|
353 |
+
def generate_text_gen_pre(self, text, gen_characters, gen_token_multiplier, instruction):
|
354 |
+
return self.chat_or_gen(self, text, gen_characters, gen_token_multiplier, instruction, mode="gen")
|
355 |
+
|
356 |
def generate_response(self, input_str):
|
357 |
self.load_model('CHAT')
|
358 |
if not self.model_loaded.wait(timeout=30) or not self.llama:
|
|
|
479 |
character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
|
480 |
else:
|
481 |
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
|
482 |
+
return character_maker.chat_or_gen(text=message,gen_characters=None, gen_token_multiplier=None, instruction=None, mode="chat")
|
483 |
|
484 |
def chat_with_character_stream(message, history):
|
485 |
if character_maker.use_chat_format:
|
486 |
character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))]
|
487 |
else:
|
488 |
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
|
489 |
+
response = character_maker.chat_or_gen(text=message,gen_characters=None, gen_token_multiplier=None, instruction=None, mode="chat")
|
490 |
for i in range(len(response)):
|
491 |
time.sleep(0.05) # 各文字の表示間隔を調整
|
492 |
yield response[:i+1]
|
|
|
711 |
generated_output = gr.Textbox(label="生成された文章")
|
712 |
|
713 |
generate_button.click(
|
714 |
+
character_maker.generate_text_pre,
|
715 |
inputs=[gen_input_text, gen_characters, gen_token_multiplier, gen_instruction],
|
716 |
outputs=[generated_output]
|
717 |
)
|