File size: 3,611 Bytes
7a3df5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f58aa
7a3df5e
e0f58aa
 
7a3df5e
dff82eb
7a3df5e
dff82eb
7a3df5e
dff82eb
7a3df5e
dff82eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3df5e
dff82eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3df5e
 
 
 
 
 
 
e0f58aa
 
7a3df5e
 
e0f58aa
7a3df5e
 
 
e0f58aa
7a3df5e
 
 
e0f58aa
 
 
 
 
7a3df5e
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
---
license: apache-2.0
tags:
- generated_from_trainer
metrics:
- accuracy
model-index:
- name: beit-sketch-classifier
  results: []
---

<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->

# beit-sketch-classifier

This model is a version of [microsoft/beit-base-patch16-224-pt22k-ft22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k-ft22k) fine-tuned on a dataset of Quick!Draw! sketches (~10% of [QuickDraw's 50M sketches](https://huggingface.co/datasets/kmewhort/quickdraw-bins-50M)).
It achieves the following results on the evaluation set:
- Loss: 0.7372
- Accuracy: 0.8098

## Intended uses & limitations

It's intended to be used to classifier sketches with a line-segment input format (there's no data augmentation in the fine-tuning; the input raster images ideally need to be generated from line-vector format very similarly to the training images).

You can generate the requisite PIL images from Quickdraw `bin` format with the following:

```
# packed bytes -> dict (fro mhttps://github.com/googlecreativelab/quickdraw-dataset/blob/master/examples/binary_file_parser.py)
def unpack_drawing(file_handle):
    key_id, = unpack('Q', file_handle.read(8))
    country_code, = unpack('2s', file_handle.read(2))
    recognized, = unpack('b', file_handle.read(1))
    timestamp, = unpack('I', file_handle.read(4))
    n_strokes, = unpack('H', file_handle.read(2))
    image = []
    n_bytes = 17
    for i in range(n_strokes):
        n_points, = unpack('H', file_handle.read(2))
        fmt = str(n_points) + 'B'
        x = unpack(fmt, file_handle.read(n_points))
        y = unpack(fmt, file_handle.read(n_points))
        image.append((x, y))
        n_bytes += 2 + 2*n_points
    result = {
        'key_id': key_id,
        'country_code': country_code,
        'recognized': recognized,
        'timestamp': timestamp,
        'image': image,
    }
    return result

# packed bin -> RGB PIL
def binToPIL(packed_drawing):
    padding = 8
    radius = 7
    scale = (224.0-(2*padding)) / 256
    
    unpacked = unpack_drawing(io.BytesIO(packed_drawing))
    unpacked_image = unpacked['image']
    image = np.full((224,224), 255, np.uint8)
    for stroke in unpacked['image']:
        prevX = round(stroke[0][0]*scale)
        prevY = round(stroke[1][0]*scale)
        for i in range(1, len(stroke[0])):
            x = round(stroke[0][i]*scale)
            y = round(stroke[1][i]*scale)
            cv2.line(image, (padding+prevX, padding+prevY), (padding+x, padding+y), 0, radius, -1)
            prevX = x
            prevY = y
    pilImage = Image.fromarray(image).convert("RGB")     
    return pilImage
```

## Training procedure

### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 5e-05
- train_batch_size: 64
- eval_batch_size: 64
- seed: 42
- gradient_accumulation_steps: 4
- total_train_batch_size: 256
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- lr_scheduler_warmup_ratio: 0.1
- num_epochs: 3

### Training results

| Training Loss | Epoch | Step  | Accuracy | Validation Loss |
|:-------------:|:-----:|:-----:|:--------:|:---------------:|
| 0.939         | 1.0   | 12606 | 0.7853   | 0.8275          |
| 0.7312        | 2.0   | 25212 | 0.7587   | 0.8027          |
| 0.6174        | 3.0   | 37818 | 0.7372   | 0.8098          |


### Framework versions

- Transformers 4.25.1
- Pytorch 1.13.1+cu117
- Datasets 2.7.1
- Tokenizers 0.13.2