File size: 3,385 Bytes
38e80cc
 
 
 
 
 
 
 
 
 
3097a8c
 
38e80cc
 
 
 
 
 
dc6b8b2
3097a8c
 
38e80cc
 
 
 
 
3097a8c
38e80cc
3097a8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38e80cc
 
 
3097a8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38e80cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3097a8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
---
license: apache-2.0
tags:
- transformers
- pytorch
datasets:
- conv_ai_2
model-index:
- name: distillbert_conv_quality_score
  results: []
language:
- en
---


# distillbert_conv_quality_score

This model is a fine-tuned version of [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) on the conv_ai_2 dataset.
It was trained to generate a score (in the [0, 1] range) from a conversation


It achieves the following results on the evaluation set:
- training/loss: 0.0165
- validation/loss: 0.0149


## Usage

```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_name = "alespalla/distillbert_conv_quality_score"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

conversation = '''
Q: Begin
A: lol ! do you think it is strange to feel like you have been through life before ?
Q: Hellow
A: I don't understand you 🙈. Also, try to guess: i like to ...
Q: How are you?
A: make time stop, funny you :)
Q: What is your name?
A: jessie. hows your day going ? 😃
'''

score = model(**tokenizer(conversation, return_tensors='pt')).logits.item()
print(f"Score: {score}")
```

## Training and evaluation data

The training data was generated from `conv_ai_2` using the following function

```python

from datasets import load_dataset

def get_dataset(regression=False):

    db = load_dataset("conv_ai_2")

    def generate_converation(elem):
        text = ""
        for idx, txt in enumerate(elem["dialog"]):
            if idx % 2:
                text += f"A: {txt['text']}\n"
            else:
                text += f"Q: {txt['text']}\n"
        if regression:
            return {'text': text, "labels": (elem['eval_score'] - 1)/4}
        return {'text': text, "labels": elem['eval_score'] - 1}

    db = db.filter(lambda example: example["eval_score"] > 0)
    db = db.map(generate_converation, remove_columns=db['train'].column_names)
    db = db['train'].train_test_split(test_size=0.2).shuffle(42)

    return db

```

## Training procedure

### Training hyperparameters

The following hyperparameters were used during training:
- epochs: 40
- batch_size: 16
- learning_rate: 0.0002
- eval_steps: 82
- log_steps: 82
- save_steps: 41
- gradient_accumulation_steps: 1
- warmup_steps: 0

### Training results

| step | training/loss | validation/loss |
|:----:|:-------------:|:---------------:|
| 81   | 0.1020        | 0.0794          |
| 163  | 0.0800        | 0.0713          |
| 245  | 0.0553        | 0.0491          |
| 327  | 0.0362        | 0.0440          |
| 409  | 0.0282        | 0.0352          |
| 491  | 0.0282        | 0.0412          |
| 573  | 0.0256        | 0.0293          |
| 655  | 0.0238        | 0.0252          |
| 737  | 0.0175        | 0.0226          |
| 819  | 0.0154        | 0.0228          |
| 901  | 0.0116        | 0.0205          |
| 983  | 0.0160        | 0.0202          |
| 1065 | 0.0146        | 0.0240          |
| 1147 | 0.0182        | 0.0180          |
| 1229 | 0.0171        | 0.0192          |
| 1311 | 0.0091        | 0.0174          |
| 1393 | 0.0171        | 0.0158          |
| 1475 | 0.0137        | 0.0158          |
| 1557 | 0.0158        | 0.0148          |
| 1639 | 0.0165        | 0.0149          |


### Framework versions

- Transformers 4.26.1
- Datasets 2.10.1
- Tokenizers 0.13.2