DKatheesrupan commited on
Commit
79562ec
·
verified ·
1 Parent(s): 9ebc59b

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +125 -0
  2. evaluate_clip_openai.ipynb +694 -0
  3. requirements.txt +11 -0
  4. train_cat_vit.ipynb +842 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+ from transformers import pipeline
6
+
7
+
8
+ # ----------------------------
9
+ # Paths
10
+ # ----------------------------
11
+
12
+ BASE_DIR = Path(__file__).resolve().parent
13
+
14
+ # HIER ggf. den Modellordner anpassen
15
+ MODEL_PATH = BASE_DIR.parent / "flower-vit"
16
+
17
+ EXAMPLE_DIR = BASE_DIR / "example_images"
18
+
19
+
20
+ # ----------------------------
21
+ # Labels
22
+ # ----------------------------
23
+
24
+ CAT_LABELS = ["cheetah", "leopard", "lion", "puma", "tiger"]
25
+
26
+
27
+ # ----------------------------
28
+ # Load models
29
+ # ----------------------------
30
+
31
+ print("Loading custom model...")
32
+ vit_classifier = pipeline(
33
+ "image-classification",
34
+ model=str(MODEL_PATH)
35
+ )
36
+
37
+ print("Loading CLIP model...")
38
+ clip_classifier = pipeline(
39
+ task="zero-shot-image-classification",
40
+ model="openai/clip-vit-base-patch32"
41
+ )
42
+
43
+
44
+ # ----------------------------
45
+ # Helper functions
46
+ # ----------------------------
47
+
48
+ def normalize_custom_labels(results):
49
+ id2label = {
50
+ "LABEL_0": "cheetah",
51
+ "LABEL_1": "leopard",
52
+ "LABEL_2": "lion",
53
+ "LABEL_3": "puma",
54
+ "LABEL_4": "tiger",
55
+ }
56
+
57
+ output = {}
58
+
59
+ for r in results:
60
+ label = r["label"]
61
+ score = float(r["score"])
62
+
63
+ if label in id2label:
64
+ label = id2label[label]
65
+ else:
66
+ label = label.lower()
67
+
68
+ output[label] = score
69
+
70
+ return output
71
+
72
+
73
+ # ----------------------------
74
+ # Main function
75
+ # ----------------------------
76
+
77
+ def classify_cat(image):
78
+ # Custom Model
79
+ vit_results = vit_classifier(image)
80
+ vit_output = normalize_custom_labels(vit_results)
81
+
82
+ # CLIP
83
+ clip_labels = [f"a photo of a {label}" for label in CAT_LABELS]
84
+ clip_results = clip_classifier(image, candidate_labels=clip_labels)
85
+
86
+ clip_output = {}
87
+ for r in clip_results:
88
+ label = r["label"].replace("a photo of a ", "").lower()
89
+ score = float(r["score"])
90
+ clip_output[label] = score
91
+
92
+ return vit_output, clip_output
93
+
94
+
95
+ # ----------------------------
96
+ # Example images
97
+ # ----------------------------
98
+
99
+ example_images = [
100
+ [str(EXAMPLE_DIR / "Cheetah_032.jpg")],
101
+ [str(EXAMPLE_DIR / "Leopard_001.jpg")],
102
+ [str(EXAMPLE_DIR / "Lion_003.jpg")],
103
+ [str(EXAMPLE_DIR / "Puma_001.jpg")],
104
+ [str(EXAMPLE_DIR / "Tiger_001.jpg")]
105
+ ]
106
+
107
+
108
+ # ----------------------------
109
+ # Interface
110
+ # ----------------------------
111
+
112
+ iface = gr.Interface(
113
+ fn=classify_cat,
114
+ inputs=gr.Image(type="filepath"),
115
+ outputs=[
116
+ gr.Label(label="Custom Model"),
117
+ gr.Label(label="CLIP")
118
+ ],
119
+ title="Big Cat Classification",
120
+ description="Compare Custom Model vs CLIP",
121
+ examples=example_images
122
+ )
123
+
124
+ if __name__ == "__main__":
125
+ iface.launch()
evaluate_clip_openai.ipynb ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 60,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Requirement already satisfied: transformers in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (5.5.0)\n",
13
+ "Requirement already satisfied: torch in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (2.11.0)\n",
14
+ "Requirement already satisfied: pillow in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (12.1.1)\n",
15
+ "Requirement already satisfied: openai in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (2.30.0)\n",
16
+ "Requirement already satisfied: huggingface-hub<2.0,>=1.5.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (1.6.0)\n",
17
+ "Requirement already satisfied: numpy>=1.17 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (2.4.2)\n",
18
+ "Requirement already satisfied: packaging>=20.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from transformers) (26.0)\n",
19
+ "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (6.0.3)\n",
20
+ "Requirement already satisfied: regex>=2025.10.22 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (2026.4.4)\n",
21
+ "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (0.22.2)\n",
22
+ "Requirement already satisfied: typer in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (0.24.1)\n",
23
+ "Requirement already satisfied: safetensors>=0.4.3 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (0.7.0)\n",
24
+ "Requirement already satisfied: tqdm>=4.27 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from transformers) (4.67.3)\n",
25
+ "Requirement already satisfied: filelock>=3.10.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (3.25.0)\n",
26
+ "Requirement already satisfied: fsspec>=2023.5.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (2026.2.0)\n",
27
+ "Requirement already satisfied: hf-xet<2.0.0,>=1.3.2 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (1.3.2)\n",
28
+ "Requirement already satisfied: httpx<1,>=0.23.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (0.28.1)\n",
29
+ "Requirement already satisfied: typing-extensions>=4.1.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from huggingface-hub<2.0,>=1.5.0->transformers) (4.15.0)\n",
30
+ "Requirement already satisfied: anyio in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (4.12.1)\n",
31
+ "Requirement already satisfied: certifi in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (2026.2.25)\n",
32
+ "Requirement already satisfied: httpcore==1.* in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (1.0.9)\n",
33
+ "Requirement already satisfied: idna in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (3.11)\n",
34
+ "Requirement already satisfied: h11>=0.16 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub<2.0,>=1.5.0->transformers) (0.16.0)\n",
35
+ "Requirement already satisfied: setuptools<82 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from torch) (81.0.0)\n",
36
+ "Requirement already satisfied: sympy>=1.13.3 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from torch) (1.14.0)\n",
37
+ "Requirement already satisfied: networkx>=2.5.1 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from torch) (3.6.1)\n",
38
+ "Requirement already satisfied: jinja2 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from torch) (3.1.6)\n",
39
+ "Requirement already satisfied: distro<2,>=1.7.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from openai) (1.9.0)\n",
40
+ "Requirement already satisfied: jiter<1,>=0.10.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from openai) (0.13.0)\n",
41
+ "Requirement already satisfied: pydantic<3,>=1.9.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from openai) (2.12.5)\n",
42
+ "Requirement already satisfied: sniffio in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from openai) (1.3.1)\n",
43
+ "Requirement already satisfied: annotated-types>=0.6.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from pydantic<3,>=1.9.0->openai) (0.7.0)\n",
44
+ "Requirement already satisfied: pydantic-core==2.41.5 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from pydantic<3,>=1.9.0->openai) (2.41.5)\n",
45
+ "Requirement already satisfied: typing-inspection>=0.4.2 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from pydantic<3,>=1.9.0->openai) (0.4.2)\n",
46
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from sympy>=1.13.3->torch) (1.3.0)\n",
47
+ "Requirement already satisfied: colorama in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from tqdm>=4.27->transformers) (0.4.6)\n",
48
+ "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from jinja2->torch) (3.0.3)\n",
49
+ "Requirement already satisfied: click>=8.2.1 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from typer->transformers) (8.3.1)\n",
50
+ "Requirement already satisfied: shellingham>=1.3.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from typer->transformers) (1.5.4)\n",
51
+ "Requirement already satisfied: rich>=12.3.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from typer->transformers) (14.3.3)\n",
52
+ "Requirement already satisfied: annotated-doc>=0.0.2 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from typer->transformers) (0.0.4)\n",
53
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from rich>=12.3.0->typer->transformers) (4.0.0)\n",
54
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from rich>=12.3.0->typer->transformers) (2.19.2)\n",
55
+ "Requirement already satisfied: mdurl~=0.1 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer->transformers) (0.1.2)\n",
56
+ "Note: you may need to restart the kernel to use updated packages.\n"
57
+ ]
58
+ },
59
+ {
60
+ "name": "stderr",
61
+ "output_type": "stream",
62
+ "text": [
63
+ "\n",
64
+ "[notice] A new release of pip is available: 25.3 -> 26.0.1\n",
65
+ "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
66
+ ]
67
+ }
68
+ ],
69
+ "source": [
70
+ "%pip install transformers torch pillow openai\n",
71
+ "from transformers import pipeline\n",
72
+ "from PIL import Image\n",
73
+ "import os\n",
74
+ "import pandas as pd"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 63,
80
+ "metadata": {},
81
+ "outputs": [
82
+ {
83
+ "name": "stdout",
84
+ "output_type": "stream",
85
+ "text": [
86
+ "Model path exists: True\n",
87
+ "Image folder exists: True\n",
88
+ "Images: ['Cheetah_032.jpg', 'Leopard_001.jpg', 'Lion_003.jpg', 'Puma_001.jpg', 'Tiger_001.jpg']\n"
89
+ ]
90
+ }
91
+ ],
92
+ "source": [
93
+ "MODEL_PATH = \"./cat-vit\"\n",
94
+ "IMAGE_FOLDER = \"./Cats-classification-app/example_images\"\n",
95
+ "\n",
96
+ "labels = [\"cheetah\", \"leopard\", \"lion\", \"puma\", \"tiger\"]\n",
97
+ "clip_labels = [f\"a photo of a {label}\" for label in labels]\n",
98
+ "\n",
99
+ "print(\"Model path exists:\", os.path.exists(MODEL_PATH))\n",
100
+ "print(\"Image folder exists:\", os.path.exists(IMAGE_FOLDER))\n",
101
+ "print(\"Images:\", [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith((\".jpg\", \".jpeg\", \".png\"))])"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 64,
107
+ "metadata": {},
108
+ "outputs": [
109
+ {
110
+ "data": {
111
+ "application/vnd.jupyter.widget-view+json": {
112
+ "model_id": "1bf87a05bbc346c9b3f30eb950c1f3a5",
113
+ "version_major": 2,
114
+ "version_minor": 0
115
+ },
116
+ "text/plain": [
117
+ "Loading weights: 0%| | 0/200 [00:00<?, ?it/s]"
118
+ ]
119
+ },
120
+ "metadata": {},
121
+ "output_type": "display_data"
122
+ },
123
+ {
124
+ "data": {
125
+ "application/vnd.jupyter.widget-view+json": {
126
+ "model_id": "fad1554b05bf40d7b31480f8daa8ad35",
127
+ "version_major": 2,
128
+ "version_minor": 0
129
+ },
130
+ "text/plain": [
131
+ "Loading weights: 0%| | 0/398 [00:00<?, ?it/s]"
132
+ ]
133
+ },
134
+ "metadata": {},
135
+ "output_type": "display_data"
136
+ },
137
+ {
138
+ "name": "stderr",
139
+ "output_type": "stream",
140
+ "text": [
141
+ "\u001b[1mCLIPModel LOAD REPORT\u001b[0m from: openai/clip-vit-base-patch32\n",
142
+ "Key | Status | | \n",
143
+ "-------------------------------------+------------+--+-\n",
144
+ "text_model.embeddings.position_ids | UNEXPECTED | | \n",
145
+ "vision_model.embeddings.position_ids | UNEXPECTED | | \n",
146
+ "\n",
147
+ "Notes:\n",
148
+ "- UNEXPECTED:\tcan be ignored when loading from different task/architecture; not ok if you expect identical arch.\n"
149
+ ]
150
+ }
151
+ ],
152
+ "source": [
153
+ "custom_model = pipeline(\"image-classification\", model=MODEL_PATH)\n",
154
+ "\n",
155
+ "clip_model = pipeline(\n",
156
+ " \"zero-shot-image-classification\",\n",
157
+ " model=\"openai/clip-vit-base-patch32\"\n",
158
+ ")"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 65,
164
+ "metadata": {},
165
+ "outputs": [
166
+ {
167
+ "data": {
168
+ "application/vnd.jupyter.widget-view+json": {
169
+ "model_id": "bc19664791384aceb2502dfe76b5dd1d",
170
+ "version_major": 2,
171
+ "version_minor": 0
172
+ },
173
+ "text/plain": [
174
+ "Loading weights: 0%| | 0/398 [00:00<?, ?it/s]"
175
+ ]
176
+ },
177
+ "metadata": {},
178
+ "output_type": "display_data"
179
+ },
180
+ {
181
+ "name": "stderr",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "\u001b[1mCLIPModel LOAD REPORT\u001b[0m from: openai/clip-vit-base-patch32\n",
185
+ "Key | Status | | \n",
186
+ "-------------------------------------+------------+--+-\n",
187
+ "text_model.embeddings.position_ids | UNEXPECTED | | \n",
188
+ "vision_model.embeddings.position_ids | UNEXPECTED | | \n",
189
+ "\n",
190
+ "Notes:\n",
191
+ "- UNEXPECTED:\tcan be ignored when loading from different task/architecture; not ok if you expect identical arch.\n"
192
+ ]
193
+ },
194
+ {
195
+ "name": "stdout",
196
+ "output_type": "stream",
197
+ "text": [
198
+ "CLIP model loaded!\n"
199
+ ]
200
+ }
201
+ ],
202
+ "source": [
203
+ "clip_model = pipeline(\n",
204
+ " \"zero-shot-image-classification\",\n",
205
+ " model=\"openai/clip-vit-base-patch32\"\n",
206
+ ")\n",
207
+ "\n",
208
+ "print(\"CLIP model loaded!\")"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 66,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "def get_true_label(filename):\n",
218
+ " name = filename.lower()\n",
219
+ " \n",
220
+ " if name.startswith(\"cheetah\"):\n",
221
+ " return \"cheetah\"\n",
222
+ " elif name.startswith(\"leopard\"):\n",
223
+ " return \"leopard\"\n",
224
+ " elif name.startswith(\"lion\"):\n",
225
+ " return \"lion\"\n",
226
+ " elif name.startswith(\"puma\"):\n",
227
+ " return \"puma\"\n",
228
+ " elif name.startswith(\"tiger\"):\n",
229
+ " return \"tiger\"\n",
230
+ " else:\n",
231
+ " return \"unknown\""
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": 67,
237
+ "metadata": {},
238
+ "outputs": [
239
+ {
240
+ "name": "stdout",
241
+ "output_type": "stream",
242
+ "text": [
243
+ "Found images: ['Cheetah_032.jpg', 'Leopard_001.jpg', 'Lion_003.jpg', 'Puma_001.jpg', 'Tiger_001.jpg']\n",
244
+ "results length: 5\n",
245
+ " image true_label custom_pred custom_score clip_pred clip_score \\\n",
246
+ "0 Cheetah_032.jpg cheetah cheetah 0.5264 cheetah 0.8319 \n",
247
+ "1 Leopard_001.jpg leopard leopard 0.5127 leopard 0.9232 \n",
248
+ "2 Lion_003.jpg lion lion 0.5408 lion 0.9949 \n",
249
+ "3 Puma_001.jpg puma puma 0.6112 puma 0.9986 \n",
250
+ "4 Tiger_001.jpg tiger tiger 0.6976 tiger 0.9892 \n",
251
+ "\n",
252
+ " custom_correct clip_correct \n",
253
+ "0 True True \n",
254
+ "1 True True \n",
255
+ "2 True True \n",
256
+ "3 True True \n",
257
+ "4 True True \n",
258
+ "columns: ['image', 'true_label', 'custom_pred', 'custom_score', 'clip_pred', 'clip_score', 'custom_correct', 'clip_correct']\n"
259
+ ]
260
+ }
261
+ ],
262
+ "source": [
263
+ "results = []\n",
264
+ "\n",
265
+ "id2label = {\n",
266
+ " 0: \"cheetah\",\n",
267
+ " 1: \"leopard\",\n",
268
+ " 2: \"lion\",\n",
269
+ " 3: \"puma\",\n",
270
+ " 4: \"tiger\"\n",
271
+ "}\n",
272
+ "\n",
273
+ "image_files = sorted([\n",
274
+ " f for f in os.listdir(IMAGE_FOLDER)\n",
275
+ " if f.lower().endswith((\".jpg\", \".jpeg\", \".png\"))\n",
276
+ "])\n",
277
+ "\n",
278
+ "print(\"Found images:\", image_files)\n",
279
+ "\n",
280
+ "for img_file in image_files:\n",
281
+ " image_path = os.path.join(IMAGE_FOLDER, img_file)\n",
282
+ " image = Image.open(image_path).convert(\"RGB\")\n",
283
+ " true_label = get_true_label(img_file)\n",
284
+ "\n",
285
+ " custom_result = custom_model(image)[0]\n",
286
+ " raw_custom_label = custom_result[\"label\"]\n",
287
+ " custom_score = float(custom_result[\"score\"])\n",
288
+ "\n",
289
+ " if raw_custom_label.startswith(\"LABEL_\"):\n",
290
+ " label_id = int(raw_custom_label.split(\"_\")[1])\n",
291
+ " custom_pred = id2label[label_id]\n",
292
+ " else:\n",
293
+ " custom_pred = raw_custom_label.lower()\n",
294
+ "\n",
295
+ " clip_result = clip_model(image, candidate_labels=clip_labels)[0]\n",
296
+ " clip_pred = clip_result[\"label\"].replace(\"a photo of a \", \"\").lower()\n",
297
+ " clip_score = float(clip_result[\"score\"])\n",
298
+ "\n",
299
+ " results.append({\n",
300
+ " \"image\": img_file,\n",
301
+ " \"true_label\": true_label,\n",
302
+ " \"custom_pred\": custom_pred,\n",
303
+ " \"custom_score\": round(custom_score, 4),\n",
304
+ " \"clip_pred\": clip_pred,\n",
305
+ " \"clip_score\": round(clip_score, 4),\n",
306
+ " \"custom_correct\": custom_pred == true_label,\n",
307
+ " \"clip_correct\": clip_pred == true_label,\n",
308
+ " })\n",
309
+ "\n",
310
+ "print(\"results length:\", len(results))\n",
311
+ "\n",
312
+ "df = pd.DataFrame(results)\n",
313
+ "print(df)\n",
314
+ "print(\"columns:\", df.columns.tolist())"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": 68,
320
+ "metadata": {},
321
+ "outputs": [
322
+ {
323
+ "name": "stdout",
324
+ "output_type": "stream",
325
+ "text": [
326
+ "Custom accuracy: 1.0\n",
327
+ "CLIP accuracy: 1.0\n"
328
+ ]
329
+ }
330
+ ],
331
+ "source": [
332
+ "custom_accuracy = df[\"custom_correct\"].mean()\n",
333
+ "clip_accuracy = df[\"clip_correct\"].mean()\n",
334
+ "\n",
335
+ "print(\"Custom accuracy:\", round(custom_accuracy, 4))\n",
336
+ "print(\"CLIP accuracy:\", round(clip_accuracy, 4))"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": 69,
342
+ "metadata": {},
343
+ "outputs": [
344
+ {
345
+ "name": "stdout",
346
+ "output_type": "stream",
347
+ "text": [
348
+ "Saved to comparison_results.csv\n"
349
+ ]
350
+ }
351
+ ],
352
+ "source": [
353
+ "df.to_csv(\"comparison_results.csv\", index=False)\n",
354
+ "print(\"Saved to comparison_results.csv\")"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": 70,
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "MODEL_PATH = \"./cat-vit\"\n",
364
+ "IMAGE_FOLDER = \"./Cats-classification-app/example_images\""
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": 71,
370
+ "metadata": {},
371
+ "outputs": [],
372
+ "source": [
373
+ "import os\n",
374
+ "from openai import OpenAI\n",
375
+ "\n",
376
+ "os.environ[\"OPENAI_API_KEY\"] = \"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\"\n",
377
+ "\n",
378
+ "client = OpenAI()"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": 72,
384
+ "metadata": {},
385
+ "outputs": [],
386
+ "source": [
387
+ "def predict_openai_label(image_path):\n",
388
+ " with open(image_path, \"rb\") as image_file:\n",
389
+ " image_base64 = base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
390
+ "\n",
391
+ " response = client.responses.create(\n",
392
+ " model=\"gpt-4.1-mini\",\n",
393
+ " input=[\n",
394
+ " {\n",
395
+ " \"role\": \"user\",\n",
396
+ " \"content\": [\n",
397
+ " {\n",
398
+ " \"type\": \"input_text\",\n",
399
+ " \"text\": \"Classify this image as exactly one of these labels: cheetah, leopard, lion, puma, tiger. Return only one label in lowercase.\"\n",
400
+ " },\n",
401
+ " {\n",
402
+ " \"type\": \"input_image\",\n",
403
+ " \"image_url\": f\"data:image/jpeg;base64,{image_base64}\"\n",
404
+ " }\n",
405
+ " ]\n",
406
+ " }\n",
407
+ " ]\n",
408
+ " )\n",
409
+ "\n",
410
+ " return response.output_text.strip().lower()"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": 48,
416
+ "metadata": {},
417
+ "outputs": [],
418
+ "source": [
419
+ "def predict_openai_label(image_path):\n",
420
+ " with open(image_path, \"rb\") as image_file:\n",
421
+ " image_base64 = base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
422
+ "\n",
423
+ " response = client.responses.create(\n",
424
+ " model=\"gpt-4.1-mini\",\n",
425
+ " input=[\n",
426
+ " {\n",
427
+ " \"role\": \"user\",\n",
428
+ " \"content\": [\n",
429
+ " {\n",
430
+ " \"type\": \"input_text\",\n",
431
+ " \"text\": \"Classify this image as exactly one of these labels: cheetah, leopard, lion, puma, tiger. Return only one label in lowercase.\"\n",
432
+ " },\n",
433
+ " {\n",
434
+ " \"type\": \"input_image\",\n",
435
+ " \"image_url\": f\"data:image/jpeg;base64,{image_base64}\"\n",
436
+ " }\n",
437
+ " ]\n",
438
+ " }\n",
439
+ " ]\n",
440
+ " )\n",
441
+ "\n",
442
+ " return response.output_text.strip().lower()"
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "execution_count": 73,
448
+ "metadata": {},
449
+ "outputs": [
450
+ {
451
+ "data": {
452
+ "text/html": [
453
+ "<div>\n",
454
+ "<style scoped>\n",
455
+ " .dataframe tbody tr th:only-of-type {\n",
456
+ " vertical-align: middle;\n",
457
+ " }\n",
458
+ "\n",
459
+ " .dataframe tbody tr th {\n",
460
+ " vertical-align: top;\n",
461
+ " }\n",
462
+ "\n",
463
+ " .dataframe thead th {\n",
464
+ " text-align: right;\n",
465
+ " }\n",
466
+ "</style>\n",
467
+ "<table border=\"1\" class=\"dataframe\">\n",
468
+ " <thead>\n",
469
+ " <tr style=\"text-align: right;\">\n",
470
+ " <th></th>\n",
471
+ " <th>image</th>\n",
472
+ " <th>true_label</th>\n",
473
+ " <th>custom_pred</th>\n",
474
+ " <th>custom_score</th>\n",
475
+ " <th>clip_pred</th>\n",
476
+ " <th>clip_score</th>\n",
477
+ " <th>openai_pred</th>\n",
478
+ " <th>custom_correct</th>\n",
479
+ " <th>clip_correct</th>\n",
480
+ " <th>openai_correct</th>\n",
481
+ " </tr>\n",
482
+ " </thead>\n",
483
+ " <tbody>\n",
484
+ " <tr>\n",
485
+ " <th>0</th>\n",
486
+ " <td>Cheetah_032.jpg</td>\n",
487
+ " <td>cheetah</td>\n",
488
+ " <td>cheetah</td>\n",
489
+ " <td>0.5264</td>\n",
490
+ " <td>cheetah</td>\n",
491
+ " <td>0.8319</td>\n",
492
+ " <td>ERROR: name 'base64' is not defined</td>\n",
493
+ " <td>True</td>\n",
494
+ " <td>True</td>\n",
495
+ " <td>False</td>\n",
496
+ " </tr>\n",
497
+ " <tr>\n",
498
+ " <th>1</th>\n",
499
+ " <td>Leopard_001.jpg</td>\n",
500
+ " <td>leopard</td>\n",
501
+ " <td>leopard</td>\n",
502
+ " <td>0.5127</td>\n",
503
+ " <td>leopard</td>\n",
504
+ " <td>0.9232</td>\n",
505
+ " <td>ERROR: name 'base64' is not defined</td>\n",
506
+ " <td>True</td>\n",
507
+ " <td>True</td>\n",
508
+ " <td>False</td>\n",
509
+ " </tr>\n",
510
+ " <tr>\n",
511
+ " <th>2</th>\n",
512
+ " <td>Lion_003.jpg</td>\n",
513
+ " <td>lion</td>\n",
514
+ " <td>lion</td>\n",
515
+ " <td>0.5408</td>\n",
516
+ " <td>lion</td>\n",
517
+ " <td>0.9949</td>\n",
518
+ " <td>ERROR: name 'base64' is not defined</td>\n",
519
+ " <td>True</td>\n",
520
+ " <td>True</td>\n",
521
+ " <td>False</td>\n",
522
+ " </tr>\n",
523
+ " <tr>\n",
524
+ " <th>3</th>\n",
525
+ " <td>Puma_001.jpg</td>\n",
526
+ " <td>puma</td>\n",
527
+ " <td>puma</td>\n",
528
+ " <td>0.6112</td>\n",
529
+ " <td>puma</td>\n",
530
+ " <td>0.9986</td>\n",
531
+ " <td>ERROR: name 'base64' is not defined</td>\n",
532
+ " <td>True</td>\n",
533
+ " <td>True</td>\n",
534
+ " <td>False</td>\n",
535
+ " </tr>\n",
536
+ " <tr>\n",
537
+ " <th>4</th>\n",
538
+ " <td>Tiger_001.jpg</td>\n",
539
+ " <td>tiger</td>\n",
540
+ " <td>tiger</td>\n",
541
+ " <td>0.6976</td>\n",
542
+ " <td>tiger</td>\n",
543
+ " <td>0.9892</td>\n",
544
+ " <td>ERROR: name 'base64' is not defined</td>\n",
545
+ " <td>True</td>\n",
546
+ " <td>True</td>\n",
547
+ " <td>False</td>\n",
548
+ " </tr>\n",
549
+ " </tbody>\n",
550
+ "</table>\n",
551
+ "</div>"
552
+ ],
553
+ "text/plain": [
554
+ " image true_label custom_pred custom_score clip_pred clip_score \\\n",
555
+ "0 Cheetah_032.jpg cheetah cheetah 0.5264 cheetah 0.8319 \n",
556
+ "1 Leopard_001.jpg leopard leopard 0.5127 leopard 0.9232 \n",
557
+ "2 Lion_003.jpg lion lion 0.5408 lion 0.9949 \n",
558
+ "3 Puma_001.jpg puma puma 0.6112 puma 0.9986 \n",
559
+ "4 Tiger_001.jpg tiger tiger 0.6976 tiger 0.9892 \n",
560
+ "\n",
561
+ " openai_pred custom_correct clip_correct \\\n",
562
+ "0 ERROR: name 'base64' is not defined True True \n",
563
+ "1 ERROR: name 'base64' is not defined True True \n",
564
+ "2 ERROR: name 'base64' is not defined True True \n",
565
+ "3 ERROR: name 'base64' is not defined True True \n",
566
+ "4 ERROR: name 'base64' is not defined True True \n",
567
+ "\n",
568
+ " openai_correct \n",
569
+ "0 False \n",
570
+ "1 False \n",
571
+ "2 False \n",
572
+ "3 False \n",
573
+ "4 False "
574
+ ]
575
+ },
576
+ "execution_count": 73,
577
+ "metadata": {},
578
+ "output_type": "execute_result"
579
+ }
580
+ ],
581
+ "source": [
582
+ "results = []\n",
583
+ "\n",
584
+ "image_files = sorted([\n",
585
+ " f for f in os.listdir(IMAGE_FOLDER)\n",
586
+ " if f.lower().endswith((\".jpg\", \".jpeg\", \".png\"))\n",
587
+ "])\n",
588
+ "\n",
589
+ "for img_file in image_files:\n",
590
+ " image_path = os.path.join(IMAGE_FOLDER, img_file)\n",
591
+ " image = Image.open(image_path).convert(\"RGB\")\n",
592
+ " true_label = get_true_label(img_file)\n",
593
+ "\n",
594
+ " # Custom model\n",
595
+ " custom_result = custom_model(image)[0]\n",
596
+ " custom_pred = custom_result[\"label\"].lower()\n",
597
+ " custom_score = float(custom_result[\"score\"])\n",
598
+ "\n",
599
+ " # CLIP model\n",
600
+ " clip_result = clip_model(image, candidate_labels=clip_labels)[0]\n",
601
+ " clip_pred = clip_result[\"label\"].replace(\"a photo of a \", \"\").lower()\n",
602
+ " clip_score = float(clip_result[\"score\"])\n",
603
+ "\n",
604
+ " # OpenAI model\n",
605
+ " try:\n",
606
+ " openai_pred = predict_openai_label(image_path)\n",
607
+ " openai_correct = openai_pred == true_label\n",
608
+ " except Exception as e:\n",
609
+ " openai_pred = f\"ERROR: {e}\"\n",
610
+ " openai_correct = False\n",
611
+ "\n",
612
+ " results.append({\n",
613
+ " \"image\": img_file,\n",
614
+ " \"true_label\": true_label,\n",
615
+ " \"custom_pred\": custom_pred,\n",
616
+ " \"custom_score\": round(custom_score, 4),\n",
617
+ " \"clip_pred\": clip_pred,\n",
618
+ " \"clip_score\": round(clip_score, 4),\n",
619
+ " \"openai_pred\": openai_pred,\n",
620
+ " \"custom_correct\": custom_pred == true_label,\n",
621
+ " \"clip_correct\": clip_pred == true_label,\n",
622
+ " \"openai_correct\": openai_correct,\n",
623
+ " })\n",
624
+ "\n",
625
+ "df = pd.DataFrame(results)\n",
626
+ "df"
627
+ ]
628
+ },
629
+ {
630
+ "cell_type": "code",
631
+ "execution_count": 74,
632
+ "metadata": {},
633
+ "outputs": [
634
+ {
635
+ "name": "stdout",
636
+ "output_type": "stream",
637
+ "text": [
638
+ "Custom accuracy: 1.0\n",
639
+ "CLIP accuracy: 1.0\n",
640
+ "OpenAI accuracy: 0.0\n"
641
+ ]
642
+ }
643
+ ],
644
+ "source": [
645
+ "custom_accuracy = df[\"custom_correct\"].mean()\n",
646
+ "clip_accuracy = df[\"clip_correct\"].mean()\n",
647
+ "openai_accuracy = df[\"openai_correct\"].mean()\n",
648
+ "\n",
649
+ "print(\"Custom accuracy:\", round(custom_accuracy, 4))\n",
650
+ "print(\"CLIP accuracy:\", round(clip_accuracy, 4))\n",
651
+ "print(\"OpenAI accuracy:\", round(openai_accuracy, 4))"
652
+ ]
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "execution_count": 75,
657
+ "metadata": {},
658
+ "outputs": [
659
+ {
660
+ "name": "stdout",
661
+ "output_type": "stream",
662
+ "text": [
663
+ "Saved to ../comparison_results_with_openai.csv\n"
664
+ ]
665
+ }
666
+ ],
667
+ "source": [
668
+ "df.to_csv(\"../comparison_results_with_openai.csv\", index=False)\n",
669
+ "print(\"Saved to ../comparison_results_with_openai.csv\")"
670
+ ]
671
+ }
672
+ ],
673
+ "metadata": {
674
+ "kernelspec": {
675
+ "display_name": "Python 3",
676
+ "language": "python",
677
+ "name": "python3"
678
+ },
679
+ "language_info": {
680
+ "codemirror_mode": {
681
+ "name": "ipython",
682
+ "version": 3
683
+ },
684
+ "file_extension": ".py",
685
+ "mimetype": "text/x-python",
686
+ "name": "python",
687
+ "nbconvert_exporter": "python",
688
+ "pygments_lexer": "ipython3",
689
+ "version": "3.14.3"
690
+ }
691
+ },
692
+ "nbformat": 4,
693
+ "nbformat_minor": 2
694
+ }
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ torchvision
4
+ datasets
5
+ evaluate
6
+ accelerate
7
+ scikit-learn
8
+ pillow
9
+ gradio
10
+ openai
11
+ huggingface_hub
train_cat_vit.ipynb ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "a0c0c143",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "\n",
14
+ "[notice] A new release of pip is available: 25.3 -> 26.0.1\n",
15
+ "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stdout",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "Requirement already satisfied: matplotlib in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (3.10.8)\n",
23
+ "Requirement already satisfied: ipywidgets in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (8.1.8)\n",
24
+ "Requirement already satisfied: contourpy>=1.0.1 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from matplotlib) (1.3.3)\n",
25
+ "Requirement already satisfied: cycler>=0.10 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from matplotlib) (0.12.1)\n",
26
+ "Requirement already satisfied: fonttools>=4.22.0 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from matplotlib) (4.62.1)\n",
27
+ "Requirement already satisfied: kiwisolver>=1.3.1 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from matplotlib) (1.5.0)\n",
28
+ "Requirement already satisfied: numpy>=1.23 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from matplotlib) (2.4.2)\n",
29
+ "Requirement already satisfied: packaging>=20.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from matplotlib) (26.0)\n",
30
+ "Requirement already satisfied: pillow>=8 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from matplotlib) (12.1.1)\n",
31
+ "Requirement already satisfied: pyparsing>=3 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from matplotlib) (3.3.2)\n",
32
+ "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from matplotlib) (2.9.0.post0)\n",
33
+ "Requirement already satisfied: comm>=0.1.3 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipywidgets) (0.2.3)\n",
34
+ "Requirement already satisfied: ipython>=6.1.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipywidgets) (9.11.0)\n",
35
+ "Requirement already satisfied: traitlets>=4.3.1 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipywidgets) (5.14.3)\n",
36
+ "Requirement already satisfied: widgetsnbextension~=4.0.14 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from ipywidgets) (4.0.15)\n",
37
+ "Requirement already satisfied: jupyterlab_widgets~=3.0.15 in c:\\users\\kathe\\appdata\\local\\python\\pythoncore-3.14-64\\lib\\site-packages (from ipywidgets) (3.0.16)\n",
38
+ "Requirement already satisfied: colorama>=0.4.4 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipython>=6.1.0->ipywidgets) (0.4.6)\n",
39
+ "Requirement already satisfied: decorator>=5.1.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipython>=6.1.0->ipywidgets) (5.2.1)\n",
40
+ "Requirement already satisfied: ipython-pygments-lexers>=1.0.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipython>=6.1.0->ipywidgets) (1.1.1)\n",
41
+ "Requirement already satisfied: jedi>=0.18.2 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipython>=6.1.0->ipywidgets) (0.19.2)\n",
42
+ "Requirement already satisfied: matplotlib-inline>=0.1.6 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipython>=6.1.0->ipywidgets) (0.2.1)\n",
43
+ "Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipython>=6.1.0->ipywidgets) (3.0.52)\n",
44
+ "Requirement already satisfied: pygments>=2.14.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipython>=6.1.0->ipywidgets) (2.19.2)\n",
45
+ "Requirement already satisfied: stack_data>=0.6.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from ipython>=6.1.0->ipywidgets) (0.6.3)\n",
46
+ "Requirement already satisfied: wcwidth in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.6.0)\n",
47
+ "Requirement already satisfied: parso<0.9.0,>=0.8.4 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from jedi>=0.18.2->ipython>=6.1.0->ipywidgets) (0.8.6)\n",
48
+ "Requirement already satisfied: six>=1.5 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)\n",
49
+ "Requirement already satisfied: executing>=1.2.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from stack_data>=0.6.0->ipython>=6.1.0->ipywidgets) (2.2.1)\n",
50
+ "Requirement already satisfied: asttokens>=2.1.0 in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from stack_data>=0.6.0->ipython>=6.1.0->ipywidgets) (3.0.1)\n",
51
+ "Requirement already satisfied: pure-eval in c:\\users\\kathe\\appdata\\roaming\\python\\python314\\site-packages (from stack_data>=0.6.0->ipython>=6.1.0->ipywidgets) (0.2.3)\n",
52
+ "Note: you may need to restart the kernel to use updated packages.\n",
53
+ "5.5.0\n",
54
+ "1.13.0\n"
55
+ ]
56
+ }
57
+ ],
58
+ "source": [
59
+ "# Install packages\n",
60
+ "%pip install matplotlib ipywidgets\n",
61
+ "\n",
62
+ "# Imports\n",
63
+ "import numpy as np\n",
64
+ "import matplotlib.pyplot as plt\n",
65
+ "import torch\n",
66
+ "\n",
67
+ "from datasets import load_dataset, DatasetDict\n",
68
+ "from transformers import AutoImageProcessor, ViTForImageClassification\n",
69
+ "from transformers import Trainer, TrainingArguments\n",
70
+ "\n",
71
+ "import evaluate\n",
72
+ "import transformers\n",
73
+ "import accelerate\n",
74
+ "\n",
75
+ "\n",
76
+ "print(transformers.__version__)\n",
77
+ "print(accelerate.__version__)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 2,
83
+ "id": "3e3aa822",
84
+ "metadata": {},
85
+ "outputs": [
86
+ {
87
+ "data": {
88
+ "application/vnd.jupyter.widget-view+json": {
89
+ "model_id": "339917e702894b88b0e14dd328b3c811",
90
+ "version_major": 2,
91
+ "version_minor": 0
92
+ },
93
+ "text/plain": [
94
+ "Resolving data files: 0%| | 0/241 [00:00<?, ?it/s]"
95
+ ]
96
+ },
97
+ "metadata": {},
98
+ "output_type": "display_data"
99
+ },
100
+ {
101
+ "data": {
102
+ "text/plain": [
103
+ "DatasetDict({\n",
104
+ " train: Dataset({\n",
105
+ " features: ['image', 'label'],\n",
106
+ " num_rows: 241\n",
107
+ " })\n",
108
+ "})"
109
+ ]
110
+ },
111
+ "execution_count": 2,
112
+ "metadata": {},
113
+ "output_type": "execute_result"
114
+ }
115
+ ],
116
+ "source": [
117
+ "#Dataset laden\n",
118
+ "dataset = load_dataset(\"imagefolder\", data_dir=\"Cats\")\n",
119
+ "dataset"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 3,
125
+ "id": "63ecc9fb",
126
+ "metadata": {},
127
+ "outputs": [
128
+ {
129
+ "name": "stdout",
130
+ "output_type": "stream",
131
+ "text": [
132
+ "Label names: ['Cheetah', 'Leopard', 'Lion', 'Puma', 'Tiger']\n",
133
+ "Label ids: [0, 1, 2, 3, 4]\n",
134
+ "Number of classes: 5\n"
135
+ ]
136
+ }
137
+ ],
138
+ "source": [
139
+ "#Labels prüfen\n",
140
+ "label_names = dataset[\"train\"].features[\"label\"].names\n",
141
+ "labels = dataset[\"train\"].unique(\"label\")\n",
142
+ "\n",
143
+ "print(\"Label names:\", label_names)\n",
144
+ "print(\"Label ids:\", labels)\n",
145
+ "print(\"Number of classes:\", len(label_names))"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 4,
151
+ "id": "bb4293e4",
152
+ "metadata": {},
153
+ "outputs": [
154
+ {
155
+ "data": {
156
+ "text/plain": [
157
+ "DatasetDict({\n",
158
+ " train: Dataset({\n",
159
+ " features: ['image', 'label'],\n",
160
+ " num_rows: 192\n",
161
+ " })\n",
162
+ " validation: Dataset({\n",
163
+ " features: ['image', 'label'],\n",
164
+ " num_rows: 24\n",
165
+ " })\n",
166
+ " test: Dataset({\n",
167
+ " features: ['image', 'label'],\n",
168
+ " num_rows: 25\n",
169
+ " })\n",
170
+ "})"
171
+ ]
172
+ },
173
+ "execution_count": 4,
174
+ "metadata": {},
175
+ "output_type": "execute_result"
176
+ }
177
+ ],
178
+ "source": [
179
+ "#Train / Validation / Test splitten\n",
180
+ "split_dataset = dataset[\"train\"].train_test_split(test_size=0.2, seed=42)\n",
181
+ "eval_dataset = split_dataset[\"test\"].train_test_split(test_size=0.5, seed=42)\n",
182
+ "\n",
183
+ "our_dataset = DatasetDict({\n",
184
+ " \"train\": split_dataset[\"train\"],\n",
185
+ " \"validation\": eval_dataset[\"train\"],\n",
186
+ " \"test\": eval_dataset[\"test\"]\n",
187
+ "})\n",
188
+ "\n",
189
+ "our_dataset"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 5,
195
+ "id": "a5d24190",
196
+ "metadata": {},
197
+ "outputs": [
198
+ {
199
+ "name": "stdout",
200
+ "output_type": "stream",
201
+ "text": [
202
+ "{'Cheetah': '0', 'Leopard': '1', 'Lion': '2', 'Puma': '3', 'Tiger': '4'}\n",
203
+ "{'0': 'Cheetah', '1': 'Leopard', '2': 'Lion', '3': 'Puma', '4': 'Tiger'}\n"
204
+ ]
205
+ }
206
+ ],
207
+ "source": [
208
+ "#Label-Mappings\n",
209
+ "label2id = {label: str(i) for i, label in enumerate(label_names)}\n",
210
+ "id2label = {str(i): label for i, label in enumerate(label_names)}\n",
211
+ "\n",
212
+ "print(label2id)\n",
213
+ "print(id2label)"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": 6,
219
+ "id": "dc887218",
220
+ "metadata": {},
221
+ "outputs": [
222
+ {
223
+ "data": {
224
+ "text/plain": [
225
+ "ViTImageProcessor {\n",
226
+ " \"do_normalize\": true,\n",
227
+ " \"do_rescale\": true,\n",
228
+ " \"do_resize\": true,\n",
229
+ " \"image_mean\": [\n",
230
+ " 0.5,\n",
231
+ " 0.5,\n",
232
+ " 0.5\n",
233
+ " ],\n",
234
+ " \"image_processor_type\": \"ViTImageProcessor\",\n",
235
+ " \"image_std\": [\n",
236
+ " 0.5,\n",
237
+ " 0.5,\n",
238
+ " 0.5\n",
239
+ " ],\n",
240
+ " \"resample\": 2,\n",
241
+ " \"rescale_factor\": 0.00392156862745098,\n",
242
+ " \"size\": {\n",
243
+ " \"height\": 224,\n",
244
+ " \"width\": 224\n",
245
+ " }\n",
246
+ "}"
247
+ ]
248
+ },
249
+ "execution_count": 6,
250
+ "metadata": {},
251
+ "output_type": "execute_result"
252
+ }
253
+ ],
254
+ "source": [
255
+ "#Image Processor\n",
256
+ "processor = AutoImageProcessor.from_pretrained(\"google/vit-base-patch16-224\")\n",
257
+ "processor"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": 7,
263
+ "id": "ac8ed1d2",
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "#Transforms\n",
268
+ "def transforms(batch):\n",
269
+ " images = [img.convert(\"RGB\") for img in batch[\"image\"]]\n",
270
+ " inputs = processor(images, return_tensors=\"pt\")\n",
271
+ " inputs[\"labels\"] = batch[\"label\"]\n",
272
+ " return inputs\n",
273
+ "\n",
274
+ "processed_dataset = our_dataset.with_transform(transforms)"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": 8,
280
+ "id": "b566cf1c",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "#Collate Function\n",
285
+ "def collate_fn(batch):\n",
286
+ " return {\n",
287
+ " \"pixel_values\": torch.stack([x[\"pixel_values\"] for x in batch]),\n",
288
+ " \"labels\": torch.tensor([x[\"labels\"] for x in batch])\n",
289
+ " }"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": 9,
295
+ "id": "3e90e19f",
296
+ "metadata": {},
297
+ "outputs": [
298
+ {
299
+ "data": {
300
+ "application/vnd.jupyter.widget-view+json": {
301
+ "model_id": "1a0b545f019c4a05a49c5b72a517da66",
302
+ "version_major": 2,
303
+ "version_minor": 0
304
+ },
305
+ "text/plain": [
306
+ "Downloading builder script: 0.00B [00:00, ?B/s]"
307
+ ]
308
+ },
309
+ "metadata": {},
310
+ "output_type": "display_data"
311
+ },
312
+ {
313
+ "data": {
314
+ "application/vnd.jupyter.widget-view+json": {
315
+ "model_id": "483365e5f75c48c9a24e88e1d9a05ef6",
316
+ "version_major": 2,
317
+ "version_minor": 0
318
+ },
319
+ "text/plain": [
320
+ "Downloading builder script: 0.00B [00:00, ?B/s]"
321
+ ]
322
+ },
323
+ "metadata": {},
324
+ "output_type": "display_data"
325
+ },
326
+ {
327
+ "data": {
328
+ "application/vnd.jupyter.widget-view+json": {
329
+ "model_id": "189c19a1212f46ab95b7e769b441e2d1",
330
+ "version_major": 2,
331
+ "version_minor": 0
332
+ },
333
+ "text/plain": [
334
+ "Downloading builder script: 0.00B [00:00, ?B/s]"
335
+ ]
336
+ },
337
+ "metadata": {},
338
+ "output_type": "display_data"
339
+ }
340
+ ],
341
+ "source": [
342
+ "#Metriken\n",
343
+ "accuracy_metric = evaluate.load(\"accuracy\")\n",
344
+ "precision_metric = evaluate.load(\"precision\")\n",
345
+ "recall_metric = evaluate.load(\"recall\")\n",
346
+ "f1_metric = evaluate.load(\"f1\")\n",
347
+ "\n",
348
+ "def compute_metrics(eval_pred):\n",
349
+ " logits, labels = eval_pred\n",
350
+ " predictions = np.argmax(logits, axis=1)\n",
351
+ "\n",
352
+ " accuracy = accuracy_metric.compute(predictions=predictions, references=labels)[\"accuracy\"]\n",
353
+ " precision = precision_metric.compute(predictions=predictions, references=labels, average=\"weighted\")[\"precision\"]\n",
354
+ " recall = recall_metric.compute(predictions=predictions, references=labels, average=\"weighted\")[\"recall\"]\n",
355
+ " f1 = f1_metric.compute(predictions=predictions, references=labels, average=\"weighted\")[\"f1\"]\n",
356
+ "\n",
357
+ " return {\n",
358
+ " \"accuracy\": accuracy,\n",
359
+ " \"precision\": precision,\n",
360
+ " \"recall\": recall,\n",
361
+ " \"f1\": f1\n",
362
+ " }"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": 10,
368
+ "id": "87f65a9b",
369
+ "metadata": {},
370
+ "outputs": [
371
+ {
372
+ "name": "stderr",
373
+ "output_type": "stream",
374
+ "text": [
375
+ "You passed `num_labels=5` which is incompatible to the `id2label` map of length `1000`.\n"
376
+ ]
377
+ },
378
+ {
379
+ "data": {
380
+ "application/vnd.jupyter.widget-view+json": {
381
+ "model_id": "63d5a7739fff49af855c2f7278a74df7",
382
+ "version_major": 2,
383
+ "version_minor": 0
384
+ },
385
+ "text/plain": [
386
+ "Loading weights: 0%| | 0/200 [00:00<?, ?it/s]"
387
+ ]
388
+ },
389
+ "metadata": {},
390
+ "output_type": "display_data"
391
+ },
392
+ {
393
+ "name": "stderr",
394
+ "output_type": "stream",
395
+ "text": [
396
+ "\u001b[1mViTForImageClassification LOAD REPORT\u001b[0m from: google/vit-base-patch16-224\n",
397
+ "Key | Status | \n",
398
+ "------------------+----------+------------------------------------------------------------------------------------------\n",
399
+ "classifier.bias | MISMATCH | Reinit due to size mismatch - ckpt: torch.Size([1000]) vs model:torch.Size([5]) \n",
400
+ "classifier.weight | MISMATCH | Reinit due to size mismatch - ckpt: torch.Size([1000, 768]) vs model:torch.Size([5, 768])\n",
401
+ "\n",
402
+ "Notes:\n",
403
+ "- MISMATCH:\tckpt weights were loaded, but they did not match the original empty weight shapes.\n"
404
+ ]
405
+ }
406
+ ],
407
+ "source": [
408
+ "#Modell laden\n",
409
+ "model = ViTForImageClassification.from_pretrained(\n",
410
+ " \"google/vit-base-patch16-224\",\n",
411
+ " num_labels=len(label_names),\n",
412
+ " id2label={int(k): v for k, v in id2label.items()},\n",
413
+ " label2id=label2id,\n",
414
+ " ignore_mismatched_sizes=True\n",
415
+ ")"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": 11,
421
+ "id": "78883db4",
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "#Backbone einfrieren\n",
426
+ "for name, param in model.named_parameters():\n",
427
+ " if not name.startswith(\"classifier\"):\n",
428
+ " param.requires_grad = False"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "code",
433
+ "execution_count": 12,
434
+ "id": "2dc7e9f0",
435
+ "metadata": {},
436
+ "outputs": [],
437
+ "source": [
438
+ "#TrainingArguments\n",
439
+ "training_args = TrainingArguments(\n",
440
+ " output_dir=\"./cat-vit\",\n",
441
+ " per_device_train_batch_size=16,\n",
442
+ " per_device_eval_batch_size=16,\n",
443
+ " eval_strategy=\"epoch\",\n",
444
+ " save_strategy=\"epoch\",\n",
445
+ " logging_steps=20,\n",
446
+ " num_train_epochs=5,\n",
447
+ " learning_rate=3e-4,\n",
448
+ " save_total_limit=2,\n",
449
+ " remove_unused_columns=False,\n",
450
+ " push_to_hub=True,\n",
451
+ " load_best_model_at_end=True,\n",
452
+ " metric_for_best_model=\"accuracy\",\n",
453
+ " greater_is_better=True,\n",
454
+ " report_to=\"none\",\n",
455
+ " disable_tqdm=True,\n",
456
+ " run_name=\"cat-vit-transfer-learning\"\n",
457
+ ")"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": 13,
463
+ "id": "1e3d4feb",
464
+ "metadata": {},
465
+ "outputs": [],
466
+ "source": [
467
+ "#Trainer\n",
468
+ "trainer = Trainer(\n",
469
+ " model=model,\n",
470
+ " args=training_args,\n",
471
+ " train_dataset=processed_dataset[\"train\"],\n",
472
+ " eval_dataset=processed_dataset[\"validation\"],\n",
473
+ " data_collator=collate_fn,\n",
474
+ " compute_metrics=compute_metrics,\n",
475
+ " processing_class=processor\n",
476
+ ")\n",
477
+ "\n"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "execution_count": 14,
483
+ "id": "2a8b4894",
484
+ "metadata": {},
485
+ "outputs": [
486
+ {
487
+ "name": "stderr",
488
+ "output_type": "stream",
489
+ "text": [
490
+ "c:\\Users\\kathe\\AppData\\Local\\Python\\pythoncore-3.14-64\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n",
491
+ " super().__init__(loader)\n"
492
+ ]
493
+ },
494
+ {
495
+ "name": "stdout",
496
+ "output_type": "stream",
497
+ "text": [
498
+ "{'eval_loss': '1.082', 'eval_accuracy': '0.875', 'eval_precision': '0.9018', 'eval_recall': '0.875', 'eval_f1': '0.8627', 'eval_runtime': '3.233', 'eval_samples_per_second': '7.423', 'eval_steps_per_second': '0.619', 'epoch': '1'}\n"
499
+ ]
500
+ },
501
+ {
502
+ "data": {
503
+ "application/vnd.jupyter.widget-view+json": {
504
+ "model_id": "262534a235fb4b32bd6cd2ed35146c98",
505
+ "version_major": 2,
506
+ "version_minor": 0
507
+ },
508
+ "text/plain": [
509
+ "Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]"
510
+ ]
511
+ },
512
+ "metadata": {},
513
+ "output_type": "display_data"
514
+ },
515
+ {
516
+ "name": "stderr",
517
+ "output_type": "stream",
518
+ "text": [
519
+ "c:\\Users\\kathe\\AppData\\Local\\Python\\pythoncore-3.14-64\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n",
520
+ " super().__init__(loader)\n"
521
+ ]
522
+ },
523
+ {
524
+ "name": "stdout",
525
+ "output_type": "stream",
526
+ "text": [
527
+ "{'loss': '1.151', 'grad_norm': '5.051', 'learning_rate': '0.000205', 'epoch': '1.667'}\n",
528
+ "{'eval_loss': '0.7125', 'eval_accuracy': '0.9167', 'eval_precision': '0.9278', 'eval_recall': '0.9167', 'eval_f1': '0.9139', 'eval_runtime': '3.441', 'eval_samples_per_second': '6.976', 'eval_steps_per_second': '0.581', 'epoch': '2'}\n"
529
+ ]
530
+ },
531
+ {
532
+ "data": {
533
+ "application/vnd.jupyter.widget-view+json": {
534
+ "model_id": "d8e710d66cef43cea6bff1d5e03e76f5",
535
+ "version_major": 2,
536
+ "version_minor": 0
537
+ },
538
+ "text/plain": [
539
+ "Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]"
540
+ ]
541
+ },
542
+ "metadata": {},
543
+ "output_type": "display_data"
544
+ },
545
+ {
546
+ "name": "stderr",
547
+ "output_type": "stream",
548
+ "text": [
549
+ "c:\\Users\\kathe\\AppData\\Local\\Python\\pythoncore-3.14-64\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n",
550
+ " super().__init__(loader)\n"
551
+ ]
552
+ },
553
+ {
554
+ "name": "stdout",
555
+ "output_type": "stream",
556
+ "text": [
557
+ "{'eval_loss': '0.5354', 'eval_accuracy': '0.9167', 'eval_precision': '0.9278', 'eval_recall': '0.9167', 'eval_f1': '0.9139', 'eval_runtime': '3.425', 'eval_samples_per_second': '7.006', 'eval_steps_per_second': '0.584', 'epoch': '3'}\n"
558
+ ]
559
+ },
560
+ {
561
+ "data": {
562
+ "application/vnd.jupyter.widget-view+json": {
563
+ "model_id": "efcce8388767474eac235d4f294c6ed0",
564
+ "version_major": 2,
565
+ "version_minor": 0
566
+ },
567
+ "text/plain": [
568
+ "Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]"
569
+ ]
570
+ },
571
+ "metadata": {},
572
+ "output_type": "display_data"
573
+ },
574
+ {
575
+ "name": "stderr",
576
+ "output_type": "stream",
577
+ "text": [
578
+ "c:\\Users\\kathe\\AppData\\Local\\Python\\pythoncore-3.14-64\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n",
579
+ " super().__init__(loader)\n"
580
+ ]
581
+ },
582
+ {
583
+ "name": "stdout",
584
+ "output_type": "stream",
585
+ "text": [
586
+ "{'loss': '0.5336', 'grad_norm': '3.152', 'learning_rate': '0.000105', 'epoch': '3.333'}\n",
587
+ "{'eval_loss': '0.4571', 'eval_accuracy': '0.9167', 'eval_precision': '0.9278', 'eval_recall': '0.9167', 'eval_f1': '0.9139', 'eval_runtime': '3.065', 'eval_samples_per_second': '7.83', 'eval_steps_per_second': '0.652', 'epoch': '4'}\n"
588
+ ]
589
+ },
590
+ {
591
+ "data": {
592
+ "application/vnd.jupyter.widget-view+json": {
593
+ "model_id": "05f39cf4209244939ce265accaebc9bf",
594
+ "version_major": 2,
595
+ "version_minor": 0
596
+ },
597
+ "text/plain": [
598
+ "Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]"
599
+ ]
600
+ },
601
+ "metadata": {},
602
+ "output_type": "display_data"
603
+ },
604
+ {
605
+ "name": "stderr",
606
+ "output_type": "stream",
607
+ "text": [
608
+ "c:\\Users\\kathe\\AppData\\Local\\Python\\pythoncore-3.14-64\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n",
609
+ " super().__init__(loader)\n"
610
+ ]
611
+ },
612
+ {
613
+ "name": "stdout",
614
+ "output_type": "stream",
615
+ "text": [
616
+ "{'loss': '0.3465', 'grad_norm': '2.518', 'learning_rate': '5e-06', 'epoch': '5'}\n",
617
+ "{'eval_loss': '0.4346', 'eval_accuracy': '0.9167', 'eval_precision': '0.9219', 'eval_recall': '0.9167', 'eval_f1': '0.9139', 'eval_runtime': '3.323', 'eval_samples_per_second': '7.222', 'eval_steps_per_second': '0.602', 'epoch': '5'}\n"
618
+ ]
619
+ },
620
+ {
621
+ "data": {
622
+ "application/vnd.jupyter.widget-view+json": {
623
+ "model_id": "d3f7a6a821604a4faa19a99e47d553a4",
624
+ "version_major": 2,
625
+ "version_minor": 0
626
+ },
627
+ "text/plain": [
628
+ "Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]"
629
+ ]
630
+ },
631
+ "metadata": {},
632
+ "output_type": "display_data"
633
+ },
634
+ {
635
+ "name": "stdout",
636
+ "output_type": "stream",
637
+ "text": [
638
+ "{'train_runtime': '167.4', 'train_samples_per_second': '5.736', 'train_steps_per_second': '0.358', 'train_loss': '0.6771', 'epoch': '5'}\n"
639
+ ]
640
+ },
641
+ {
642
+ "name": "stderr",
643
+ "output_type": "stream",
644
+ "text": [
645
+ "c:\\Users\\kathe\\AppData\\Local\\Python\\pythoncore-3.14-64\\Lib\\site-packages\\torch\\utils\\data\\dataloader.py:775: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.\n",
646
+ " super().__init__(loader)\n"
647
+ ]
648
+ },
649
+ {
650
+ "name": "stdout",
651
+ "output_type": "stream",
652
+ "text": [
653
+ "{'eval_loss': '0.6814', 'eval_accuracy': '0.96', 'eval_precision': '0.97', 'eval_recall': '0.96', 'eval_f1': '0.96', 'eval_runtime': '3.285', 'eval_samples_per_second': '7.611', 'eval_steps_per_second': '0.609', 'epoch': '5'}\n"
654
+ ]
655
+ },
656
+ {
657
+ "data": {
658
+ "text/plain": [
659
+ "{'eval_loss': 0.681404709815979,\n",
660
+ " 'eval_accuracy': 0.96,\n",
661
+ " 'eval_precision': 0.97,\n",
662
+ " 'eval_recall': 0.96,\n",
663
+ " 'eval_f1': 0.96,\n",
664
+ " 'eval_runtime': 3.2846,\n",
665
+ " 'eval_samples_per_second': 7.611,\n",
666
+ " 'eval_steps_per_second': 0.609,\n",
667
+ " 'epoch': 5.0}"
668
+ ]
669
+ },
670
+ "execution_count": 14,
671
+ "metadata": {},
672
+ "output_type": "execute_result"
673
+ }
674
+ ],
675
+ "source": [
676
+ "#Trainieren\n",
677
+ "trainer.train()\n",
678
+ "test_results = trainer.evaluate(processed_dataset[\"test\"])\n",
679
+ "test_results"
680
+ ]
681
+ },
682
+ {
683
+ "cell_type": "code",
684
+ "execution_count": 15,
685
+ "id": "026d1a8f",
686
+ "metadata": {},
687
+ "outputs": [],
688
+ "source": [
689
+ "#Test Evaluation !ReadME!\n",
690
+ "#trainer.evaluate(processed_dataset['test'])"
691
+ ]
692
+ },
693
+ {
694
+ "cell_type": "code",
695
+ "execution_count": 16,
696
+ "id": "ca5b4010",
697
+ "metadata": {},
698
+ "outputs": [
699
+ {
700
+ "data": {
701
+ "application/vnd.jupyter.widget-view+json": {
702
+ "model_id": "e12f1e64b3a2438f8479118470a17561",
703
+ "version_major": 2,
704
+ "version_minor": 0
705
+ },
706
+ "text/plain": [
707
+ "Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]"
708
+ ]
709
+ },
710
+ "metadata": {},
711
+ "output_type": "display_data"
712
+ },
713
+ {
714
+ "data": {
715
+ "application/vnd.jupyter.widget-view+json": {
716
+ "model_id": "2e77b9cf34c34a489fc01d7137cb6714",
717
+ "version_major": 2,
718
+ "version_minor": 0
719
+ },
720
+ "text/plain": [
721
+ "Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]"
722
+ ]
723
+ },
724
+ "metadata": {},
725
+ "output_type": "display_data"
726
+ },
727
+ {
728
+ "data": {
729
+ "application/vnd.jupyter.widget-view+json": {
730
+ "model_id": "96a93af3c1974290a0439d3d6089b7b5",
731
+ "version_major": 2,
732
+ "version_minor": 0
733
+ },
734
+ "text/plain": [
735
+ "Processing Files (0 / 0): | | 0.00B / 0.00B "
736
+ ]
737
+ },
738
+ "metadata": {},
739
+ "output_type": "display_data"
740
+ },
741
+ {
742
+ "data": {
743
+ "application/vnd.jupyter.widget-view+json": {
744
+ "model_id": "81545cef38fe4403ab0c287fffaf85de",
745
+ "version_major": 2,
746
+ "version_minor": 0
747
+ },
748
+ "text/plain": [
749
+ "New Data Upload: | | 0.00B / 0.00B "
750
+ ]
751
+ },
752
+ "metadata": {},
753
+ "output_type": "display_data"
754
+ },
755
+ {
756
+ "data": {
757
+ "application/vnd.jupyter.widget-view+json": {
758
+ "model_id": "dcc7dd6286f348ceaa474c41639d3fa1",
759
+ "version_major": 2,
760
+ "version_minor": 0
761
+ },
762
+ "text/plain": [
763
+ "Writing model shards: 0%| | 0/1 [00:00<?, ?it/s]"
764
+ ]
765
+ },
766
+ "metadata": {},
767
+ "output_type": "display_data"
768
+ },
769
+ {
770
+ "data": {
771
+ "application/vnd.jupyter.widget-view+json": {
772
+ "model_id": "8be396ac33304e8aa6d04654792985b9",
773
+ "version_major": 2,
774
+ "version_minor": 0
775
+ },
776
+ "text/plain": [
777
+ "Processing Files (0 / 0): | | 0.00B / 0.00B "
778
+ ]
779
+ },
780
+ "metadata": {},
781
+ "output_type": "display_data"
782
+ },
783
+ {
784
+ "data": {
785
+ "application/vnd.jupyter.widget-view+json": {
786
+ "model_id": "ce487973666a465bb4f9111693e49fea",
787
+ "version_major": 2,
788
+ "version_minor": 0
789
+ },
790
+ "text/plain": [
791
+ "New Data Upload: | | 0.00B / 0.00B "
792
+ ]
793
+ },
794
+ "metadata": {},
795
+ "output_type": "display_data"
796
+ },
797
+ {
798
+ "data": {
799
+ "text/plain": [
800
+ "CommitInfo(commit_url='https://huggingface.co/DKatheesrupan/cat-vit/commit/05e008b778df7e8e7dcbab9ef293490315c2609a', commit_message='cat-vit-classifier', commit_description='', oid='05e008b778df7e8e7dcbab9ef293490315c2609a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/DKatheesrupan/cat-vit', endpoint='https://huggingface.co', repo_type='model', repo_id='DKatheesrupan/cat-vit'), pr_revision=None, pr_num=None)"
801
+ ]
802
+ },
803
+ "execution_count": 16,
804
+ "metadata": {},
805
+ "output_type": "execute_result"
806
+ }
807
+ ],
808
+ "source": [
809
+ "#Modell pushen\n",
810
+ "kwargs = {\n",
811
+ " \"finetuned_from\": \"google/vit-base-patch16-224\",\n",
812
+ " \"dataset\": \"custom cat dataset\",\n",
813
+ " \"tasks\": \"image-classification\",\n",
814
+ " \"tags\": [\"image-classification\", \"vision-transformer\", \"cats\"]\n",
815
+ "}\n",
816
+ "trainer.save_model()\n",
817
+ "trainer.push_to_hub(\"cat-vit-classifier\", **kwargs)"
818
+ ]
819
+ }
820
+ ],
821
+ "metadata": {
822
+ "kernelspec": {
823
+ "display_name": "Python 3",
824
+ "language": "python",
825
+ "name": "python3"
826
+ },
827
+ "language_info": {
828
+ "codemirror_mode": {
829
+ "name": "ipython",
830
+ "version": 3
831
+ },
832
+ "file_extension": ".py",
833
+ "mimetype": "text/x-python",
834
+ "name": "python",
835
+ "nbconvert_exporter": "python",
836
+ "pygments_lexer": "ipython3",
837
+ "version": "3.14.3"
838
+ }
839
+ },
840
+ "nbformat": 4,
841
+ "nbformat_minor": 5
842
+ }