File size: 6,570 Bytes
d48bb37
1bf4dfb
cb6fd2e
d48bb37
1e482da
 
 
 
 
 
 
 
 
 
 
 
 
 
d48bb37
 
1e482da
 
 
 
 
d48bb37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172f090
 
d48bb37
172f090
 
 
 
 
d48bb37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e482da
d48bb37
 
 
1e482da
d48bb37
1e482da
 
 
 
 
d48bb37
 
 
1e482da
d48bb37
 
1e482da
 
 
 
 
 
 
 
 
1bf4dfb
1e482da
 
 
 
 
 
 
1bf4dfb
1e482da
cb6fd2e
1e482da
 
 
 
 
 
 
d48bb37
 
 
 
 
 
 
1e482da
 
 
d48bb37
 
1e482da
 
172f090
1e482da
172f090
1e482da
cb6fd2e
 
d48bb37
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, TrainerCallback
from datasets import load_dataset, DatasetDict
import os
import time

# トレーニングの進行状況を格納するグローバル変数
progress_info = {
    "status": "待機中",
    "progress": 0,
    "time_remaining": None
}

def update_progress(trainer, epoch, step, total_steps, time_remaining):
    global progress_info
    progress_info["status"] = f"エポック {epoch + 1} / {trainer.args.num_train_epochs}, ステップ {step + 1} / {total_steps}"
    progress_info["progress"] = (step + 1) / total_steps
    progress_info["time_remaining"] = time_remaining

def train_and_deploy(write_token, repo_name, license_text):
    global progress_info
    progress_info["status"] = "トレーニング開始"
    progress_info["progress"] = 0
    progress_info["time_remaining"] = None

    # トークンを環境変数に設定
    os.environ['HF_WRITE_TOKEN'] = write_token
    
    # ライセンスファイルを作成
    with open("LICENSE", "w") as f:
        f.write(license_text)
    
    # モデルとトークナイザーの読み込み
    model_name = "EleutherAI/pythia-14m"  # トレーニング対象のモデル
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # FBK-MT/mosel データセットの読み込み
    dataset = load_dataset("FBK-MT/mosel")
    
    # データセットのキーを確認
    print(f"Dataset keys: {dataset.keys()}")
    if "train" not in dataset:
        raise KeyError("The dataset does not contain a 'train' split.")
    
    # testセットが存在しない場合、trainセットを分割してtestセットを作成
    if "test" not in dataset:
        dataset = dataset["train"].train_test_split(test_size=0.1)
        dataset = DatasetDict({
            "train": dataset["train"],
            "test": dataset["test"]
        })
    
    # データセットの最初のエントリのキーを確認
    print(f"Sample keys in 'train' split: {dataset['train'][0].keys()}")
    
    # データセットのトークン化
    def tokenize_function(examples):
        try:
            texts = examples['text']
            return tokenizer(texts, padding="max_length", truncation=True, max_length=128)
        except KeyError as e:
            print(f"KeyError: {e}")
            print(f"Available keys: {examples.keys()}")
            raise
    
    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    
    # トレーニング設定
    training_args = TrainingArguments(
        output_dir="./results",
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_dir="./logs",
        logging_steps=10,
        num_train_epochs=3,  # トレーニングエポック数
        push_to_hub=True,  # Hugging Face Hubにプッシュ
        hub_token=write_token,
        hub_model_id=repo_name  # ユーザーが入力したリポジトリ名
    )
    
    # Trainerの設定
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        callbacks=[CustomCallback()]
    )
    
    # トレーニング実行
    start_time = time.time()
    trainer.train()
    end_time = time.time()
    total_time = end_time - start_time
    progress_info["status"] = f"トレーニング完了(所要時間: {total_time:.2f}秒)"
    progress_info["progress"] = 1
    progress_info["time_remaining"] = 0
    
    # モデルをHugging Face Hubにプッシュ
    trainer.push_to_hub()
    
    return f"モデルが'{repo_name}'リポジトリにデプロイされました!"

class CustomCallback(TrainerCallback):
    def on_train_begin(self, args, state, control, **kwargs):
        global progress_info
        progress_info["status"] = "トレーニング開始"
        progress_info["progress"] = 0
        progress_info["time_remaining"] = None

    def on_step_begin(self, args, state, control, **kwargs):
        global progress_info
        total_steps = state.max_steps
        current_step = state.global_step
        progress_info["status"] = f"エポック {state.epoch + 1} / {args.num_train_epochs}, ステップ {current_step + 1} / {total_steps}"
        progress_info["progress"] = (current_step + 1) / total_steps
        progress_info["time_remaining"] = None

    def on_step_end(self, args, state, control, **kwargs):
        global progress_info
        total_steps = state.max_steps
        current_step = state.global_step
        elapsed_time = time.time() - state.log_history[0].get("epoch_time", time.time())  # デフォルト値を追加
        time_per_step = elapsed_time / (current_step + 1)
        remaining_steps = total_steps - current_step
        time_remaining = time_per_step * remaining_steps
        progress_info["status"] = f"エポック {state.epoch + 1} / {args.num_train_epochs}, ステップ {current_step + 1} / {total_steps}"
        progress_info["progress"] = (current_step + 1) / total_steps
        progress_info["time_remaining"] = f"{time_remaining:.2f}秒"

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("### pythia トレーニングとデプロイ")
    token_input = gr.Textbox(label="Hugging Face Write Token", placeholder="トークンを入力してください...")
    repo_input = gr.Textbox(label="リポジトリ名", placeholder="デプロイするリポジトリ名を入力してください...")
    license_input = gr.Textbox(label="ライセンス", placeholder="ライセンス情報を入力してください...")
    output = gr.Textbox(label="出力")
    progress = gr.Progress(track_tqdm=True)
    status = gr.Textbox(label="ステータス", value="待機中")
    time_remaining = gr.Textbox(label="残り時間", value="待機中")
    train_button = gr.Button("デプロイ")

    def update_ui():
        global progress_info
        status.value = progress_info["status"]
        progress.update(value=progress_info["progress"])
        time_remaining.value = f"{progress_info['time_remaining']}秒" if progress_info['time_remaining'] else "待機中"

    train_button.click(fn=train_and_deploy, inputs=[token_input, repo_input, license_input], outputs=output)
    train_button.click(fn=update_ui, inputs=[], outputs=[status, progress, time_remaining])

demo.launch()