Joshua Lochner commited on
Commit
366d154
β€’
1 Parent(s): ee58f38

Use exact same structure as example

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
.gitattributes CHANGED
@@ -17,12 +17,15 @@
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
20
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
24
- *.wasm filter=lfs diff=lfs merge=lfs -text
25
  *.xz filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
29
+ variables filter=lfs diff=lfs merge=lfs -text
30
+ /Users/mervenoyan/Desktop/seg/pet-segmentation/variables/variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
31
+ variables.index filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,5 +1,32 @@
1
  ---
2
  tags:
3
- - text-classification
 
 
 
 
 
 
 
 
 
 
 
 
4
  ---
5
- Test
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  tags:
3
+ - image-segmentation
4
+ - generic
5
+ library_name: generic
6
+ dataset:
7
+ - oxfort-iit pets
8
+ widget:
9
+ - src: https://huggingface.co/datasets/mishig/sample_images/resolve/main/cat-1.jpg
10
+ example_title: Kedis
11
+ - src: https://huggingface.co/datasets/mishig/sample_images/resolve/main/cat-2.jpg
12
+ example_title: Cat in a Crate
13
+ - src: https://huggingface.co/datasets/mishig/sample_images/resolve/main/cat-3.jpg
14
+ example_title: Two Cats Chilling
15
+ license: cc0-1.0
16
  ---
