|
from datasets import load_from_disk |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling |
|
from accelerate import Accelerator |
|
import os |
|
import logging |
|
from flask import Flask |
|
from multiprocessing import Process |
|
import time |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
DATA_PATH = "processed_dataset" |
|
train_dataset = load_from_disk(DATA_PATH) |
|
|
|
|
|
model_name = "vishaljoshi24/crd3_text_gen_1" |
|
model = AutoModelForCausalLM.from_pretrained(model_name, token=HF_TOKEN) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN) |
|
|
|
|
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
tokenizer.pad_token = '[PAD]' |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=tokenizer, |
|
mlm=False |
|
) |
|
|
|
def train_model(): |
|
logger.info("Loading dataset...") |
|
train_dataset = load_from_disk(DATA_PATH) |
|
|
|
logger.info("Loading model and tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=HF_TOKEN) |
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
logger.info("Preparing data collator...") |
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=tokenizer, |
|
mlm=False |
|
) |
|
|
|
logger.info("Setting training arguments...") |
|
training_args = TrainingArguments( |
|
output_dir="./output", |
|
evaluation_strategy="no", |
|
learning_rate=1e-5, |
|
lr_scheduler_type="cosine", |
|
warmup_steps=500, |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
num_train_epochs=3, |
|
weight_decay=0.01, |
|
max_grad_norm=0.95, |
|
save_steps=500, |
|
gradient_accumulation_steps=8, |
|
logging_dir="./logs", |
|
push_to_hub=True, |
|
hub_model_id="vishaljoshi24/crd3_text_gen_2", |
|
report_to="none" |
|
) |
|
|
|
logger.info("Initializing Trainer...") |
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
) |
|
|
|
logger.info("Starting training...") |
|
trainer.train() |
|
|
|
logger.info("Training complete. Pushing model to Hugging Face Hub...") |
|
model.push_to_hub("vishaljoshi24/crd3_text_gen_2", token=HF_TOKEN) |
|
tokenizer.push_to_hub("vishaljoshi24/crd3_text_gen_2", token=HF_TOKEN) |
|
logger.info("Model and tokenizer pushed successfully.") |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
training_status = {"status": "Idle"} |
|
|
|
@app.route("/") |
|
def health_check(): |
|
return {"message": "Application is healthy", "training_status": training_status["status"]}, 200 |
|
|
|
def start_flask_server(): |
|
app.run(host="0.0.0.0", port=7860) |
|
|
|
if __name__ == "__main__": |
|
|
|
flask_process = Process(target=start_flask_server) |
|
flask_process.start() |
|
|
|
try: |
|
|
|
training_status["status"] = "Training in progress..." |
|
train_model() |
|
training_status["status"] = "Training complete" |
|
finally: |
|
flask_process.terminate() |
|
|