xuyaxiong commited on
Commit
7043cd3
1 Parent(s): 89be3bf

添加:训练脚本

Browse files
Files changed (1) hide show
  1. train.ipynb +100 -0
train.ipynb ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import fastbook\n",
10
+ "fastbook.setup_book()"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 3,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "from fastbook import *\n",
20
+ "from fastai.vision.all import *\n",
21
+ "import os"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "pattern_get_class = re.compile(r'PetImages/(\\w+)/\\d+.jpg')\n",
31
+ "path = '../kagglecatsanddogs_5340/'\n",
32
+ "fnames = get_image_files(path)\n",
33
+ "\n",
34
+ "dls = ImageDataLoaders.from_path_re(\n",
35
+ " path, fnames, pattern_get_class, valid_pct=0.2, seed=42, item_tfms=Resize(224))\n",
36
+ "dls.show_batch()"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 5,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "learn = vision_learner(dls, resnet34, metrics=error_rate)\n",
46
+ "learn.fine_tune(1)\n",
47
+ "learn.export('model.pkl')"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "path = 'cat.jpg'\n",
57
+ "img = PILImage.create(path)\n",
58
+ "img.to_thumb(192)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "is_cat,_,probs = learn.predict(img)\n",
68
+ "print(f\"Is this a cat?: {is_cat}.\")\n",
69
+ "print(f\"Probability it's a cat: {probs[0].item():.6f}\")"
70
+ ]
71
+ }
72
+ ],
73
+ "metadata": {
74
+ "kernelspec": {
75
+ "display_name": "tesseract",
76
+ "language": "python",
77
+ "name": "python3"
78
+ },
79
+ "language_info": {
80
+ "codemirror_mode": {
81
+ "name": "ipython",
82
+ "version": 3
83
+ },
84
+ "file_extension": ".py",
85
+ "mimetype": "text/x-python",
86
+ "name": "python",
87
+ "nbconvert_exporter": "python",
88
+ "pygments_lexer": "ipython3",
89
+ "version": "3.9.15"
90
+ },
91
+ "orig_nbformat": 4,
92
+ "vscode": {
93
+ "interpreter": {
94
+ "hash": "c8334dffe72b6a881969c3515475442b0cf3f3c8c06d8151aebf952bb4134fbe"
95
+ }
96
+ }
97
+ },
98
+ "nbformat": 4,
99
+ "nbformat_minor": 2
100
+ }