17
+ ## Keras semantic segmentation models on the πŸ€—Hub! 🐢 πŸ• 🐩
18
+ Full credits go to [François Chollet](https://twitter.com/fchollet).
19
+
20
+ This repository contains the model from [this notebook on segmenting pets using U-net-like architecture](https://keras.io/examples/vision/oxford_pets_image_segmentation/). We've changed the inference part to enable segmentation widget on the Hub. (see ```pipeline.py```)
21
+
22
+ ## Background Information
23
+
24
+ Image classification task tells us about a class assigned to an image, and object detection task creates a boundary box on an object in an image. But what if we want to know about the shape of the image? Segmentation models helps us segment images and reveal their shapes. It has many variants, including, panoptic segmentation, instance segmentation and semantic segmentation.This post is on hosting your Keras semantic segmentation models on Hub.
25
+ Semantic segmentation models classify pixels, meaning, they assign a class (can be cat or dog) to each pixel. The output of a model looks like following.
26
+ ![Raw Output](./raw_output.jpg)
27
+ We need to get the best prediction for every pixel.
28
+ ![Mask](./mask.jpg)
29
+ This is still not readable. We have to convert this into different binary masks for each class and convert to a readable format by converting each mask into base64. We will return a list of dicts, and for each dictionary, we have the label itself, the base64 code and a score (semantic segmentation models don't return a score, so we have to return 1.0 for this case). You can find the full implementation in ```pipeline.py```.
30
+ ![Binary Mask](./binary_mask.jpg)
31
+ Now that you know the expected output by the model, you can host your Keras segmentation models (and other semantic segmentation models) in the similar fashion. Try it yourself and host your segmentation models!
32
+ ![Segmented Cat](./hircin_the_cat.png)
added_tokens.json DELETED
@@ -1 +0,0 @@
1
- {"NUMBER_PERCENTAGE_TOKEN": 30525, "HYPHENATED_URL_TOKEN": 30524, "START_SELFPROMO_TOKEN": 30536, "START_SPONSOR_TOKEN": 30534, "PROFANITY_TOKEN": 30532, "[Laughter]": 30531, "BETWEEN_SEGMENTS_TOKEN": 30540, "NUMBER_TOKEN": 30526, "SHORT_HYPHENATED_TOKEN": 30527, "END_SPONSOR_TOKEN": 30535, "LONG_WORD_TOKEN": 30528, "EXTRACT_SEGMENTS: ": 30522, "END_INTERACTION_TOKEN": 30539, "[Applause]": 30530, "START_INTERACTION_TOKEN": 30538, "END_SELFPROMO_TOKEN": 30537, "URL_TOKEN": 30523, "NO_SEGMENT_TOKEN": 30533, "[Music]": 30529}
 
binary_mask.jpg ADDED
config.json CHANGED
@@ -2,15 +2,6 @@
2
  "id2label": {
3
  "0": 0,
4
  "1": 1,
5
- "2": 2,
6
- "3": 3
7
- },
8
- "label2id": {
9
- "0": 0,
10
- "1": 1,
11
- "2": 2,
12
- "3": 3
13
- },
14
- "model_type": "bert",
15
- "vocab_size": 30541
16
- }
2
  "id2label": {
3
  "0": 0,
4
  "1": 1,
5
+ "2": 2
6
+ }
7
+ }
 
 
 
 
 
 
 
 
 
custom_pipeline.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import youtube_transcript_api2
2
+ import json
3
+ import re
4
+ import requests
5
+ from transformers import (
6
+ AutoModelForSequenceClassification,
7
+ AutoTokenizer,
8
+ TextClassificationPipeline,
9
+ )
10
+ from typing import Any, Dict, List
11
+
12
+ CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
13
+
14
+ PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
15
+ PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
16
+
17
+ NUM_DECIMALS = 3
18
+
19
+ # https://www.fincher.org/Utilities/CountryLanguageList.shtml
20
+ # https://lingohub.com/developers/supported-locales/language-designators-with-regions
21
+ LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA',
22
+ 'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW',
23
+ 'en']
24
+
25
+
26
+ def parse_transcript_json(json_data, granularity):
27
+ assert json_data['wireMagic'] == 'pb3'
28
+
29
+ assert granularity in ('word', 'chunk')
30
+
31
+ # TODO remove bracketed words?
32
+ # (kiss smacks)
33
+ # (upbeat music)
34
+ # [text goes here]
35
+
36
+ # Some manual transcripts aren't that well formatted... but do have punctuation
37
+ # https://www.youtube.com/watch?v=LR9FtWVjk2c
38
+
39
+ parsed_transcript = []
40
+
41
+ events = json_data['events']
42
+
43
+ for event_index, event in enumerate(events):
44
+ segments = event.get('segs')
45
+ if not segments:
46
+ continue
47
+
48
+ # This value is known (when phrase appears on screen)
49
+ start_ms = event['tStartMs']
50
+ total_characters = 0
51
+
52
+ new_segments = []
53
+ for seg in segments:
54
+ # Replace \n, \t, etc. with space
55
+ text = ' '.join(seg['utf8'].split())
56
+
57
+ # Remove zero-width spaces and strip trailing and leading whitespace
58
+ text = text.replace('\u200b', '').replace('\u200c', '').replace(
59
+ '\u200d', '').replace('\ufeff', '').strip()
60
+
61
+ # Alternatively,
62
+ # text = text.encode('ascii', 'ignore').decode()
63
+
64
+ # Needed for auto-generated transcripts
65
+ text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED)
66
+
67
+ if not text:
68
+ continue
69
+
70
+ offset_ms = seg.get('tOffsetMs', 0)
71
+
72
+ new_segments.append({
73
+ 'text': text,
74
+ 'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS)
75
+ })
76
+
77
+ total_characters += len(text)
78
+
79
+ if not new_segments:
80
+ continue
81
+
82
+ if event_index < len(events) - 1:
83
+ next_start_ms = events[event_index + 1]['tStartMs']
84
+ total_event_duration_ms = min(
85
+ event.get('dDurationMs', float('inf')), next_start_ms - start_ms)
86
+ else:
87
+ total_event_duration_ms = event.get('dDurationMs', 0)
88
+
89
+ # Ensure duration is non-negative
90
+ total_event_duration_ms = max(total_event_duration_ms, 0)
91
+
92
+ avg_seconds_per_character = (
93
+ total_event_duration_ms/total_characters)/1000
94
+
95
+ num_char_count = 0
96
+ for seg_index, seg in enumerate(new_segments):
97
+ num_char_count += len(seg['text'])
98
+
99
+ # Estimate segment end
100
+ seg_end = seg['start'] + \
101
+ (num_char_count * avg_seconds_per_character)
102
+
103
+ if seg_index < len(new_segments) - 1:
104
+ # Do not allow longer than next
105
+ seg_end = min(seg_end, new_segments[seg_index+1]['start'])
106
+
107
+ seg['end'] = round(seg_end, NUM_DECIMALS)
108
+ parsed_transcript.append(seg)
109
+
110
+ final_parsed_transcript = []
111
+ for i in range(len(parsed_transcript)):
112
+
113
+ word_level = granularity == 'word'
114
+ if word_level:
115
+ split_text = parsed_transcript[i]['text'].split()
116
+ elif granularity == 'chunk':
117
+ # Split on space after punctuation
118
+ split_text = re.split(
119
+ r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text'])
120
+ if len(split_text) == 1:
121
+ split_on_whitespace = parsed_transcript[i]['text'].split()
122
+
123
+ if len(split_on_whitespace) >= 8: # Too many words
124
+ # Rather split on whitespace instead of punctuation
125
+ split_text = split_on_whitespace
126
+ else:
127
+ word_level = True
128
+ else:
129
+ raise ValueError('Unknown granularity')
130
+
131
+ segment_end = parsed_transcript[i]['end']
132
+ if i < len(parsed_transcript) - 1:
133
+ segment_end = min(segment_end, parsed_transcript[i+1]['start'])
134
+
135
+ segment_duration = segment_end - parsed_transcript[i]['start']
136
+
137
+ num_chars_in_text = sum(map(len, split_text))
138
+
139
+ num_char_count = 0
140
+ current_offset = 0
141
+ for s in split_text:
142
+ num_char_count += len(s)
143
+
144
+ next_offset = (num_char_count/num_chars_in_text) * segment_duration
145
+
146
+ word_start = round(
147
+ parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS)
148
+ word_end = round(
149
+ parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS)
150
+
151
+ # Make the reasonable assumption that min wps is 1.5
152
+ final_parsed_transcript.append({
153
+ 'text': s,
154
+ 'start': word_start,
155
+ 'end': min(word_end, word_start + 1.5) if word_level else word_end
156
+ })
157
+ current_offset = next_offset
158
+
159
+ return final_parsed_transcript
160
+
161
+
162
+ def list_transcripts(video_id):
163
+ try:
164
+ return youtube_transcript_api2.YouTubeTranscriptApi.list_transcripts(video_id)
165
+ except json.decoder.JSONDecodeError:
166
+ return None
167
+
168
+
169
+ WORDS_TO_REMOVE = [
170
+ '[Music]'
171
+ '[Applause]'
172
+ '[Laughter]'
173
+ ]
174
+
175
+
176
+ def get_words(video_id, transcript_type='auto', fallback='manual', filter_words_to_remove=True, granularity='word'):
177
+ """Get parsed video transcript with caching system
178
+ returns None if not processed yet and process is False
179
+ """
180
+
181
+ raw_transcript_json = None
182
+ try:
183
+ transcript_list = list_transcripts(video_id)
184
+
185
+ if transcript_list is not None:
186
+ if transcript_type == 'manual':
187
+ ts = transcript_list.find_manually_created_transcript(
188
+ LANGUAGE_PREFERENCE_LIST)
189
+ else:
190
+ ts = transcript_list.find_generated_transcript(
191
+ LANGUAGE_PREFERENCE_LIST)
192
+ raw_transcript = ts._http_client.get(
193
+ f'{ts._url}&fmt=json3').content
194
+ if raw_transcript:
195
+ raw_transcript_json = json.loads(raw_transcript)
196
+
197
+ except (youtube_transcript_api2.TooManyRequests, youtube_transcript_api2.YouTubeRequestFailed):
198
+ raise # Cannot recover from these errors and do not mark as empty transcript
199
+
200
+ except requests.exceptions.RequestException: # Can recover
201
+ return get_words(video_id, transcript_type, fallback, granularity)
202
+
203
+ except youtube_transcript_api2.CouldNotRetrieveTranscript: # Retrying won't solve
204
+ pass # Mark as empty transcript
205
+
206
+ except json.decoder.JSONDecodeError:
207
+ return get_words(video_id, transcript_type, fallback, granularity)
208
+
209
+ if not raw_transcript_json and fallback is not None:
210
+ return get_words(video_id, transcript_type=fallback, fallback=None, granularity=granularity)
211
+
212
+ if raw_transcript_json:
213
+ processed_transcript = parse_transcript_json(
214
+ raw_transcript_json, granularity)
215
+ if filter_words_to_remove:
216
+ processed_transcript = list(
217
+ filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript))
218
+ else:
219
+ processed_transcript = raw_transcript_json # Either None or []
220
+
221
+ return processed_transcript
222
+
223
+
224
+ def word_start(word):
225
+ return word['start']
226
+
227
+
228
+ def word_end(word):
229
+ return word.get('end', word['start'])
230
+
231
+
232
+ def extract_segment(words, start, end, map_function=None):
233
+ """Extracts all words with time in [start, end]"""
234
+
235
+ a = max(binary_search_below(words, 0, len(words), start), 0)
236
+ b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
237
+
238
+ to_transform = map_function is not None and callable(map_function)
239
+
240
+ return [
241
+ map_function(words[i]) if to_transform else words[i] for i in range(a, b)
242
+ ]
243
+
244
+
245
+ def avg(*items):
246
+ return sum(items)/len(items)
247
+
248
+
249
+ def binary_search_below(transcript, start_index, end_index, time):
250
+ if start_index >= end_index:
251
+ return end_index
252
+
253
+ middle_index = (start_index + end_index) // 2
254
+ middle = transcript[middle_index]
255
+ middle_time = avg(word_start(middle), word_end(middle))
256
+
257
+ if time <= middle_time:
258
+ return binary_search_below(transcript, start_index, middle_index, time)
259
+ else:
260
+ return binary_search_below(transcript, middle_index + 1, end_index, time)
261
+
262
+
263
+ def binary_search_above(transcript, start_index, end_index, time):
264
+ if start_index >= end_index:
265
+ return end_index
266
+
267
+ middle_index = (start_index + end_index + 1) // 2
268
+ middle = transcript[middle_index]
269
+ middle_time = avg(word_start(middle), word_end(middle))
270
+
271
+ if time >= middle_time:
272
+ return binary_search_above(transcript, middle_index, end_index, time)
273
+ else:
274
+ return binary_search_above(transcript, start_index, middle_index - 1, time)
275
+
276
+
277
+ class SponsorBlockClassificationPipeline(TextClassificationPipeline):
278
+ def __init__(self, model, tokenizer):
279
+ super().__init__(model=model, tokenizer=tokenizer, return_all_scores=True)
280
+
281
+ def preprocess(self, video, **tokenizer_kwargs):
282
+
283
+ words = get_words(video['video_id'])
284
+ segment_words = extract_segment(words, video['start'], video['end'])
285
+ text = ' '.join(x['text'] for x in segment_words)
286
+
287
+ model_inputs = self.tokenizer(
288
+ text, return_tensors=self.framework, **tokenizer_kwargs)
289
+ return {'video': video, 'model_inputs': model_inputs}
290
+
291
+ def _forward(self, data):
292
+ model_outputs = self.model(**data['model_inputs'])
293
+ return {'video': data['video'], 'model_outputs': model_outputs}
294
+
295
+ def postprocess(self, data, function_to_apply=None, return_all_scores=False):
296
+ model_outputs = data['model_outputs']
297
+
298
+ results = super().postprocess(model_outputs, function_to_apply, return_all_scores)
299
+
300
+ for result in results:
301
+ result['label_text'] = CATEGORIES[result['label']]
302
+
303
+ return results # {**data['video'], 'result': results}
304
+
305
+
306
+ # model_id = "Xenova/sponsorblock-classifier-v2"
307
+ # model = AutoModelForSequenceClassification.from_pretrained(model_id)
308
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
309
+
310
+ # pl = SponsorBlockClassificationPipeline(model=model, tokenizer=tokenizer)
311
+ data = [{
312
+ 'video_id': 'pqh4LfPeCYs',
313
+ 'start': 835.933,
314
+ 'end': 927.581,
315
+ 'category': 'sponsor'
316
+ }]
317
+ # print(pl(data))
318
+
319
+ # MODEL_ID = "Xenova/sponsorblock-classifier-v2"
320
+ class PreTrainedPipeline():
321
+ def __init__(self, path: str):
322
+ # load the model
323
+ self.model = AutoModelForSequenceClassification.from_pretrained(path)
324
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
325
+ self.pipeline = SponsorBlockClassificationPipeline(
326
+ model=self.model, tokenizer=self.tokenizer)
327
+
328
+ def __call__(self, inputs: str) -> List[Dict[str, Any]]:
329
+ json_data = json.loads(inputs)
330
+ return self.pipeline(json_data)
331
+
332
+ # a = PreTrainedPipeline('Xenova/sponsorblock-classifier-v2')(json.dumps(data))
333
+ # print(a)
hircin_the_cat.png ADDED
image/.DS_Store ADDED
Binary file (6.15 kB). View file
image/binary_mask.jpg ADDED
image/mask.jpg ADDED
image/raw_output.jpg ADDED
rng_state.pth β†’ keras_metadata.pb RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:341b9fac5e4ef18cdbcc0f24ef9702f00339086845194584ef728d46e5bb3aac
3
- size 14439
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7844f7cdd94bacbd2dceb2000b172c07e3a3db345cd29ea0d51d66107ff28e9
3
+ size 563033
mask.jpg ADDED
pipeline.py CHANGED
@@ -1,14 +1,78 @@
 
