{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/AI_Image_Classification/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "Resolving data files: 100%|██████████| 25/25 [00:00<00:00, 203606.99it/s]\n", "Resolving data files: 100%|██████████| 26/26 [00:00<00:00, 203076.17it/s]\n" ] } ], "source": [ "from datasets import load_dataset\n", "\n", "data = load_dataset(\"dataset\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# !pip install Pillow" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'image': ,\n", " 'label': 0}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data['train'][0]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['ai_gen', 'human']" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels = data[\"train\"].features[\"label\"].names\n", "labels" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "label2id, id2label = dict(), dict()\n", "for i, label in enumerate(labels):\n", " label2id[label] = str(i)\n", " id2label[str(i)] = label" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'ai_gen': '0', 'human': '1'}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label2id" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoImageProcessor\n", "\n", "checkpoint = \"google/vit-base-patch16-224-in21k\"\n", "image_processor = AutoImageProcessor.from_pretrained(checkpoint)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor\n", "\n", "normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)\n", "size = (\n", " image_processor.size[\"shortest_edge\"]\n", " if \"shortest_edge\" in image_processor.size\n", " else (image_processor.size[\"height\"], image_processor.size[\"width\"])\n", ")\n", "_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def transforms(examples):\n", " examples[\"pixel_values\"] = [_transforms(img.convert(\"RGB\")) for img in examples[\"image\"]]\n", " del examples[\"image\"]\n", " return examples" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "data = data.with_transform(transforms)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['image', 'label'],\n", " num_rows: 18000\n", " })\n", " validation: Dataset({\n", " features: ['image', 'label'],\n", " num_rows: 20715\n", " })\n", " test: Dataset({\n", " features: ['image', 'label'],\n", " num_rows: 13354\n", " })\n", "})" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from transformers import DefaultDataCollator\n", "\n", "data_collator = DefaultDataCollator()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/AI_Image_Classification/venv/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", " return self.fget.__get__(instance, owner)()\n" ] } ], "source": [ "from transformers import AutoModelForImageClassification, TrainingArguments, Trainer\n", "\n", "model = AutoModelForImageClassification.from_pretrained(\n", " \"umm-maybe/AI-image-detector\",\n", " num_labels=len(labels),\n", " id2label=id2label,\n", " label2id=label2id,\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "\n", "accuracy = evaluate.load(\"accuracy\")\n", "\n", "import numpy as np\n", "\n", "\n", "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " predictions = np.argmax(predictions, axis=1)\n", " return accuracy.compute(predictions=predictions, references=labels)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [843/843 43:11, Epoch 2/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracy
00.0347000.0134690.996480
10.0009000.0211300.994234
20.0241000.0107350.997529

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n", "Non-default generation parameters: {'max_length': 128}\n", "Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n", "Non-default generation parameters: {'max_length': 128}\n", "Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n", "Non-default generation parameters: {'max_length': 128}\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=843, training_loss=0.034790725539037115, metrics={'train_runtime': 2594.7053, 'train_samples_per_second': 20.812, 'train_steps_per_second': 0.325, 'total_flos': 4.2268994172435825e+18, 'train_loss': 0.034790725539037115, 'epoch': 3.0})" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "training_args = TrainingArguments(\n", " output_dir=\"ai_detector\",\n", " remove_unused_columns=False,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=5e-5,\n", " per_device_train_batch_size=16,\n", " gradient_accumulation_steps=4,\n", " per_device_eval_batch_size=16,\n", " num_train_epochs=3,\n", " warmup_ratio=0.1,\n", " logging_steps=10,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"accuracy\",\n", " # push_to_hub=True,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=data_collator,\n", " train_dataset=data[\"train\"],\n", " eval_dataset=data[\"test\"],\n", " tokenizer=image_processor,\n", " compute_metrics=compute_metrics,\n", ")\n", "\n", "trainer.train()" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }