abdullahmeda commited on
Commit
a4d40bc
1 Parent(s): b9309ba
README.md CHANGED
@@ -1,12 +1,37 @@
1
  ---
2
- title: Test
3
- emoji: 🐠
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.0.20
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Poster2plot
3
+ emoji: 🎬
4
+ colorFrom: purple
5
+ colorTo: purple
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
  ---
10
 
11
+ # Configuration
12
+
13
+ `title`: _string_
14
+ Display title for the Space
15
+
16
+ `emoji`: _string_
17
+ Space emoji (emoji-only character allowed)
18
+
19
+ `colorFrom`: _string_
20
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
+
22
+ `colorTo`: _string_
23
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
+
25
+ `sdk`: _string_
26
+ Can be either `gradio` or `streamlit`
27
+
28
+ `sdk_version` : _string_
29
+ Only applicable for `streamlit` SDK.
30
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
+
32
+ `app_file`: _string_
33
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
+ Path is relative to the root of the repository.
35
+
36
+ `pinned`: _boolean_
37
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import re
4
+ import gradio as gr
5
+ from pathlib import Path
6
+ from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
7
+
8
+
9
+ # Pattern to ignore all the text after 2 or more full stops
10
+ regex_pattern = "[.]{2,}"
11
+
12
+
13
+ def post_process(text):
14
+ try:
15
+ text = text.strip()
16
+ text = re.split(regex_pattern, text)[0]
17
+ except Exception as e:
18
+ print(e)
19
+ pass
20
+ return text
21
+
22
+
23
+ def set_example_image(example: list) -> dict:
24
+ return gr.Image.update(value=example[0])
25
+
26
+
27
+ def predict(image, max_length=64, num_beams=4):
28
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
29
+ pixel_values = pixel_values.to(device)
30
+
31
+ with torch.no_grad():
32
+ output_ids = model.generate(
33
+ pixel_values,
34
+ max_length=max_length,
35
+ num_beams=num_beams,
36
+ return_dict_in_generate=True,
37
+ ).sequences
38
+
39
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
40
+ pred = post_process(preds[0])
41
+
42
+ return pred
43
+
44
+
45
+ model_name_or_path = "deepklarity/poster2plot"
46
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47
+
48
+ # Load model.
49
+
50
+ model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path)
51
+ model.to(device)
52
+ print("Loaded model")
53
+
54
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path)
55
+ print("Loaded feature_extractor")
56
+
57
+ tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
58
+ if model.decoder.name_or_path == "gpt2":
59
+ tokenizer.pad_token = tokenizer.eos_token
60
+
61
+ print("Loaded tokenizer")
62
+
63
+ title = "Poster2Plot: Upload a Movie/T.V show poster to generate a plot"
64
+ description = ""
65
+
66
+ input = gr.inputs.Image(type="pil")
67
+
68
+ example_images = sorted(
69
+ [f.as_posix() for f in Path("examples").glob("*.jpg")]
70
+ )
71
+ print(f"Loaded {len(example_images)} example images")
72
+
73
+ demo = gr.Blocks()
74
+ filenames = next(os.walk('examples'), (None, None, []))[2]
75
+ examples = [[f"examples/{filename}"] for filename in filenames]
76
+ print(examples)
77
+
78
+ with demo:
79
+ with gr.Column():
80
+ with gr.Row():
81
+ with gr.Column():
82
+ input_image = gr.Image()
83
+ with gr.Row():
84
+ clear_button = gr.Button(value="Clear", variant='secondary')
85
+ submit_button = gr.Button(value="Submit", variant='primary')
86
+ with gr.Column():
87
+ plot = gr.Textbox()
88
+ with gr.Row():
89
+ example_images = gr.Dataset(components=[input_image], samples=examples)
90
+
91
+ submit_button.click(fn=predict, inputs=[input_image], outputs=[plot])
92
+ example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components)
93
+
94
+ demo.launch()
95
+
96
+
97
+ interface = gr.Interface(
98
+ fn=predict,
99
+ inputs=input,
100
+ outputs="textbox",
101
+ title=title,
102
+ description=description,
103
+ examples=example_images,
104
+ examples_per_page=20,
105
+ live=True,
106
+ article='<p>Made by: <a href="https://twitter.com/kartik_godawat" target="_blank" rel="noopener noreferrer">dk-crazydiv</a> and <a href="https://twitter.com/dsr_ai" target="_blank" rel="noopener noreferrer">dsr</a></p>'
107
+ )
108
+
109
+ interface.launch()
examples/tt0068646-the-godfather.jpg ADDED
examples/tt0076759-star-wars.jpg ADDED
examples/tt0108778-friends.jpg ADDED
examples/tt0109830-forrest-gump.jpg ADDED
examples/tt0434409-v-for-vendetta.jpg ADDED
examples/tt10062292-never-have-i-ever.jpg ADDED
examples/tt10919420-squid-games.jpg ADDED
examples/tt3521164-moana.jpg ADDED
examples/tt6468322-money-heist.jpg ADDED
examples/tt7991608-red-notice.jpg ADDED
examples/tt8366590-baaghi3.jpg ADDED
flagged/image/0.jpg ADDED
flagged/image/1.jpg ADDED
flagged/log.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ 'image','output','flag','username','timestamp'
2
+ 'image/0.jpg','A young woman is forced to deal with her past when she is accused of murder. She tries to find out what happened to her husband, who is also accused of the crime. Will she be able to solve the case or will she be the one to save her husband''s life? Based on the true story of','','','2022-06-23 18:30:55.658016'
3
+ 'image/1.jpg','A young woman is forced to deal with her past when she is accused of murder. She tries to find out what happened to her husband, who is also accused of the crime. Will she be able to solve the case or will she be the one to save her husband''s life? Based on the true story of','','','2022-06-23 18:30:57.352462'
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ gradio==2.9.0
3
+ transformers==4.12.5
4
+ torch==1.10.0+cpu
test.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import re
4
+ import gradio as gr
5
+ from pathlib import Path
6
+ from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
7
+
8
+
9
+ # Pattern to ignore all the text after 2 or more full stops
10
+ regex_pattern = "[.]{2,}"
11
+
12
+
13
+ def post_process(text):
14
+ try:
15
+ text = text.strip()
16
+ text = re.split(regex_pattern, text)[0]
17
+ except Exception as e:
18
+ print(e)
19
+ pass
20
+ return text
21
+
22
+
23
+ def set_example_image(example: list) -> dict:
24
+ return gr.Image.update(value=example[0])
25
+
26
+
27
+ def predict(image, max_length=64, num_beams=4):
28
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
29
+ pixel_values = pixel_values.to(device)
30
+
31
+ with torch.no_grad():
32
+ output_ids = model.generate(
33
+ pixel_values,
34
+ max_length=max_length,
35
+ num_beams=num_beams,
36
+ return_dict_in_generate=True,
37
+ ).sequences
38
+
39
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
40
+ pred = post_process(preds[0])
41
+
42
+ return pred
43
+
44
+
45
+ model_name_or_path = "deepklarity/poster2plot"
46
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47
+
48
+ # Load model.
49
+
50
+ model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path)
51
+ model.to(device)
52
+ print("Loaded model")
53
+
54
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path)
55
+ print("Loaded feature_extractor")
56
+
57
+ tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
58
+ if model.decoder.name_or_path == "gpt2":
59
+ tokenizer.pad_token = tokenizer.eos_token
60
+
61
+ print("Loaded tokenizer")
62
+
63
+ title = "Poster2Plot: Upload a Movie/T.V show poster to generate a plot"
64
+ description = ""
65
+
66
+ input = gr.inputs.Image(type="pil")
67
+
68
+ example_images = sorted(
69
+ [f.as_posix() for f in Path("examples").glob("*.jpg")]
70
+ )
71
+ print(f"Loaded {len(example_images)} example images")
72
+
73
+ demo = gr.Blocks()
74
+ filenames = next(os.walk('examples'), (None, None, []))[2]
75
+ examples = [[f"examples/{filename}"] for filename in filenames]
76
+ print(examples)
77
+
78
+ with demo:
79
+ with gr.Column():
80
+ with gr.Row():
81
+ with gr.Column():
82
+ input_image = gr.Image()
83
+ with gr.Row():
84
+ clear_button = gr.Button(value="Clear", variant='secondary')
85
+ submit_button = gr.Button(value="Submit", variant='primary')
86
+ with gr.Column():
87
+ plot = gr.Textbox()
88
+ with gr.Row():
89
+ example_images = gr.Dataset(components=[input_image], samples=examples)
90
+
91
+ submit_button.click(fn=predict, inputs=[input_image], outputs=[plot])
92
+ example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components)
93
+
94
+ demo.launch()
train/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ # Train new model
2
+
3
+ - Download and extract the following datasets in a new folder called datasets:
4
+
5
+ 1. [IMDb movies extensive dataset](https://www.kaggle.com/stefanoleone992/imdb-extensive-dataset)
6
+ 2. [48K IMDB Movies With Posters](https://www.kaggle.com/rezaunderfit/48k-imdb-movies-with-posters)
7
+
8
+ - Run `create_dataset.ipynb` to create train.csv and valid.csv
9
+ - Run `train.ipynb` to train the model
train/create_dataset.ipynb ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "0fbed7bc",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2021-12-09T16:46:29.851016Z",
10
+ "start_time": "2021-12-09T16:46:29.841794Z"
11
+ },
12
+ "pycharm": {
13
+ "name": "#%%\n"
14
+ }
15
+ },
16
+ "outputs": [],
17
+ "source": [
18
+ "%reload_ext autoreload\n",
19
+ "%autoreload 2"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "id": "99d6f14d",
26
+ "metadata": {
27
+ "ExecuteTime": {
28
+ "end_time": "2021-12-09T16:46:30.336104Z",
29
+ "start_time": "2021-12-09T16:46:29.852308Z"
30
+ },
31
+ "pycharm": {
32
+ "name": "#%%\n"
33
+ }
34
+ },
35
+ "outputs": [],
36
+ "source": [
37
+ "from pathlib import Path\n",
38
+ "import pandas as pd\n",
39
+ "import shutil\n",
40
+ "from sklearn.model_selection import train_test_split"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "c8fcf96c",
47
+ "metadata": {
48
+ "ExecuteTime": {
49
+ "end_time": "2021-12-09T16:46:30.349125Z",
50
+ "start_time": "2021-12-09T16:46:30.337223Z"
51
+ },
52
+ "code_folding": [],
53
+ "pycharm": {
54
+ "name": "#%%\n"
55
+ }
56
+ },
57
+ "outputs": [],
58
+ "source": [
59
+ "def copy_images(\n",
60
+ " src_dir: Path,\n",
61
+ " des_dir: Path,\n",
62
+ " ids_with_plots: list,\n",
63
+ " delete_existing_files: bool = False,\n",
64
+ "):\n",
65
+ " \"\"\"This function copies a poster to images folder if it's id is present in the ids_with_plots list\"\"\"\n",
66
+ "\n",
67
+ " images_list = []\n",
68
+ " if delete_existing_files:\n",
69
+ " shutil.rmtree(des_dir)\n",
70
+ "\n",
71
+ " des_dir.mkdir(parents=True, exist_ok=True)\n",
72
+ "\n",
73
+ " for f in src_dir.rglob(\"*\"):\n",
74
+ " try:\n",
75
+ " if f.is_file() and f.suffix in [\".jpg\", \".jpeg\", \".png\"]:\n",
76
+ " img_name = f.name\n",
77
+ " id = Path(img_name).stem\n",
78
+ " if id in ids_with_plots:\n",
79
+ " desc_file = des_dir / img_name\n",
80
+ " shutil.copy(f, desc_file)\n",
81
+ " images_list.append((id, img_name))\n",
82
+ " except Exception as e:\n",
83
+ " print(f, e)\n",
84
+ " return images_list"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "id": "a34124b2",
91
+ "metadata": {
92
+ "ExecuteTime": {
93
+ "end_time": "2021-12-09T16:46:30.359361Z",
94
+ "start_time": "2021-12-09T16:46:30.350299Z"
95
+ },
96
+ "pycharm": {
97
+ "name": "#%%\n"
98
+ }
99
+ },
100
+ "outputs": [],
101
+ "source": [
102
+ "data_dir = Path(\"datasets\").resolve()\n",
103
+ "images_dir = data_dir / \"images\""
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "id": "8714ea01",
110
+ "metadata": {
111
+ "ExecuteTime": {
112
+ "end_time": "2021-12-09T16:46:30.781046Z",
113
+ "start_time": "2021-12-09T16:46:30.360608Z"
114
+ },
115
+ "pycharm": {
116
+ "name": "#%%\n"
117
+ }
118
+ },
119
+ "outputs": [],
120
+ "source": [
121
+ "movies_df = pd.read_csv(\n",
122
+ " data_dir / \"IMDb movies.csv\", usecols=[\"imdb_title_id\", \"description\"]\n",
123
+ ")\n",
124
+ "movies_df = movies_df.rename(columns={\"imdb_title_id\": \"id\", \"description\": \"text\"})\n",
125
+ "movies_df.dropna(subset=[\"text\"], inplace=True) # Drop rows where text is empty\n",
126
+ "movies_df.head()\n"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "id": "27f7fd94",
133
+ "metadata": {
134
+ "ExecuteTime": {
135
+ "end_time": "2021-12-09T16:46:30.792761Z",
136
+ "start_time": "2021-12-09T16:46:30.781964Z"
137
+ },
138
+ "pycharm": {
139
+ "name": "#%%\n"
140
+ }
141
+ },
142
+ "outputs": [],
143
+ "source": [
144
+ "ids_with_plots = movies_df.id.tolist()"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "id": "ebaa042a",
151
+ "metadata": {
152
+ "ExecuteTime": {
153
+ "end_time": "2021-12-09T16:47:04.704390Z",
154
+ "start_time": "2021-12-09T16:46:30.794094Z"
155
+ },
156
+ "pycharm": {
157
+ "name": "#%%\n"
158
+ }
159
+ },
160
+ "outputs": [],
161
+ "source": [
162
+ "images_list = copy_images(data_dir / \"Poster\", images_dir, ids_with_plots)\n",
163
+ "images_list[0]"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "id": "17e0a874",
170
+ "metadata": {
171
+ "ExecuteTime": {
172
+ "end_time": "2021-12-09T16:47:04.724427Z",
173
+ "start_time": "2021-12-09T16:47:04.705540Z"
174
+ },
175
+ "pycharm": {
176
+ "name": "#%%\n"
177
+ }
178
+ },
179
+ "outputs": [],
180
+ "source": [
181
+ "images_df = pd.DataFrame(images_list, columns=[\"id\", \"filename\"])\n",
182
+ "images_df.head()"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "id": "bb1114e6",
189
+ "metadata": {
190
+ "ExecuteTime": {
191
+ "end_time": "2021-12-09T16:47:04.772775Z",
192
+ "start_time": "2021-12-09T16:47:04.725707Z"
193
+ },
194
+ "pycharm": {
195
+ "name": "#%%\n"
196
+ }
197
+ },
198
+ "outputs": [],
199
+ "source": [
200
+ "data_df = pd.merge(movies_df, images_df, on=[\"id\"])\n",
201
+ "print(len(data_df))\n",
202
+ "data_df"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "6790815b",
209
+ "metadata": {
210
+ "ExecuteTime": {
211
+ "end_time": "2021-12-09T16:47:04.796785Z",
212
+ "start_time": "2021-12-09T16:47:04.774932Z"
213
+ },
214
+ "pycharm": {
215
+ "name": "#%%\n"
216
+ }
217
+ },
218
+ "outputs": [],
219
+ "source": [
220
+ "print(len(data_df))\n",
221
+ "data_df.dropna(subset=[\"filename\"], inplace=True)\n",
222
+ "print(len(data_df))"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "id": "40c7205d",
229
+ "metadata": {
230
+ "ExecuteTime": {
231
+ "end_time": "2021-12-09T16:47:04.818522Z",
232
+ "start_time": "2021-12-09T16:47:04.798063Z"
233
+ },
234
+ "pycharm": {
235
+ "name": "#%%\n"
236
+ }
237
+ },
238
+ "outputs": [],
239
+ "source": [
240
+ "print(len(data_df))\n",
241
+ "data_df.dropna(subset=[\"text\"], inplace=True)\n",
242
+ "print(len(data_df))"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "id": "9a2d142f",
249
+ "metadata": {
250
+ "ExecuteTime": {
251
+ "end_time": "2021-12-09T16:47:04.838450Z",
252
+ "start_time": "2021-12-09T16:47:04.819726Z"
253
+ },
254
+ "pycharm": {
255
+ "name": "#%%\n"
256
+ }
257
+ },
258
+ "outputs": [],
259
+ "source": [
260
+ "print(len(data_df))\n",
261
+ "data_df.drop_duplicates(subset=[\"id\"], inplace=True)\n",
262
+ "print(len(data_df))"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "id": "45f4b970",
269
+ "metadata": {
270
+ "ExecuteTime": {
271
+ "end_time": "2021-12-09T16:47:04.971652Z",
272
+ "start_time": "2021-12-09T16:47:04.839618Z"
273
+ },
274
+ "pycharm": {
275
+ "name": "#%%\n"
276
+ }
277
+ },
278
+ "outputs": [],
279
+ "source": [
280
+ "data_df.to_csv(data_dir / \"data.csv\", index=False)"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "id": "f8019a02",
287
+ "metadata": {
288
+ "ExecuteTime": {
289
+ "end_time": "2021-12-09T16:47:05.104710Z",
290
+ "start_time": "2021-12-09T16:47:04.972681Z"
291
+ },
292
+ "pycharm": {
293
+ "name": "#%%\n"
294
+ }
295
+ },
296
+ "outputs": [],
297
+ "source": [
298
+ "train_df, valid_df = train_test_split(data_df, test_size=0.1, shuffle=True)\n",
299
+ "train_df.to_csv(data_dir / \"train.csv\", index=False)\n",
300
+ "valid_df.to_csv(data_dir / \"valid.csv\", index=False)\n",
301
+ "print(len(train_df), len(valid_df))"
302
+ ]
303
+ }
304
+ ],
305
+ "metadata": {
306
+ "kernelspec": {
307
+ "display_name": "huggingface",
308
+ "language": "python",
309
+ "name": "huggingface"
310
+ },
311
+ "language_info": {
312
+ "codemirror_mode": {
313
+ "name": "ipython",
314
+ "version": 3
315
+ },
316
+ "file_extension": ".py",
317
+ "mimetype": "text/x-python",
318
+ "name": "python",
319
+ "nbconvert_exporter": "python",
320
+ "pygments_lexer": "ipython3",
321
+ "version": "3.9.7"
322
+ }
323
+ },
324
+ "nbformat": 4,
325
+ "nbformat_minor": 5
326
+ }
train/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ pandas==1.3.4
3
+ scikit-learn==1.0.1
4
+ python-box==5.4.1
5
+ transformers==4.12.5
6
+ torch==1.10.0+cu113
7
+ Pillow==8.4.0
train/train.ipynb ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "0fbed7bc",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2021-12-09T15:34:14.921553Z",
10
+ "start_time": "2021-12-09T15:34:14.911112Z"
11
+ }
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "%reload_ext autoreload\n",
16
+ "%autoreload 2"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "c4b60ef3",
23
+ "metadata": {
24
+ "ExecuteTime": {
25
+ "end_time": "2021-12-09T15:34:15.961098Z",
26
+ "start_time": "2021-12-09T15:34:14.922771Z"
27
+ },
28
+ "code_folding": []
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "# imports\n",
33
+ "\n",
34
+ "import pandas as pd\n",
35
+ "import os\n",
36
+ "from pathlib import Path\n",
37
+ "from PIL import Image\n",
38
+ "import shutil\n",
39
+ "from logging import root\n",
40
+ "from PIL import Image\n",
41
+ "from pathlib import Path\n",
42
+ "import pandas as pd\n",
43
+ "import torch\n",
44
+ "from torch.utils.data import Dataset\n",
45
+ "from PIL import Image\n",
46
+ "from transformers import (\n",
47
+ " Seq2SeqTrainer,\n",
48
+ " Seq2SeqTrainingArguments,\n",
49
+ " get_linear_schedule_with_warmup,\n",
50
+ " AutoFeatureExtractor,\n",
51
+ " AutoTokenizer,\n",
52
+ " ViTFeatureExtractor,\n",
53
+ " VisionEncoderDecoderModel,\n",
54
+ " default_data_collator,\n",
55
+ ")\n",
56
+ "from transformers.optimization import AdamW\n",
57
+ "\n",
58
+ "from box import Box\n",
59
+ "import inspect\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "id": "99d6f14d",
66
+ "metadata": {
67
+ "ExecuteTime": {
68
+ "end_time": "2021-12-09T15:34:15.979191Z",
69
+ "start_time": "2021-12-09T15:34:15.962078Z"
70
+ },
71
+ "code_folding": []
72
+ },
73
+ "outputs": [],
74
+ "source": [
75
+ "# custom functions\n",
76
+ "\n",
77
+ "class ImageCaptionDataset(Dataset):\n",
78
+ " def __init__(\n",
79
+ " self, df, feature_extractor, tokenizer, images_dir, max_target_length=128\n",
80
+ " ):\n",
81
+ " self.df = df\n",
82
+ " self.feature_extractor = feature_extractor\n",
83
+ " self.tokenizer = tokenizer\n",
84
+ " self.images_dir = images_dir\n",
85
+ " self.max_target_length = max_target_length\n",
86
+ "\n",
87
+ " def __len__(self):\n",
88
+ " return len(self.df)\n",
89
+ "\n",
90
+ " def __getitem__(self, idx):\n",
91
+ " filename = self.df[\"filename\"][idx]\n",
92
+ " text = self.df[\"text\"][idx]\n",
93
+ " # prepare image (i.e. resize + normalize)\n",
94
+ " image = Image.open(self.images_dir / filename).convert(\"RGB\")\n",
95
+ " pixel_values = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n",
96
+ " # add labels (input_ids) by encoding the text\n",
97
+ " labels = self.tokenizer(\n",
98
+ " text,\n",
99
+ " padding=\"max_length\",\n",
100
+ " truncation=True,\n",
101
+ " max_length=self.max_target_length,\n",
102
+ " ).input_ids\n",
103
+ " # important: make sure that PAD tokens are ignored by the loss function\n",
104
+ " labels = [\n",
105
+ " label if label != self.tokenizer.pad_token_id else -100 for label in labels\n",
106
+ " ]\n",
107
+ "\n",
108
+ " encoding = {\n",
109
+ " \"pixel_values\": pixel_values.squeeze(),\n",
110
+ " \"labels\": torch.tensor(labels),\n",
111
+ " }\n",
112
+ " return encoding\n",
113
+ "\n",
114
+ "\n",
115
+ "\n",
116
+ "def predict(image, max_length=64, num_beams=4):\n",
117
+ "\n",
118
+ " pixel_values = feature_extractor(images=image, return_tensors=\"pt\").pixel_values\n",
119
+ " pixel_values = pixel_values.to(device)\n",
120
+ "\n",
121
+ " with torch.no_grad():\n",
122
+ " output_ids = model.generate(\n",
123
+ " pixel_values,\n",
124
+ " max_length=max_length,\n",
125
+ " num_beams=num_beams,\n",
126
+ " return_dict_in_generate=True,\n",
127
+ " ).sequences\n",
128
+ "\n",
129
+ " preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)\n",
130
+ " preds = [pred.strip() for pred in preds]\n",
131
+ "\n",
132
+ " return preds\n"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "id": "ea66826b",
139
+ "metadata": {
140
+ "ExecuteTime": {
141
+ "end_time": "2021-12-09T15:34:16.042990Z",
142
+ "start_time": "2021-12-09T15:34:15.980557Z"
143
+ }
144
+ },
145
+ "outputs": [],
146
+ "source": [
147
+ "data_dir = Path(\"datasets\").resolve()\n",
148
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
149
+ "print(device)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "id": "17cfb2c2",
156
+ "metadata": {
157
+ "ExecuteTime": {
158
+ "end_time": "2021-12-09T15:34:16.058421Z",
159
+ "start_time": "2021-12-09T15:34:16.044111Z"
160
+ }
161
+ },
162
+ "outputs": [],
163
+ "source": [
164
+ "# arguments pertaining to what data we are going to input our model for training and eval.\n",
165
+ "\n",
166
+ "data_training_args = {\n",
167
+ " # The maximum total sequence length for target text after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.\n",
168
+ " \"max_target_length\": 64,\n",
169
+ "\n",
170
+ " # Number of beams to use for evaluation. This argument will be passed to model.generate which is used during evaluate and predict.\n",
171
+ " \"num_beams\": 4,\n",
172
+ "\n",
173
+ " # Folder with all the images\n",
174
+ " \"images_dir\": data_dir / \"images\",\n",
175
+ "}\n",
176
+ "\n",
177
+ "data_training_args = Box(data_training_args)"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "adc4839a",
184
+ "metadata": {
185
+ "ExecuteTime": {
186
+ "end_time": "2021-12-09T15:34:16.073242Z",
187
+ "start_time": "2021-12-09T15:34:16.059354Z"
188
+ }
189
+ },
190
+ "outputs": [],
191
+ "source": [
192
+ "# arguments pertaining to which model/config/tokenizer we are going to fine-tune from.\n",
193
+ "\n",
194
+ "model_args = {\n",
195
+ "\n",
196
+ " # Path to pretrained model or model identifier from huggingface.co/models\"\n",
197
+ " \"encoder_model_name_or_path\": \"google/vit-base-patch16-224-in21k\",\n",
198
+ "\n",
199
+ " # Path to pretrained model or model identifier from huggingface.co/models\"\n",
200
+ " \"decoder_model_name_or_path\": \"gpt2\",\n",
201
+ "\n",
202
+ " # If set to int > 0, all ngrams of that size can only occur once.\n",
203
+ " \"no_repeat_ngram_size\": 3,\n",
204
+ "\n",
205
+ " # Exponential penalty to the length that will be used by default in the generate method of the model.\n",
206
+ " \"length_penalty\": 2.0,\n",
207
+ "}\n",
208
+ "\n",
209
+ "model_args = Box(model_args)"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "id": "22b8c9e3",
216
+ "metadata": {
217
+ "ExecuteTime": {
218
+ "end_time": "2021-12-09T15:34:16.089201Z",
219
+ "start_time": "2021-12-09T15:34:16.074223Z"
220
+ }
221
+ },
222
+ "outputs": [],
223
+ "source": [
224
+ "# arguments pertaining to Trainer class. Refer: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments\n",
225
+ "\n",
226
+ "training_args = {\n",
227
+ " \"num_train_epochs\": 5,\n",
228
+ " \"per_device_train_batch_size\": 32,\n",
229
+ " \"per_device_eval_batch_size\": 32,\n",
230
+ " \"output_dir\": \"output_dir\",\n",
231
+ " \"do_train\": True,\n",
232
+ " \"do_eval\": True,\n",
233
+ " \"fp16\": True,\n",
234
+ " \"learning_rate\": 1e-5,\n",
235
+ " \"load_best_model_at_end\": True,\n",
236
+ " \"evaluation_strategy\": \"epoch\",\n",
237
+ " \"save_strategy\": \"epoch\",\n",
238
+ " \"report_to\": \"none\"\n",
239
+ "}\n",
240
+ "\n",
241
+ "seq2seq_training_args = Seq2SeqTrainingArguments(**training_args)"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": null,
247
+ "id": "d0023eac",
248
+ "metadata": {
249
+ "ExecuteTime": {
250
+ "end_time": "2021-12-09T15:34:37.844396Z",
251
+ "start_time": "2021-12-09T15:34:16.090085Z"
252
+ }
253
+ },
254
+ "outputs": [],
255
+ "source": [
256
+ "feature_extractor = ViTFeatureExtractor.from_pretrained(\n",
257
+ " model_args.encoder_model_name_or_path\n",
258
+ ")\n",
259
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
260
+ " model_args.decoder_model_name_or_path, use_fast=True\n",
261
+ ")\n",
262
+ "tokenizer.pad_token = tokenizer.eos_token\n",
263
+ "\n",
264
+ "model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(\n",
265
+ " model_args.encoder_model_name_or_path, model_args.decoder_model_name_or_path\n",
266
+ ")\n",
267
+ "\n",
268
+ "# set special tokens used for creating the decoder_input_ids from the labels\n",
269
+ "model.config.decoder_start_token_id = tokenizer.bos_token_id\n",
270
+ "model.config.pad_token_id = tokenizer.pad_token_id\n",
271
+ "# make sure vocab size is set correctly\n",
272
+ "model.config.vocab_size = model.config.decoder.vocab_size\n",
273
+ "\n",
274
+ "# set beam search parameters\n",
275
+ "model.config.eos_token_id = tokenizer.sep_token_id\n",
276
+ "model.config.max_length = data_training_args.max_target_length\n",
277
+ "model.config.no_repeat_ngram_size = model_args.no_repeat_ngram_size\n",
278
+ "model.config.length_penalty = model_args.length_penalty\n",
279
+ "model.config.num_beams = data_training_args.num_beams\n",
280
+ "model.decoder.resize_token_embeddings(len(tokenizer))\n"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "id": "6428ea08",
287
+ "metadata": {
288
+ "ExecuteTime": {
289
+ "end_time": "2021-12-09T15:34:37.933804Z",
290
+ "start_time": "2021-12-09T15:34:37.845607Z"
291
+ }
292
+ },
293
+ "outputs": [],
294
+ "source": [
295
+ "train_df = pd.read_csv(data_dir / \"train.csv\")\n",
296
+ "valid_df = pd.read_csv(data_dir / \"valid.csv\")\n",
297
+ "\n",
298
+ "train_dataset = ImageCaptionDataset(\n",
299
+ " df=train_df,\n",
300
+ " feature_extractor=feature_extractor,\n",
301
+ " tokenizer=tokenizer,\n",
302
+ " images_dir=data_training_args.images_dir,\n",
303
+ " max_target_length=data_training_args.max_target_length,\n",
304
+ ")\n",
305
+ "eval_dataset = ImageCaptionDataset(\n",
306
+ " df=valid_df,\n",
307
+ " feature_extractor=feature_extractor,\n",
308
+ " tokenizer=tokenizer,\n",
309
+ " images_dir=data_training_args.images_dir,\n",
310
+ " max_target_length=data_training_args.max_target_length,\n",
311
+ ")\n",
312
+ "\n",
313
+ "print(f\"Number of training examples: {len(train_dataset)}\")\n",
314
+ "print(f\"Number of validation examples: {len(eval_dataset)}\")"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "id": "c8e492a1",
321
+ "metadata": {
322
+ "ExecuteTime": {
323
+ "end_time": "2021-12-09T15:34:37.971630Z",
324
+ "start_time": "2021-12-09T15:34:37.935339Z"
325
+ }
326
+ },
327
+ "outputs": [],
328
+ "source": [
329
+ "# Let's verify an example from the training dataset:\n",
330
+ "\n",
331
+ "encoding = train_dataset[0]\n",
332
+ "for k,v in encoding.items():\n",
333
+ " print(k, v.shape)"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "edb4e7a6",
340
+ "metadata": {
341
+ "ExecuteTime": {
342
+ "end_time": "2021-12-09T15:34:38.006980Z",
343
+ "start_time": "2021-12-09T15:34:37.972483Z"
344
+ }
345
+ },
346
+ "outputs": [],
347
+ "source": [
348
+ "# We can also check the original image and decode the labels:\n",
349
+ "image = Image.open(data_training_args.images_dir / train_df[\"filename\"][0]).convert(\"RGB\")\n",
350
+ "image"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "id": "25f2cae7",
357
+ "metadata": {
358
+ "ExecuteTime": {
359
+ "end_time": "2021-12-09T15:34:38.031745Z",
360
+ "start_time": "2021-12-09T15:34:38.008027Z"
361
+ }
362
+ },
363
+ "outputs": [],
364
+ "source": [
365
+ "labels = encoding[\"labels\"]\n",
366
+ "labels[labels == -100] = tokenizer.pad_token_id\n",
367
+ "label_str = tokenizer.decode(labels, skip_special_tokens=True)\n",
368
+ "print(label_str)\n"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "execution_count": null,
374
+ "id": "b7a009d3",
375
+ "metadata": {
376
+ "ExecuteTime": {
377
+ "end_time": "2021-12-09T15:34:38.049539Z",
378
+ "start_time": "2021-12-09T15:34:38.032749Z"
379
+ }
380
+ },
381
+ "outputs": [],
382
+ "source": [
383
+ "optimizer = AdamW(model.parameters(), lr=seq2seq_training_args.learning_rate)\n",
384
+ "\n",
385
+ "steps_per_epoch = len(train_dataset) // seq2seq_training_args.per_device_train_batch_size\n",
386
+ "num_training_steps = steps_per_epoch * seq2seq_training_args.num_train_epochs\n",
387
+ "\n",
388
+ "lr_scheduler = get_linear_schedule_with_warmup(\n",
389
+ " optimizer,\n",
390
+ " num_warmup_steps=seq2seq_training_args.warmup_steps,\n",
391
+ " num_training_steps=num_training_steps,\n",
392
+ ")\n",
393
+ "\n",
394
+ "optimizers = (optimizer, lr_scheduler)"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "f2f477b2",
401
+ "metadata": {
402
+ "ExecuteTime": {
403
+ "start_time": "2021-12-09T15:34:14.944Z"
404
+ }
405
+ },
406
+ "outputs": [],
407
+ "source": [
408
+ "trainer = Seq2SeqTrainer(\n",
409
+ " model=model,\n",
410
+ " optimizers=optimizers,\n",
411
+ " tokenizer=feature_extractor,\n",
412
+ " args=seq2seq_training_args,\n",
413
+ " train_dataset=train_dataset,\n",
414
+ " eval_dataset=eval_dataset,\n",
415
+ " data_collator=default_data_collator,\n",
416
+ ")\n",
417
+ "\n",
418
+ "trainer.train()"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "id": "f08d2b7c",
425
+ "metadata": {
426
+ "ExecuteTime": {
427
+ "end_time": "2021-12-09T16:24:49.096274Z",
428
+ "start_time": "2021-12-09T16:24:49.096246Z"
429
+ }
430
+ },
431
+ "outputs": [],
432
+ "source": [
433
+ "test_img = \"../examples/tt7991608-red-notice.jpg\"\n",
434
+ "with Image.open(test_img) as image:\n",
435
+ " preds = predict(\n",
436
+ " image, max_length=data_training_args.max_target_length, num_beams=data_training_args.num_beams\n",
437
+ " )\n",
438
+ "\n",
439
+ "# Uncomment to display the test image in a jupyter notebook\n",
440
+ "# display(image)\n",
441
+ "print(preds[0])"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "id": "ecf21225",
448
+ "metadata": {},
449
+ "outputs": [],
450
+ "source": []
451
+ }
452
+ ],
453
+ "metadata": {
454
+ "kernelspec": {
455
+ "display_name": "huggingface",
456
+ "language": "python",
457
+ "name": "huggingface"
458
+ },
459
+ "language_info": {
460
+ "codemirror_mode": {
461
+ "name": "ipython",
462
+ "version": 3
463
+ },
464
+ "file_extension": ".py",
465
+ "mimetype": "text/x-python",
466
+ "name": "python",
467
+ "nbconvert_exporter": "python",
468
+ "pygments_lexer": "ipython3",
469
+ "version": "3.9.7"
470
+ }
471
+ },
472
+ "nbformat": 4,
473
+ "nbformat_minor": 5
474
+ }