File size: 2,813 Bytes
9f595a4
0dec08e
 
9f595a4
3c91a87
3074480
 
3c91a87
 
 
 
 
529bea9
 
 
 
 
 
 
 
 
 
 
1a7e9ae
529bea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efe2834
 
 
 
 
 
 
 
 
 
 
529bea9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# A Multi-task learning model with two prediction heads
* One prediction head classifies between keyword sentences vs statements/questions
* Other prediction head corresponds to classifier for statements vs questions

## Scores
##### Spaadia SQuaD Test acc: **0.9891**
##### Quora Keyword Pairs Test acc: **0.98048**

## Datasets:
Quora Keyword Pairs: https://www.kaggle.com/stefanondisponibile/quora-question-keyword-pairs
Spaadia SQuaD pairs: https://www.kaggle.com/shahrukhkhan/questions-vs-statementsclassificationdataset

## Article
[Medium article](https://medium.com/@shahrukhx01/multi-task-learning-with-transformers-part-1-multi-prediction-heads-b7001cf014bf)
## Demo Notebook
[Colab Notebook Multi-task Query classifiers](https://colab.research.google.com/drive/1R7WcLHxDsVvZXPhr5HBgIWa3BlSZKY6p?usp=sharing)
## Clone the model repo
```bash 
git clone https://huggingface.co/shahrukhx01/bert-multitask-query-classifiers
```
```python
%cd bert-multitask-query-classifiers/
```
## Load model
```python
from multitask_model import BertForSequenceClassification
from transformers import AutoTokenizer
import torch
model = BertForSequenceClassification.from_pretrained(
        "shahrukhx01/bert-multitask-query-classifiers",
        task_labels_map={"quora_keyword_pairs": 2, "spaadia_squad_pairs": 2},
    )
tokenizer = AutoTokenizer.from_pretrained("shahrukhx01/bert-multitask-query-classifiers")
```
## Run inference on both Tasks
```python
from multitask_model import BertForSequenceClassification
from transformers import AutoTokenizer
import torch
model = BertForSequenceClassification.from_pretrained(
        "shahrukhx01/bert-multitask-query-classifiers",
        task_labels_map={"quora_keyword_pairs": 2, "spaadia_squad_pairs": 2},
    )
tokenizer = AutoTokenizer.from_pretrained("shahrukhx01/bert-multitask-query-classifiers")

## Keyword vs Statement/Question Classifier
input = ["keyword query", "is this a keyword query?"]
task_name="quora_keyword_pairs"
sequence = tokenizer(input, padding=True, return_tensors="pt")['input_ids']
logits = model(sequence, task_name=task_name)[0]
predictions = torch.argmax(torch.softmax(logits, dim=1).detach().cpu(), axis=1)
for input, prediction in zip(input, predictions):
  print(f"task: {task_name}, input: {input} \n prediction=> {prediction}")
  print()
  

## Statement vs Question Classifier
input = ["where is berlin?", "is this a keyword query?", "Berlin is in Germany."]
task_name="spaadia_squad_pairs"
sequence = tokenizer(input, padding=True, return_tensors="pt")['input_ids']
logits = model(sequence, task_name=task_name)[0]
predictions = torch.argmax(torch.softmax(logits, dim=1).detach().cpu(), axis=1)
for input, prediction in zip(input, predictions):
  print(f"task: {task_name}, input: {input} \n prediction=> {prediction}")
  print()
```