File size: 6,931 Bytes
fc11025
11f0819
6c6d175
 
57e9cd2
11f0819
0d5cec5
8d8989f
 
 
eb35849
00221d8
 
0d5cec5
 
 
 
a4619fb
 
 
 
 
 
 
3531ae4
 
 
d61e896
a4619fb
4ab04be
 
a4619fb
 
 
d4fa487
a4619fb
 
 
 
 
 
 
4ab04be
 
a4619fb
57e9cd2
d925e0c
0d5cec5
d925e0c
 
 
 
 
10f0a80
d925e0c
6f37091
57e9cd2
6f37091
 
a4619fb
0d5cec5
 
e271162
 
 
 
 
 
 
0d5cec5
e271162
6c6d175
 
0d5cec5
6c6d175
 
 
 
 
 
 
57e9cd2
d925e0c
 
 
6c6d175
 
57e9cd2
 
0d5cec5
6c6d175
57e9cd2
 
6c6d175
57e9cd2
 
6f37091
11f0819
0d5cec5
57e9cd2
 
0d5cec5
 
432ab9c
0d5cec5
 
 
bd8e175
 
0d5cec5
 
 
 
6f37091
432ab9c
a4619fb
 
0d5cec5
6f37091
57e9cd2
432ab9c
3cf4ace
10f0a80
 
 
 
57e9cd2
c14f2e8
 
10f0a80
57e9cd2
4ab04be
a4619fb
 
 
 
 
 
 
 
 
 
d4fa487
 
 
432ab9c
57e9cd2
 
0d5cec5
6f37091
57e9cd2
 
 
11f0819
57e9cd2
0d5cec5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces

# PyTorch設定(パフォーマンスと再現性向上のため)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True

HF_TOKEN = os.getenv("HF_TOKEN")

# モデルのキャッシュ用辞書(ロード済みなら再利用)
loaded_models = {}

def get_model_and_tokenizer(model_name):
    # 既にロード済みならそのまま返す
    if model_name in loaded_models:
        return loaded_models[model_name]
    # ロードされていなければロードする
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN
    )
    model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN)
    loaded_models[model_name] = (model, tokenizer)
    return model, tokenizer

def disable_generate_button():
    # 生成ボタンを無効化し、テキストを「モデルをロード中……」に変更する
    return gr.update(interactive=False, value="モデルをロード中……")

def load_model(model_name):
    """
    プルダウン変更時や起動時に呼ばれ、モデルをロードして生成ボタンを有効化する。
    """
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN
    )
    model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN)
    loaded_models[model_name] = (model, tokenizer)
    status_message = f"Model '{model_name}' loaded successfully."
    # ロード完了後、生成ボタンを有効化し、テキストを「続きを生成」に戻す
    return status_message, gr.update(interactive=True, value="続きを生成")

@spaces.GPU
def generate_text(
    model_name,
    input_text, 
    max_length=150, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.95, 
    repetition_penalty=1.0
):
    """ユーザー入力に基づいてテキストを生成し、元のテキストに追加する関数"""
    try:
        if not input_text.strip():
            return ""
        # 既にロード済みのモデルとトークナイザーを使用
        model, tokenizer = get_model_and_tokenizer(model_name)
        
        # GPUが利用可能ならGPUへ移動。bf16がサポートされている場合はbf16を使用
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if device == "cuda" and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
            model.to(device, dtype=torch.bfloat16)
        else:
            model.to(device)

        # 入力テキストのトークン化
        input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
        input_token_count = input_ids.shape[1]
        
        # 総トークン数の上限を入力トークン数 + max_length(max_lengthは追加するトークン数として扱う)
        total_max_length = input_token_count + max_length
        
        # テキスト生成
        output_ids = model.generate(
            input_ids,
            max_length=total_max_length,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.eos_token_id,
            num_return_sequences=1
        )
        
        # 生成されたテキストをデコードし、入力部分を除いた生成分を抽出
        generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        new_text = generated_text[len(input_text):]
        
        # 入力テキストに生成したテキストを追加して返す
        return input_text + new_text
    except Exception as e:
        return f"{input_text}\n\nエラーが発生しました: {str(e)}"

# Gradioインターフェースの作成
with gr.Blocks() as demo:
    gr.Markdown("# テキスト続き生成アシスタント")
    gr.Markdown("モデルを選択し、テキストボックスに文章を入力してパラメータを調整後、「続きを生成」ボタンをクリックすると、選択したモデルがその続きを生成します。")
    
    # モデル選択用プルダウンメニュー
    model_dropdown = gr.Dropdown(
        choices=[
            "Local-Novel-LLM-project/Vecteus-v1-abliterated", 
            "Local-Novel-LLM-project/Ninja-V3", 
            "Local-Novel-LLM-project/kagemusya-7B-v1"
        ],
        label="モデルを選択してください",
        value="Local-Novel-LLM-project/Vecteus-v1-abliterated"
    )
    
    # 隠しコンポーネント:モデルロードの状況を表示(ユーザーには見せなくても良い)
    load_status = gr.Textbox(visible=False)
    
    # テキスト入力ボックス
    input_text = gr.Textbox(label="テキストを入力してください", placeholder="ここにテキストを入力...", lines=10)
    
    # 生成パラメータの設定UI
    max_length_slider = gr.Slider(minimum=10, maximum=1000, value=100, step=10, label="追加するトークン数")
    temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="創造性(温度)")
    top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="top_k")
    top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="top_p")
    repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, value=1.0, step=0.1, label="繰り返しペナルティ")
    
    # 生成ボタンは初期状態で無効化
    generate_btn = gr.Button("モデルをロード中……", variant="primary", interactive=False)
    clear_btn = gr.Button("クリア")
    
    # プルダウン変更時に、まず生成ボタンを無効化(テキストを「モデルをロード中……」に変更)し、その後モデルをロードして生成ボタンを再有効化するイベントチェーンを設定
    model_dropdown.change(
        fn=disable_generate_button,
        inputs=None,
        outputs=generate_btn
    ).then(
        fn=load_model,
        inputs=model_dropdown,
        outputs=[load_status, generate_btn]
    )
    
    # 起動時にも load_model を実行する(初期値のモデルでロード)
    demo.load(fn=load_model, inputs=model_dropdown, outputs=[load_status, generate_btn])
    
    # 生成ボタン押下時のイベント設定
    generate_btn.click(
        fn=generate_text, 
        inputs=[model_dropdown, input_text, max_length_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider], 
        outputs=input_text
    )
    
    clear_btn.click(lambda: "", None, input_text)

# アプリの起動
demo.launch()