{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 读取本地的图片和标注信息\n", "import pandas as pd\n", "train_csv = pd.read_csv('../dataset/train.csv')\n", "n_inp = len(set(train_csv['label']))\n", "train_csv.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def label_func(item):\n", " rel_path = str(item.relative_to('dataset/train'))\n", " return train_csv[train_csv['image_ID']==rel_path][\"label\"].values[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.data.all import *\n", "\n", "dataloader = DataBlock(\n", " blocks=(ImageBlock, CategoryBlock),\n", " get_items=get_image_files,\n", " get_y=label_func,\n", " splitter=RandomSplitter(valid_pct=0.2, seed=42),\n", " item_tfms=Resize(224)\n", ").dataloaders('dataset/train')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataloader.show_batch(max_n=6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = vision_learner(dataloader, resnet18, metrics=error_rate)\n", "learn.fine_tune(3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_csv = pd.read_csv('dataset/test.csv')\n", "test_image = PILImage.create('dataset/test/0b84e400d4.jpg')\n", "sport,_,probs = learn.predict(test_image)\n", "print(f\"This is {sport}\")" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }