from transformers import AutoTokenizer, BertForTokenClassification, TrainingArguments, Trainer import torch from tabulate import tabulate import wandb import os import yaml from datetime import datetime def train(json_path: str): ### Model & tokenizer loading tokenizer = AutoTokenizer.from_pretrained("jjzha/jobbert_knowledge_extraction") model = BertForTokenClassification.from_pretrained("Robzy/jobbert_knowledge_extraction") with open("./config.yaml", "r") as file: config = yaml.safe_load(file) num_epochs = config['training']['epochs'] batch_size = config['training']['batch_size'] lr = config['training']['learning_rate'] current_time = datetime.now() run = wandb.init( # set the wandb project where this run will be logged project="in-demand", # track hyperparameters and run metadata config={ "learning_rate": lr, "architecture": "BERT", "epochs": num_epochs, "batch_size": batch_size, "notes": "Datetime: " + current_time.strftime("%m/%d/%Y, %H:%M:%S") } ) ### Data loading and preprocessing from torch.utils.data import DataLoader import torch.nn as nn from transformers import DataCollatorForTokenClassification from typing import List, Tuple from datasets import load_dataset # dataset = load_dataset("json", data_files="data/test-short.json") dataset = load_dataset("json", data_files=json_path) dataset = dataset.map( lambda x: {"input_ids": torch.tensor(tokenizer.convert_tokens_to_ids(x["tokens"]))} ) def pad(list_of_lists, pad_value=0): max_len = max(len(lst) for lst in list_of_lists) # Pad shorter lists with the specified value padded_lists = [lst + [pad_value] * (max_len - len(lst)) for lst in list_of_lists] attention_masks = [[1] * len(lst) + [0] * (max_len - len(lst)) for lst in list_of_lists] return torch.tensor(padded_lists), torch.tensor(attention_masks) def collate_fn(batch: List[List[torch.Tensor]]): input_ids, attention_mask = pad(list(map(lambda x: tokenizer.convert_tokens_to_ids(x['tokens']),batch))) tags_knowledge, _ = pad([list(map(lambda x: label2id[x],o)) for o in [b['tags_knowledge'] for b in batch]]) return {"input_ids": input_ids, "tags_knowledge": tags_knowledge, "attention_mask": attention_mask} ### Training settings train_dataloader = DataLoader(dataset['train'], batch_size=batch_size, collate_fn=collate_fn) from tqdm.auto import tqdm from torch.optim import AdamW from transformers import get_scheduler model.train() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") IGNORE_INDEX = -100 criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) id2label = model.config.id2label label2id = model.config.label2id optimizer = AdamW(model.parameters(), lr=lr) num_training_steps = num_epochs * len(train_dataloader) lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps ) ### Training from dotenv import load_dotenv import os load_dotenv(".env") import logging logging.info("Initiating training") progress_bar = tqdm(range(num_epochs), desc="Epochs") for epoch in range(num_epochs): logging.info(f"Epoch #{epoch}") # print(f"Epoch #{epoch}") batch_count = 1 for batch in train_dataloader: logging.info(f"Batch #{batch_count} / {len(train_dataloader)}") # print(f"Batch #{batch_count} / {len(train_dataloader)}") tokens = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) tags_knowledge = batch['tags_knowledge'].to(device) outputs = model(tokens, attention_mask=attention_mask) # Batch pred = outputs.logits.reshape(-1, model.config.num_labels) # Logits label = torch.where(attention_mask==0, torch.tensor(IGNORE_INDEX).to(device), tags_knowledge).reshape(-1) # Labels, padding set to class idx -100 # Compute accuracy ignoring padding idx _, predicted_labels = torch.max(pred, dim=1) non_pad_elements = label != IGNORE_INDEX correct_predictions = (predicted_labels[non_pad_elements] == label[non_pad_elements]).sum().item() total_predictions = non_pad_elements.sum().item() accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0 loss = criterion(pred, label) loss.backward() optimizer.step() lr_scheduler.step() optimizer.zero_grad() wandb.log({"epoch": epoch, "accuracy": accuracy, "loss": loss}) batch_count += 1 progress_bar.update(1) print("Training complete") ### Pushing model # Hugging Face model.push_to_hub("Robzy/jobbert_knowledge_extraction") # W&B artifact = wandb.Artifact(name="jobbert-knowledge-extraction", type="BERT") state_dict = model.state_dict() with artifact.new_file('model.pth', mode='wb') as f: torch.save(state_dict, f) # Log the artifact to W&B wandb.log_artifact(artifact) if __name__ == "__main__": train(json_path="./data/data.jsonl")