SupremoUGH's picture
autotrain set up
799babb unverified
raw
history blame contribute delete
1.76 kB
from transformers import (
AutoModelForImageClassification,
AutoImageProcessor,
TrainingArguments,
Trainer,
)
from datasets import load_dataset
import os
def train():
# Load dataset
dataset = load_dataset("ylecun/mnist")
# Load processor and apply preprocessing to the dataset
processor = AutoImageProcessor.from_pretrained("SupremoUGH/image-classification-model")
def process(examples):
images = [img.convert("RGB") for img in examples["image"]]
inputs = processor(images=images, return_tensors="pt")
inputs["labels"] = examples["label"]
return inputs
dataset.set_transform(process) # Sometimes `map` instead of `set_transform`
# Load model and train it with certain training arguments
model = AutoModelForImageClassification.from_pretrained("SupremoUGH/image-classification-model")
training_args = TrainingArguments(
output_dir="./results",
remove_unused_columns=False, # Preserve input data
per_device_train_batch_size=16, # Reduce batch size for efficiency
eval_strategy="steps",
num_train_epochs=3,
fp16=False, # Disable fp16 mixed precision
save_steps=500,
eval_steps=500,
logging_steps=100,
learning_rate=2e-4,
push_to_hub=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"], # Sometimes called "validation"
)
trainer.train()
# Save fine-tuned model
save_dir = "./saved_model"
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir)
print(f"Model saved to {save_dir}")
if __name__ == "__main__":
train()