Spaces:
Runtime error
Runtime error
File size: 6,931 Bytes
fc11025 11f0819 6c6d175 57e9cd2 11f0819 0d5cec5 8d8989f eb35849 00221d8 0d5cec5 a4619fb 3531ae4 d61e896 a4619fb 4ab04be a4619fb d4fa487 a4619fb 4ab04be a4619fb 57e9cd2 d925e0c 0d5cec5 d925e0c 10f0a80 d925e0c 6f37091 57e9cd2 6f37091 a4619fb 0d5cec5 e271162 0d5cec5 e271162 6c6d175 0d5cec5 6c6d175 57e9cd2 d925e0c 6c6d175 57e9cd2 0d5cec5 6c6d175 57e9cd2 6c6d175 57e9cd2 6f37091 11f0819 0d5cec5 57e9cd2 0d5cec5 432ab9c 0d5cec5 bd8e175 0d5cec5 6f37091 432ab9c a4619fb 0d5cec5 6f37091 57e9cd2 432ab9c 3cf4ace 10f0a80 57e9cd2 c14f2e8 10f0a80 57e9cd2 4ab04be a4619fb d4fa487 432ab9c 57e9cd2 0d5cec5 6f37091 57e9cd2 11f0819 57e9cd2 0d5cec5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces
# PyTorch設定(パフォーマンスと再現性向上のため)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True
HF_TOKEN = os.getenv("HF_TOKEN")
# モデルのキャッシュ用辞書(ロード済みなら再利用)
loaded_models = {}
def get_model_and_tokenizer(model_name):
# 既にロード済みならそのまま返す
if model_name in loaded_models:
return loaded_models[model_name]
# ロードされていなければロードする
tokenizer = AutoTokenizer.from_pretrained(
model_name, attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN
)
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN)
loaded_models[model_name] = (model, tokenizer)
return model, tokenizer
def disable_generate_button():
# 生成ボタンを無効化し、テキストを「モデルをロード中……」に変更する
return gr.update(interactive=False, value="モデルをロード中……")
def load_model(model_name):
"""
プルダウン変更時や起動時に呼ばれ、モデルをロードして生成ボタンを有効化する。
"""
tokenizer = AutoTokenizer.from_pretrained(
model_name, attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN
)
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN)
loaded_models[model_name] = (model, tokenizer)
status_message = f"Model '{model_name}' loaded successfully."
# ロード完了後、生成ボタンを有効化し、テキストを「続きを生成」に戻す
return status_message, gr.update(interactive=True, value="続きを生成")
@spaces.GPU
def generate_text(
model_name,
input_text,
max_length=150,
temperature=0.7,
top_k=50,
top_p=0.95,
repetition_penalty=1.0
):
"""ユーザー入力に基づいてテキストを生成し、元のテキストに追加する関数"""
try:
if not input_text.strip():
return ""
# 既にロード済みのモデルとトークナイザーを使用
model, tokenizer = get_model_and_tokenizer(model_name)
# GPUが利用可能ならGPUへ移動。bf16がサポートされている場合はbf16を使用
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda" and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
model.to(device, dtype=torch.bfloat16)
else:
model.to(device)
# 入力テキストのトークン化
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
input_token_count = input_ids.shape[1]
# 総トークン数の上限を入力トークン数 + max_length(max_lengthは追加するトークン数として扱う)
total_max_length = input_token_count + max_length
# テキスト生成
output_ids = model.generate(
input_ids,
max_length=total_max_length,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
# 生成されたテキストをデコードし、入力部分を除いた生成分を抽出
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
new_text = generated_text[len(input_text):]
# 入力テキストに生成したテキストを追加して返す
return input_text + new_text
except Exception as e:
return f"{input_text}\n\nエラーが発生しました: {str(e)}"
# Gradioインターフェースの作成
with gr.Blocks() as demo:
gr.Markdown("# テキスト続き生成アシスタント")
gr.Markdown("モデルを選択し、テキストボックスに文章を入力してパラメータを調整後、「続きを生成」ボタンをクリックすると、選択したモデルがその続きを生成します。")
# モデル選択用プルダウンメニュー
model_dropdown = gr.Dropdown(
choices=[
"Local-Novel-LLM-project/Vecteus-v1-abliterated",
"Local-Novel-LLM-project/Ninja-V3",
"Local-Novel-LLM-project/kagemusya-7B-v1"
],
label="モデルを選択してください",
value="Local-Novel-LLM-project/Vecteus-v1-abliterated"
)
# 隠しコンポーネント:モデルロードの状況を表示(ユーザーには見せなくても良い)
load_status = gr.Textbox(visible=False)
# テキスト入力ボックス
input_text = gr.Textbox(label="テキストを入力してください", placeholder="ここにテキストを入力...", lines=10)
# 生成パラメータの設定UI
max_length_slider = gr.Slider(minimum=10, maximum=1000, value=100, step=10, label="追加するトークン数")
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="創造性(温度)")
top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="top_k")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="top_p")
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, value=1.0, step=0.1, label="繰り返しペナルティ")
# 生成ボタンは初期状態で無効化
generate_btn = gr.Button("モデルをロード中……", variant="primary", interactive=False)
clear_btn = gr.Button("クリア")
# プルダウン変更時に、まず生成ボタンを無効化(テキストを「モデルをロード中……」に変更)し、その後モデルをロードして生成ボタンを再有効化するイベントチェーンを設定
model_dropdown.change(
fn=disable_generate_button,
inputs=None,
outputs=generate_btn
).then(
fn=load_model,
inputs=model_dropdown,
outputs=[load_status, generate_btn]
)
# 起動時にも load_model を実行する(初期値のモデルでロード)
demo.load(fn=load_model, inputs=model_dropdown, outputs=[load_status, generate_btn])
# 生成ボタン押下時のイベント設定
generate_btn.click(
fn=generate_text,
inputs=[model_dropdown, input_text, max_length_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider],
outputs=input_text
)
clear_btn.click(lambda: "", None, input_text)
# アプリの起動
demo.launch() |