1
  from typing import Any, Dict, List
2
 
 
 
 
 
 
 
 
 
 
 
3
  class PreTrainedPipeline():
4
  def __init__(self, path: str):
5
  # load the model
6
- self.model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- def __call__(self, inputs: str) -> List[Dict[str, Any]]:
 
9
 
10
- return [{
11
- "label": 0,
12
- "mask": 1,
 
13
  "score": 1.0,
14
- }]
 
1
+ import json
2
  from typing import Any, Dict, List
3
 
4
+ import tensorflow as tf
5
+ from tensorflow import keras
6
+ import base64
7
+ import io
8
+ import os
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+
13
+
14
  class PreTrainedPipeline():
15
  def __init__(self, path: str):
16
  # load the model
17
+ self.model = keras.models.load_model(os.path.join(path, "tf_model.h5"))
18
+
19
+ def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:
20
+
21
+ # convert img to numpy array, resize and normalize to make the prediction
22
+ img = np.array(inputs)
23
+
24
+ im = tf.image.resize(img, (128, 128))
25
+ im = tf.cast(im, tf.float32) / 255.0
26
+ pred_mask = self.model.predict(im[tf.newaxis, ...])
27
+
28
+ # take the best performing class for each pixel
29
+ # the output of argmax looks like this [[1, 2, 0], ...]
30
+ pred_mask_arg = tf.argmax(pred_mask, axis=-1)
31
+
32
+ labels = []
33
+
34
+ # convert the prediction mask into binary masks for each class
35
+ binary_masks = {}
36
+ mask_codes = {}
37
+
38
+ # when we take tf.argmax() over pred_mask, it becomes a tensor object
39
+ # the shape becomes TensorShape object, looking like this TensorShape([128])
40
+ # we need to take get shape, convert to list and take the best one
41
+
42
+ rows = pred_mask_arg[0][1].get_shape().as_list()[0]
43
+ cols = pred_mask_arg[0][2].get_shape().as_list()[0]
44
+
45
+ for cls in range(pred_mask.shape[-1]):
46
+
47
+ binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
48
+
49
+ for row in range(rows):
50
+
51
+ for col in range(cols):
52
+
53
+ if pred_mask_arg[0][row][col] == cls:
54
+
55
+ binary_masks[f"mask_{cls}"][row][col] = 1
56
+ else:
57
+ binary_masks[f"mask_{cls}"][row][col] = 0
58
+
59
+ mask = binary_masks[f"mask_{cls}"]
60
+ mask *= 255
61
+ img = Image.fromarray(mask.astype(np.int8), mode="L")
62
+
63
+ # we need to make it readable for the widget
64
+ with io.BytesIO() as out:
65
+ img.save(out, format="PNG")
66
+ png_string = out.getvalue()
67
+ mask = base64.b64encode(png_string).decode("utf-8")
68
 
