manasch commited on
Commit
17d2808
1 Parent(s): c37043e

add pace model training notebook

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +48 -10
  3. notebooks/PaceModel.ipynb +584 -0
.gitignore CHANGED
@@ -3,3 +3,5 @@ __pycache__
3
 
4
  *.jpg
5
  *.png
 
 
 
3
 
4
  *.jpg
5
  *.png
6
+
7
+ *.log
app.py CHANGED
@@ -18,9 +18,17 @@ class AudioPalette:
18
  self.image_captioning = ImageCaptioning()
19
 
20
  def generate(self, input_image: PIL.Image.Image):
21
- generated_text = self.image_captioning.query(input_image)[0].get("generated_text")
22
  pace = self.pace_model.predict(input_image)
23
- return pace + (" - " + generated_text if generated_text is not None else "")
 
 
 
 
 
 
 
 
 
24
 
25
  def main():
26
  model = AudioPalette()
@@ -33,14 +41,44 @@ def main():
33
  show_label=True,
34
  container=True
35
  ),
36
- outputs=gr.Textbox(
37
- lines=1,
38
- placeholder="Pace of the image and the caption",
39
- label="Caption and Pace",
40
- show_label=True,
41
- container=True,
42
- type="text"
43
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  cache_examples=False,
45
  live=False,
46
  title="Audio Palette",
 
18
  self.image_captioning = ImageCaptioning()
19
 
20
  def generate(self, input_image: PIL.Image.Image):
 
21
  pace = self.pace_model.predict(input_image)
22
+ print("Pace Prediction Done")
23
+
24
+ generated_text = self.image_captioning.query(input_image)[0].get("generated_text")
25
+ print("Captioning Done")
26
+
27
+ generated_text = generated_text if generated_text is not None else ""
28
+ temp = pace + " - " + generated_text
29
+ outputs = [temp, pace, generated_text]
30
+
31
+ return outputs
32
 
33
  def main():
34
  model = AudioPalette()
 
41
  show_label=True,
42
  container=True
43
  ),
44
+ outputs=[
45
+ gr.Textbox(
46
+ lines=1,
47
+ placeholder="Pace of the image and the caption",
48
+ label="Caption and Pace",
49
+ show_label=True,
50
+ container=True,
51
+ type="text",
52
+ visible=True
53
+ ),
54
+ gr.Textbox(
55
+ lines=1,
56
+ placeholder="Pace of the image",
57
+ label="Pace",
58
+ show_label=True,
59
+ container=True,
60
+ type="text",
61
+ visible=False
62
+ ),
63
+ gr.Textbox(
64
+ lines=1,
65
+ placeholder="Caption for the image",
66
+ label="Caption",
67
+ show_label=True,
68
+ container=True,
69
+ type="text",
70
+ visible=False
71
+ ),
72
+ # gr.Audio(
73
+ # label="Generated Audio",
74
+ # show_label=True,
75
+ # container=True,
76
+ # visible=False,
77
+ # format="wav",
78
+ # autoplay=False,
79
+ # show_download_button=True,
80
+ # )
81
+ ],
82
  cache_examples=False,
83
  live=False,
84
  title="Audio Palette",
