kmewhort commited on
Commit
dff82eb
1 Parent(s): 7a3df5e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +52 -9
README.md CHANGED
@@ -14,22 +14,65 @@ should probably proofread and complete it, then remove this comment. -->
14
 
15
  # beit-sketch-classifier
16
 
17
- This model is a fine-tuned version of [microsoft/beit-base-patch16-224-pt22k-ft22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k-ft22k) on the None dataset.
18
  It achieves the following results on the evaluation set:
19
  - Loss: 1.6083
20
  - Accuracy: 0.7480
21
 
22
- ## Model description
23
-
24
- More information needed
25
-
26
  ## Intended uses & limitations
27
 
28
- More information needed
29
-
30
- ## Training and evaluation data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- More information needed
33
 
34
  ## Training procedure
35
 
 
14
 
15
  # beit-sketch-classifier
16
 
17
+ This model is a version of [microsoft/beit-base-patch16-224-pt22k-ft22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k-ft22k) fine-tuned on a dataset of Quick!Draw! sketches ([1 percent of the 50M sketches](https://huggingface.co/datasets/kmewhort/quickdraw-bins-1pct-sample)).
18
  It achieves the following results on the evaluation set:
19
  - Loss: 1.6083
20
  - Accuracy: 0.7480
21
 
 
 
 
 
22
  ## Intended uses & limitations
23
 
24
+ It's intended to be used to classifier sketches with a line-segment input format (there's no data augmentation in the fine-tuning; the input raster images ideally need to be generated from line-vector format very similarly to the training images).
25
+
26
+ You can generate the requisite PIL images from Quickdraw `bin` format with the following:
27
+
28
+ ```
29
+ # packed bytes -> dict (fro mhttps://github.com/googlecreativelab/quickdraw-dataset/blob/master/examples/binary_file_parser.py)
30
+ def unpack_drawing(file_handle):
31
+ key_id, = unpack('Q', file_handle.read(8))
32
+ country_code, = unpack('2s', file_handle.read(2))
33
+ recognized, = unpack('b', file_handle.read(1))
34
+ timestamp, = unpack('I', file_handle.read(4))
35
+ n_strokes, = unpack('H', file_handle.read(2))
36
+ image = []
37
+ n_bytes = 17
38
+ for i in range(n_strokes):
39
+ n_points, = unpack('H', file_handle.read(2))
40
+ fmt = str(n_points) + 'B'
41
+ x = unpack(fmt, file_handle.read(n_points))
42
+ y = unpack(fmt, file_handle.read(n_points))
43
+ image.append((x, y))
44
+ n_bytes += 2 + 2*n_points
45
+ result = {
46
+ 'key_id': key_id,
47
+ 'country_code': country_code,
48
+ 'recognized': recognized,
49
+ 'timestamp': timestamp,
50
+ 'image': image,
51
+ }
52
+ return result
53
+
54
+ # packed bin -> RGB PIL
55
+ def binToPIL(packed_drawing):
56
+ padding = 8
57
+ radius = 7
58
+ scale = (224.0-(2*padding)) / 256
59
+
60
+ unpacked = unpack_drawing(io.BytesIO(packed_drawing))
61
+ unpacked_image = unpacked['image']
62
+ image = np.full((224,224), 255, np.uint8)
63
+ for stroke in unpacked['image']:
64
+ prevX = round(stroke[0][0]*scale)
65
+ prevY = round(stroke[1][0]*scale)
66
+ for i in range(1, len(stroke[0])):
67
+ x = round(stroke[0][i]*scale)
68
+ y = round(stroke[1][i]*scale)
69
+ cv2.line(image, (padding+prevX, padding+prevY), (padding+x, padding+y), 0, radius, -1)
70
+ prevX = x
71
+ prevY = y
72
+ pilImage = Image.fromarray(image).convert("RGB")
73
+ return pilImage
74
+ ```
75
 
 
76
 
77
  ## Training procedure
78