69
+ mask_codes[f"mask_{cls}"] = mask
70
+
71
 
72
+ # widget needs the below format, for each class we return label and mask string
73
+ labels.append({
74
+ "label": f"LABEL_{cls}",
75
+ "mask": mask_codes[f"mask_{cls}"],
76
  "score": 1.0,
77
+ })
78
+ return labels
raw_output.jpg ADDED
special_tokens_map.json DELETED
@@ -1 +0,0 @@
1
- {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
 
scheduler.pt β†’ tf_model.h5 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:950d6c7290e92072e8c7281edabc31c607439497fc3ee6bbfd903b99ea82e72f
3
- size 623
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0258ea75c11d977fae78f747902e48541c5e6996d3d5c700175454ffeb42aa0f
3
+ size 63661584
tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
tokenizer_config.json DELETED
@@ -1 +0,0 @@
1
- {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "models/classifier-85000", "tokenizer_class": "BertTokenizer"}
 
trainer_state.json DELETED
@@ -1,379 +0,0 @@
1
- {
2
- "best_metric": null,
3
- "best_model_checkpoint": null,
4
- "epoch": 1.8509766855702583,
5
- "global_step": 235000,
6
- "is_hyper_param_search": false,
7
- "is_local_process_zero": true,
8
- "is_world_process_zero": true,
9
- "log_history": [
10
- {
11
- "epoch": 0.04,
12
- "learning_rate": 1.921235034656585e-05,
13
- "loss": 0.3334,
14
- "step": 5000
15
- },
16
- {
17
- "epoch": 0.08,
18
- "learning_rate": 1.8424700693131696e-05,
19
- "loss": 0.3387,
20
- "step": 10000
21
- },
22
- {
23
- "epoch": 0.12,
24
- "learning_rate": 1.7637051039697544e-05,
25
- "loss": 0.3327,
26
- "step": 15000
27
- },
28
- {
29
- "epoch": 0.16,
30
- "learning_rate": 1.684940138626339e-05,
31
- "loss": 0.3492,
32
- "step": 20000
33
- },
34
- {
35
- "epoch": 0.2,
36
- "learning_rate": 1.606175173282924e-05,
37
- "loss": 0.3349,
38
- "step": 25000
39
- },
40
- {
41
- "epoch": 0.2,
42
- "eval_accuracy": 0.9155995845794678,
43
- "eval_loss": 0.389266699552536,
44
- "eval_runtime": 551.1932,
45
- "eval_samples_per_second": 51.289,
46
- "eval_steps_per_second": 12.823,
47
- "step": 25000
48
- },
49
- {
50
- "epoch": 0.24,
51
- "learning_rate": 1.5274102079395087e-05,
52
- "loss": 0.3279,
53
- "step": 30000
54
- },
55
- {
56
- "epoch": 0.28,
57
- "learning_rate": 1.4486452425960932e-05,
58
- "loss": 0.3301,
59
- "step": 35000
60
- },
61
- {
62
- "epoch": 0.32,
63
- "learning_rate": 1.369880277252678e-05,
64
- "loss": 0.3243,
65
- "step": 40000
66
- },
67
- {
68
- "epoch": 0.35,
69
- "learning_rate": 1.2911153119092628e-05,
70
- "loss": 0.293,
71
- "step": 45000
72
- },
73
- {
74
- "epoch": 0.39,
75
- "learning_rate": 1.2123503465658477e-05,
76
- "loss": 0.3053,
77
- "step": 50000
78
- },
79
- {
80
- "epoch": 0.39,
81
- "eval_accuracy": 0.9235231876373291,
82
- "eval_loss": 0.3810465931892395,
83
- "eval_runtime": 542.4272,
84
- "eval_samples_per_second": 52.118,
85
- "eval_steps_per_second": 13.03,
86
- "step": 50000
87
- },
88
- {
89
- "epoch": 0.43,
90
- "learning_rate": 1.1335853812224324e-05,
91
- "loss": 0.3126,
92
- "step": 55000
93
- },
94
- {
95
- "epoch": 0.47,
96
- "learning_rate": 1.0548204158790173e-05,
97
- "loss": 0.3072,
98
- "step": 60000
99
- },
100
- {
101
- "epoch": 0.51,
102
- "learning_rate": 9.760554505356018e-06,
103
- "loss": 0.2957,
104
- "step": 65000
105
- },
106
- {
107
- "epoch": 0.55,
108
- "learning_rate": 8.972904851921865e-06,
109
- "loss": 0.2968,
110
- "step": 70000
111
- },
112
- {
113
- "epoch": 0.59,
114
- "learning_rate": 8.185255198487714e-06,
115
- "loss": 0.2882,
116
- "step": 75000
117
- },
118
- {
119
- "epoch": 0.59,
120
- "eval_accuracy": 0.9224973320960999,
121
- "eval_loss": 0.37537074089050293,
122
- "eval_runtime": 521.9317,
123
- "eval_samples_per_second": 54.164,
124
- "eval_steps_per_second": 13.542,
125
- "step": 75000
126
- },
127
- {
128
- "epoch": 0.63,
129
- "learning_rate": 7.3976055450535615e-06,
130
- "loss": 0.2754,
131
- "step": 80000
132
- },
133
- {
134
- "epoch": 0.67,
135
- "learning_rate": 6.6099558916194085e-06,
136
- "loss": 0.2607,
137
- "step": 85000
138
- },
139
- {
140
- "epoch": 0.71,
141
- "learning_rate": 5.8223062381852555e-06,
142
- "loss": 0.2818,
143
- "step": 90000
144
- },
145
- {
146
- "epoch": 0.75,
147
- "learning_rate": 5.034656584751103e-06,
148
- "loss": 0.2736,
149
- "step": 95000
150
- },
151
- {
152
- "epoch": 0.79,
153
- "learning_rate": 4.24700693131695e-06,
154
- "loss": 0.2644,
155
- "step": 100000
156
- },
157
- {
158
- "epoch": 0.79,
159
- "eval_accuracy": 0.9297842383384705,
160
- "eval_loss": 0.3645715117454529,
161
- "eval_runtime": 521.9055,
162
- "eval_samples_per_second": 54.167,
163
- "eval_steps_per_second": 13.543,
164
- "step": 100000
165
- },
166
- {
167
- "epoch": 0.83,
168
- "learning_rate": 3.459357277882798e-06,
169
- "loss": 0.2552,
170
- "step": 105000
171
- },
172
- {
173
- "epoch": 0.87,
174
- "learning_rate": 2.6717076244486457e-06,
175
- "loss": 0.266,
176
- "step": 110000
177
- },
178
- {
179
- "epoch": 0.91,
180
- "learning_rate": 1.884057971014493e-06,
181
- "loss": 0.2684,
182
- "step": 115000
183
- },
184
- {
185
- "epoch": 0.95,
186
- "learning_rate": 1.0964083175803404e-06,
187
- "loss": 0.2501,
188
- "step": 120000
189
- },
190
- {
191
- "epoch": 0.98,
192
- "learning_rate": 3.087586641461878e-07,
193
- "loss": 0.273,
194
- "step": 125000
195
- },
196
- {
197
- "epoch": 0.98,
198
- "eval_accuracy": 0.9299964904785156,
199
- "eval_loss": 0.3369257152080536,
200
- "eval_runtime": 522.7551,
201
- "eval_samples_per_second": 54.079,
202
- "eval_steps_per_second": 13.521,
203
- "step": 125000
204
- },
205
- {
206
- "epoch": 1.02,
207
- "learning_rate": 1.7952110901071204e-05,
208
- "loss": 0.2834,
209
- "step": 130000
210
- },
211
- {
212
- "epoch": 1.06,
213
- "learning_rate": 1.787334593572779e-05,
214
- "loss": 0.3047,
215
- "step": 135000
216
- },
217
- {
218
- "epoch": 1.1,
219
- "learning_rate": 1.7794580970384373e-05,
220
- "loss": 0.2963,
221
- "step": 140000
222
- },
223
- {
224
- "epoch": 1.14,
225
- "learning_rate": 1.771581600504096e-05,
226
- "loss": 0.3031,
227
- "step": 145000
228
- },
229
- {
230
- "epoch": 1.18,
231
- "learning_rate": 1.7637051039697544e-05,
232
- "loss": 0.3033,
233
- "step": 150000
234
- },
235
- {
236
- "epoch": 1.18,
237
- "eval_accuracy": 0.9257162809371948,
238
- "eval_loss": 0.4006378650665283,
239
- "eval_runtime": 519.4649,
240
- "eval_samples_per_second": 54.421,
241
- "eval_steps_per_second": 13.606,
242
- "step": 150000
243
- },
244
- {
245
- "epoch": 1.22,
246
- "learning_rate": 1.755828607435413e-05,
247
- "loss": 0.3024,
248
- "step": 155000
249
- },
250
- {
251
- "epoch": 1.26,
252
- "learning_rate": 1.7479521109010713e-05,
253
- "loss": 0.3135,
254
- "step": 160000
255
- },
256
- {
257
- "epoch": 1.3,
258
- "learning_rate": 1.74007561436673e-05,
259
- "loss": 0.3137,
260
- "step": 165000
261
- },
262
- {
263
- "epoch": 1.34,
264
- "learning_rate": 1.732199117832388e-05,
265
- "loss": 0.3227,
266
- "step": 170000
267
- },
268
- {
269
- "epoch": 1.38,
270
- "learning_rate": 1.7243226212980467e-05,
271
- "loss": 0.3246,
272
- "step": 175000
273
- },
274
- {
275
- "epoch": 1.38,
276
- "eval_accuracy": 0.924018383026123,
277
- "eval_loss": 0.3924681842327118,
278
- "eval_runtime": 518.8244,
279
- "eval_samples_per_second": 54.489,
280
- "eval_steps_per_second": 13.623,
281
- "step": 175000
282
- },
283
- {
284
- "epoch": 1.42,
285
- "learning_rate": 1.7164461247637053e-05,
286
- "loss": 0.3281,
287
- "step": 180000
288
- },
289
- {
290
- "epoch": 1.46,
291
- "learning_rate": 1.708569628229364e-05,
292
- "loss": 0.3256,
293
- "step": 185000
294
- },
295
- {
296
- "epoch": 1.5,
297
- "learning_rate": 1.700693131695022e-05,
298
- "loss": 0.313,
299
- "step": 190000
300
- },
301
- {
302
- "epoch": 1.54,
303
- "learning_rate": 1.6928166351606807e-05,
304
- "loss": 0.3313,
305
- "step": 195000
306
- },
307
- {
308
- "epoch": 1.58,
309
- "learning_rate": 1.684940138626339e-05,
310
- "loss": 0.2953,
311
- "step": 200000
312
- },
313
- {
314
- "epoch": 1.58,
315
- "eval_accuracy": 0.9212592840194702,
316
- "eval_loss": 0.3895967900753021,
317
- "eval_runtime": 526.2623,
318
- "eval_samples_per_second": 53.718,
319
- "eval_steps_per_second": 13.431,
320
- "step": 200000
321
- },
322
- {
323
- "epoch": 1.61,
324
- "learning_rate": 1.6770636420919976e-05,
325
- "loss": 0.3103,
326
- "step": 205000
327
- },
328
- {
329
- "epoch": 1.65,
330
- "learning_rate": 1.669187145557656e-05,
331
- "loss": 0.3089,
332
- "step": 210000
333
- },
334
- {
335
- "epoch": 1.69,
336
- "learning_rate": 1.6613106490233147e-05,
337
- "loss": 0.3095,
338
- "step": 215000
339
- },
340
- {
341
- "epoch": 1.73,
342
- "learning_rate": 1.653434152488973e-05,
343
- "loss": 0.3288,
344
- "step": 220000
345
- },
346
- {
347
- "epoch": 1.77,
348
- "learning_rate": 1.6455576559546316e-05,
349
- "loss": 0.3199,
350
- "step": 225000
351
- },
352
- {
353
- "epoch": 1.77,
354
- "eval_accuracy": 0.9203749299049377,
355
- "eval_loss": 0.3942428529262543,
356
- "eval_runtime": 520.6801,
357
- "eval_samples_per_second": 54.294,
358
- "eval_steps_per_second": 13.575,
359
- "step": 225000
360
- },
361
- {
362
- "epoch": 1.81,
363
- "learning_rate": 1.6376811594202898e-05,
364
- "loss": 0.306,
365
- "step": 230000
366
- },
367
- {
368
- "epoch": 1.85,
369
- "learning_rate": 1.6298046628859484e-05,
370
- "loss": 0.3104,
371
- "step": 235000
372
- }
373
- ],
374
- "max_steps": 1269600,
375
- "num_train_epochs": 10,
376
- "total_flos": 2.473283070586798e+17,
377
- "trial_name": null,
378
- "trial_params": null
379
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pytorch_model.bin β†’ variables/variables.data-00000-of-00001 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5adefcec339329c7c3d2ef2c94a00bf7f19a361c6808d54a7a987169f372f491
3
- size 438084653
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3c27336aaecafd070749d833780e6603b9c25caf01a037c0ef9a93ff3b0c36c
3
+ size 63405929
training_args.bin β†’ variables/variables.index RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:561b187687b39a679c0f6eebbf20fc3ece81123ada2b8481d039e8076a99fef3
3
- size 2991
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c109c3a99fb1bfbd8298c3f63cfee52e1bd8e8f10f138796eddbac360eaa0de1
3
+ size 17873
vocab.txt DELETED
The diff for this file is too large to render. See raw diff