notebooks/PaceModel.ipynb ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "source": [
22
+ "import json\n",
23
+ "import shutil\n",
24
+ "from pathlib import Path\n",
25
+ "from keras.applications.resnet50 import ResNet50\n",
26
+ "\n",
27
+ "import cv2\n",
28
+ "import matplotlib.pyplot as plt\n",
29
+ "import pandas as pd\n",
30
+ "\n",
31
+ "from google.colab import files\n",
32
+ "from google.colab.patches import cv2_imshow"
33
+ ],
34
+ "metadata": {
35
+ "id": "wzZknIqDEwBg"
36
+ },
37
+ "execution_count": null,
38
+ "outputs": []
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "source": [
43
+ "kaggle_token = Path(\"/root/.kaggle/kaggle.json\")\n",
44
+ "if not kaggle_token.parent.exists():\n",
45
+ " kaggle_token.parent.mkdir()\n",
46
+ "if not kaggle_token.exists():\n",
47
+ " print(\"Upload token:\")\n",
48
+ " files.upload()\n",
49
+ " shutil.move((Path.cwd() / \"kaggle.json\").as_posix(), kaggle_token.resolve().as_posix())"
50
+ ],
51
+ "metadata": {
52
+ "id": "c0SGUkuUEy6n"
53
+ },
54
+ "execution_count": null,
55
+ "outputs": []
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "source": [
60
+ "!chmod 600 /root/.kaggle/kaggle.json"
61
+ ],
62
+ "metadata": {
63
+ "id": "yufGnL24E0gf"
64
+ },
65
+ "execution_count": null,
66
+ "outputs": []
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "source": [
71
+ "!kaggle d download srbhshinde/flickr8k-sau"
72
+ ],
73
+ "metadata": {
74
+ "id": "84kESXJrE7Zn"
75
+ },
76
+ "execution_count": null,
77
+ "outputs": []
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "source": [
82
+ "!7z x flickr8k-sau.zip"
83
+ ],
84
+ "metadata": {
85
+ "id": "NOyPR6EnE80f"
86
+ },
87
+ "execution_count": null,
88
+ "outputs": []
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "source": [
93
+ "output = Path.cwd() / 'images'\n",
94
+ "if not output.exists():\n",
95
+ " output.mkdir()\n",
96
+ "\n",
97
+ "fast = output / 'fast'\n",
98
+ "med = output / 'medium'\n",
99
+ "slow = output / 'slow'\n",
100
+ "\n",
101
+ "if not fast.exists():\n",
102
+ " fast.mkdir()\n",
103
+ "\n",
104
+ "if not med.exists():\n",
105
+ " med.mkdir()\n",
106
+ "\n",
107
+ "if not slow.exists():\n",
108
+ " slow.mkdir()\n",
109
+ "\n",
110
+ "counter = 0\n",
111
+ "\n",
112
+ "with open(\"finalDataset.csv\") as f:\n",
113
+ " f.readline()\n",
114
+ " image_path = Path.cwd() / 'Flickr_Data' / 'Images'\n",
115
+ " for line in f:\n",
116
+ " idx, image_name, pace = line.strip().split(',')\n",
117
+ " if pace == 'slow':\n",
118
+ " shutil.copy2(image_path / image_name, slow)\n",
119
+ " elif pace == 'fast':\n",
120
+ " shutil.copy2(image_path / image_name, fast)\n",
121
+ " else:\n",
122
+ " shutil.copy2(image_path / image_name, med)\n",
123
+ " counter += 1\n",
124
+ "\n",
125
+ "print(counter)"
126
+ ],
127
+ "metadata": {
128
+ "id": "dY034gfPE9wW"
129
+ },
130
+ "execution_count": null,
131
+ "outputs": []
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "source": [
136
+ "import matplotlib.pyplot as plt\n",
137
+ "import numpy as np\n",
138
+ "import os\n",
139
+ "import PIL\n",
140
+ "import tensorflow as tf\n",
141
+ "from tensorflow import keras\n",
142
+ "from tensorflow.keras import layers\n",
143
+ "from tensorflow.python.keras.layers import Dense, Flatten\n",
144
+ "from tensorflow.keras.models import Sequential\n",
145
+ "from tensorflow.keras.optimizers import Adam"
146
+ ],
147
+ "metadata": {
148
+ "id": "G4uVyuXJGKDj"
149
+ },
150
+ "execution_count": null,
151
+ "outputs": []
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "metadata": {
157
+ "id": "a2Zst5ytENY3"
158
+ },
159
+ "outputs": [],
160
+ "source": [
161
+ "import pathlib\n",
162
+ "data_dir = 'images/'\n",
163
+ "data_dir = pathlib.Path(data_dir)\n",
164
+ "bg = list(data_dir.glob('medium/*'))\n",
165
+ "print(bg[0])\n",
166
+ "PIL.Image.open(str(bg[0]))"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "source": [
172
+ "data_dir = 'images/'\n",
173
+ "img_height, img_width = 224,224\n",
174
+ "batch_size = 32\n",
175
+ "train_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
176
+ " data_dir,\n",
177
+ " validation_split = 0.2,\n",
178
+ " subset = \"training\",\n",
179
+ " seed = 345,\n",
180
+ " label_mode = 'categorical',\n",
181
+ " image_size = (img_height, img_width),\n",
182
+ " batch_size = batch_size\n",
183
+ ")"
184
+ ],
185
+ "metadata": {
186
+ "id": "xlGw3fN5Eb24"
187
+ },
188
+ "execution_count": null,
189
+ "outputs": []
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "source": [
194
+ "val_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
195
+ " data_dir,\n",
196
+ " validation_split = 0.2,\n",
197
+ " subset = \"validation\",\n",
198
+ " seed = 345,\n",
199
+ " label_mode = 'categorical',\n",
200
+ " image_size = (img_height, img_width),\n",
201
+ " batch_size = batch_size\n",
202
+ ")"
203
+ ],
204
+ "metadata": {
205
+ "id": "IN1O5n9WEd2n"
206
+ },
207
+ "execution_count": null,
208
+ "outputs": []
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "source": [
213
+ "class_names = train_ds.class_names\n",
214
+ "print(class_names)"
215
+ ],
216
+ "metadata": {
217
+ "id": "zVXI77J4Ehh3"
218
+ },
219
+ "execution_count": null,
220
+ "outputs": []
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "source": [
225
+ "resnet_model = Sequential()\n",
226
+ "\n",
227
+ "pretrained_model= ResNet50(\n",
228
+ " include_top=False,\n",
229
+ " input_shape=(224,224,3),\n",
230
+ " pooling='avg',classes=211,\n",
231
+ " weights='imagenet')\n",
232
+ "\n",
233
+ "for layer in pretrained_model.layers:\n",
234
+ " layer.trainable=False\n",
235
+ "\n",
236
+ "resnet_model.add(pretrained_model)\n",
237
+ "resnet_model.add(Flatten())\n",
238
+ "resnet_model.add(Dense(1024, activation = 'relu'))\n",
239
+ "resnet_model.add(Dense(256, activation = 'relu'))\n",
240
+ "resnet_model.add(Dense(3, activation = 'softmax'))"
241
+ ],
242
+ "metadata": {
243
+ "id": "ZpR1kjgWEi3_"
244
+ },
245
+ "execution_count": null,
246
+ "outputs": []
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "source": [
251
+ "resnet_model.summary()"
252
+ ],
253
+ "metadata": {
254
+ "id": "2WGlO4VLEpSf"
255
+ },
256
+ "execution_count": null,
257
+ "outputs": []
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "source": [
262
+ "resnet_model.compile(optimizer=Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=['accuracy'])"
263
+ ],
264
+ "metadata": {
265
+ "id": "TPNjjBLqEqwu"
266
+ },
267
+ "execution_count": null,
268
+ "outputs": []
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "source": [
273
+ "epochs = 15\n",
274
+ "checkpoint_filepath = '/tmp/checkpoint'\n",
275
+ "model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n",
276
+ " filepath=checkpoint_filepath,\n",
277
+ " save_weights_only=True,\n",
278
+ " monitor='val_accuracy',\n",
279
+ " mode='max',\n",
280
+ " save_best_only=True)\n",
281
+ "history = resnet_model.fit(\n",
282
+ " train_ds,\n",
283
+ " validation_data = val_ds,\n",
284
+ " epochs = epochs,\n",
285
+ " callbacks=[model_checkpoint_callback]\n",
286
+ ")"
287
+ ],
288
+ "metadata": {
289
+ "id": "rc8VaaypEsJX"
290
+ },
291
+ "execution_count": null,
292
+ "outputs": []
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "source": [
297
+ "resnet_model.load_weights(checkpoint_filepath)"
298
+ ],
299
+ "metadata": {
300
+ "id": "tUCSqCs3JLBu"
301
+ },
302
+ "execution_count": null,
303
+ "outputs": []
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "source": [
308
+ "resnet_model.save('pace_model')"
309
+ ],
310
+ "metadata": {
311
+ "id": "HIsB7PHxfwM-"
312
+ },
313
+ "execution_count": null,
314
+ "outputs": []
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "source": [
319
+ "resnet_model.save_weights('pace_model_weights.h5')"
320
+ ],
321
+ "metadata": {
322
+ "id": "yhiPNdnkgU7C"
323
+ },
324
+ "execution_count": null,
325
+ "outputs": []
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "source": [
330
+ "resnet_model.load_weights('pace_model_weights.h5')"
331
+ ],
332
+ "metadata": {
333
+ "id": "IuXgP1oJhZz1"
334
+ },
335
+ "execution_count": null,
336
+ "outputs": []
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "source": [
341
+ "import cv2\n",
342
+ "image=cv2.imread('danny.png')\n",
343
+ "image_resized= cv2.resize(image, (img_height,img_width))\n",
344
+ "image=np.expand_dims(image_resized,axis=0)\n",
345
+ "print(image.shape)"
346
+ ],
347
+ "metadata": {
348
+ "id": "YXGqeYevKlpR"
349
+ },
350
+ "execution_count": null,
351
+ "outputs": []
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "source": [
356
+ "PIL.Image.open('danny.png')"
357
+ ],
358
+ "metadata": {
359
+ "id": "ZP6ATt0CYe1c"
360
+ },
361
+ "execution_count": null,
362
+ "outputs": []
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "source": [
367
+ "pred=resnet_model.predict(image)\n",
368
+ "print(pred)"
369
+ ],
370
+ "metadata": {
371
+ "id": "i1qeeWVcK3av"
372
+ },
373
+ "execution_count": null,
374
+ "outputs": []
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "source": [
379
+ "output_class=class_names[np.argmax(pred)]\n",
380
+ "print(\"The predicted class is\", output_class)"
381
+ ],
382
+ "metadata": {
383
+ "id": "-FgwW4zDK47O"
384
+ },
385
+ "execution_count": null,
386
+ "outputs": []
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "source": [
391
+ "import matplotlib.image as img\n",
392
+ "import matplotlib.pyplot as plt\n",
393
+ "from scipy.cluster.vq import whiten\n",
394
+ "from scipy.cluster.vq import kmeans\n",
395
+ "import pandas as pd\n",
396
+ "\n",
397
+ "batman_image = img.imread('danny.png')\n",
398
+ "\n",
399
+ "r = []\n",
400
+ "g = []\n",
401
+ "b = []\n",
402
+ "for row in batman_image:\n",
403
+ " for temp_r, temp_g, temp_b, temp in row:\n",
404
+ " r.append(temp_r)\n",
405
+ " g.append(temp_g)\n",
406
+ " b.append(temp_b)\n",
407
+ "\n",
408
+ "batman_df = pd.DataFrame({'red': r,\n",
409
+ " 'green': g,\n",
410
+ " 'blue': b})\n",
411
+ "\n",
412
+ "batman_df['scaled_color_red'] = whiten(batman_df['red'])\n",
413
+ "batman_df['scaled_color_blue'] = whiten(batman_df['blue'])\n",
414
+ "batman_df['scaled_color_green'] = whiten(batman_df['green'])\n",
415
+ "\n",
416
+ "cluster_centers, _ = kmeans(batman_df[['scaled_color_red',\n",
417
+ " 'scaled_color_blue',\n",
418
+ " 'scaled_color_green']], 3)\n",
419
+ "\n",
420
+ "dominant_colors = []\n",
421
+ "\n",
422
+ "red_std, green_std, blue_std = batman_df[['red',\n",
423
+ " 'green',\n",
424
+ " 'blue']].std()\n",
425
+ "\n",
426
+ "for cluster_center in cluster_centers:\n",
427
+ " red_scaled, green_scaled, blue_scaled = cluster_center\n",
428
+ " dominant_colors.append((\n",
429
+ " red_scaled * red_std / 255,\n",
430
+ " green_scaled * green_std / 255,\n",
431
+ " blue_scaled * blue_std / 255\n",
432
+ " ))\n",
433
+ "\n",
434
+ "plt.imshow([dominant_colors])\n",
435
+ "plt.show()"
436
+ ],
437
+ "metadata": {
438
+ "id": "pcUf1oNWpNWt"
439
+ },
440
+ "execution_count": null,
441
+ "outputs": []
442
+ },
443
+ {
444
+ "cell_type": "code",
445
+ "source": [
446
+ "import matplotlib.image as img\n",
447
+ "\n",
448
+ "# Read batman image and print dimensions\n",
449
+ "batman_image = img.imread('for.jpg')\n",
450
+ "print(batman_image.shape)"
451
+ ],
452
+ "metadata": {
453
+ "id": "r4eAlhkupdlS"
454
+ },
455
+ "execution_count": null,
456
+ "outputs": []
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "source": [
461
+ "import pandas as pd\n",
462
+ "from scipy.cluster.vq import whiten\n",
463
+ "\n",
464
+ "# Store RGB values of all pixels in lists r, g and b\n",
465
+ "r = []\n",
466
+ "g = []\n",
467
+ "b = []\n",
468
+ "for row in batman_image:\n",
469
+ " for temp_r, temp_g, temp_b in row:\n",
470
+ " r.append(temp_r)\n",
471
+ " g.append(temp_g)\n",
472
+ " b.append(temp_b)\n",
473
+ "\n",
474
+ "# only printing the size of these lists\n",
475
+ "# as the content is too big\n",
476
+ "print(len(r))\n",
477
+ "print(len(g))\n",
478
+ "print(len(b))\n",
479
+ "\n",
480
+ "# Saving as DataFrame\n",
481
+ "batman_df = pd.DataFrame({'red': r,\n",
482
+ " 'green': g,\n",
483
+ " 'blue': b})\n",
484
+ "\n",
485
+ "# Scaling the values\n",
486
+ "batman_df['scaled_color_red'] = whiten(batman_df['red'])\n",
487
+ "batman_df['scaled_color_blue'] = whiten(batman_df['blue'])\n",
488
+ "batman_df['scaled_color_green'] = whiten(batman_df['green'])"
489
+ ],
490
+ "metadata": {
491
+ "id": "unc3TcVop2qn"
492
+ },
493
+ "execution_count": null,
494
+ "outputs": []
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "source": [
499
+ "import seaborn as sns\n",
500
+ "distortions = []\n",
501
+ "num_clusters = range(1, 7) # range of cluster sizes\n",
502
+ "\n",
503
+ "# Create a list of distortions from the kmeans function\n",
504
+ "for i in num_clusters:\n",
505
+ " cluster_centers, distortion = kmeans(batman_df[['scaled_color_red',\n",
506
+ " 'scaled_color_blue',\n",
507
+ " 'scaled_color_green']], i)\n",
508
+ " distortions.append(distortion)\n",
509
+ "\n",
510
+ "# Create a data frame with two lists, num_clusters and distortions\n",
511
+ "elbow_plot = pd.DataFrame({'num_clusters': num_clusters,\n",
512
+ " 'distortions': distortions})\n",
513
+ "\n",
514
+ "# Create a line plot of num_clusters and distortions\n",
515
+ "sns.lineplot(x='num_clusters', y='distortions', data=elbow_plot)\n",
516
+ "plt.xticks(num_clusters)\n",
517
+ "plt.show()"
518
+ ],
519
+ "metadata": {
520
+ "id": "NE7I1771qAPK"
521
+ },
522
+ "execution_count": null,
523
+ "outputs": []
524
+ },
525
+ {
526
+ "cell_type": "code",
527
+ "source": [
528
+ "cluster_centers, _ = kmeans(batman_df[['scaled_color_red',\n",
529
+ " 'scaled_color_blue',\n",
530
+ " 'scaled_color_green']], 3)\n",
531
+ "\n",
532
+ "dominant_colors = []\n",
533
+ "\n",
534
+ "# Get standard deviations of each color\n",
535
+ "red_std, green_std, blue_std = batman_df[['red',\n",
536
+ " 'green',\n",
537
+ " 'blue']].std()\n",
538
+ "\n",
539
+ "for cluster_center in cluster_centers:\n",
540
+ " red_scaled, green_scaled, blue_scaled = cluster_center\n",
541
+ "\n",
542
+ " # Convert each standardized value to scaled value\n",
543
+ " dominant_colors.append((\n",
544
+ " red_scaled * red_std / 255,\n",
545
+ " green_scaled * green_std / 255,\n",
546
+ " blue_scaled * blue_std / 255\n",
547
+ " ))\n",
548
+ "\n",
549
+ "# Display colors of cluster centers\n",
550
+ "plt.imshow([dominant_colors])\n",
551
+ "plt.show()"
552
+ ],
553
+ "metadata": {
554
+ "id": "sfE-qoULqHQM"
555
+ },
556
+ "execution_count": null,
557
+ "outputs": []
558
+ },
559
+ {
560
+ "cell_type": "code",
561
+ "source": [
562
+ "from webcolors import rgb_to_name\n",
563
+ "\n",
564
+ "for i in dominant_colors:\n",
565
+ " named_color = rgb_to_name(i, spec='css3')\n",
566
+ " print(named_color)"
567
+ ],
568
+ "metadata": {
569
+ "id": "1U86FPKkqPPV"
570
+ },
571
+ "execution_count": null,
572
+ "outputs": []
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "source": [],
577
+ "metadata": {
578
+ "id": "xolPh4bIqsU_"
579
+ },
580
+ "execution_count": null,
581
+ "outputs": []
582
+ }
583
+ ]
584
+ }