Traningafri / app.py
Sakalti's picture
Update app.py
4a4175e verified
raw
history blame
3.3 kB
# 必要なライブラリをインストールしておいてください
# pip install streamlit transformers torch huggingface_hub datasets
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from huggingface_hub import HfApi, HfFolder, Repository
import torch
import os
# Streamlit App
st.title("Hugging Face Model Training App")
st.write("castorini/afriberta-corpusを使って、ユーザーが入力したモデル名でファインチューニング")
# ユーザー入力
model_name = st.text_input("トレーニングするモデル名 (例: Qwen/Qwen2.5-1.5B-Instruct)")
dataset_name = "castorini/afriberta-corpus"
hf_token = st.text_input("Hugging Face Write トークン", type="password")
repo_name = st.text_input("Hugging Faceリポジトリ名") # ユーザーが入力できるリポジトリ名
output_dir = "./finetuned_model"
if st.button("トレーニング開始"):
if not model_name or not hf_token or not repo_name:
st.warning("モデル名、トークン、リポジトリ名を入力してください")
else:
# トークンの設定
HfFolder.save_token(hf_token)
# モデルとトークナイザーのロード
st.write("モデルとトークナイザーをロード中...")
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# データセットの準備(スワヒリ語)
st.write("データセットのロード中...")
from datasets import load_dataset
dataset = load_dataset(dataset_name, 'swahili', split="train") # 言語を指定
# トレーニング用のデータセットの準備
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.rename_column("text", "labels")
# トレーニング設定
training_args = TrainingArguments(
output_dir=output_dir,
eval_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8,
num_train_epochs=1,
save_steps=10_000,
save_total_limit=2,
)
# トレーナーの作成
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
# トレーニングの実行
st.write("トレーニング開始...")
trainer.train()
# トレーニング済みモデルの保存
st.write("トレーニング完了。モデルを保存中...")
trainer.save_model(output_dir)
# Hugging Face Hub にデプロイ
api = HfApi()
api.create_repo(repo_name, token=hf_token)
repo = Repository(local_dir=output_dir, clone_from=repo_name, use_auth_token=hf_token)
st.write("Hugging Face Hubにデプロイ中...")
repo.push_to_hub(commit_message="トレーニング済みモデルをデプロイ")
st.success(f"{repo_name}としてHugging Face Hubにデプロイ完了しました!")