Ubuntu
commited on
Commit
•
e77b318
1
Parent(s):
d0702fa
added intent classification using distil bert
Browse files- data_intent/intent_data.csv +3 -0
- intent_classification_model/checkpoint-324/added_tokens.json +7 -0
- intent_classification_model/checkpoint-324/config.json +39 -0
- intent_classification_model/checkpoint-324/optimizer.pt +3 -0
- intent_classification_model/checkpoint-324/pytorch_model.bin +3 -0
- intent_classification_model/checkpoint-324/rng_state.pth +0 -0
- intent_classification_model/checkpoint-324/scheduler.pt +3 -0
- intent_classification_model/checkpoint-324/special_tokens_map.json +7 -0
- intent_classification_model/checkpoint-324/tokenizer.json +0 -0
- intent_classification_model/checkpoint-324/tokenizer_config.json +56 -0
- intent_classification_model/checkpoint-324/trainer_state.json +73 -0
- intent_classification_model/checkpoint-324/training_args.bin +3 -0
- intent_classification_model/checkpoint-324/vocab.txt +0 -0
- intent_classification_model/runs/Oct13_09-06-59_ip-172-31-95-165/events.out.tfevents.1697188019.ip-172-31-95-165.137562.0 +0 -0
- intent_classification_model/runs/Oct13_09-08-12_ip-172-31-95-165/events.out.tfevents.1697188092.ip-172-31-95-165.137562.1 +0 -0
- intent_classification_model/runs/Oct13_09-08-49_ip-172-31-95-165/events.out.tfevents.1697188130.ip-172-31-95-165.137562.2 +0 -0
- intent_classification_model/runs/Oct13_09-09-35_ip-172-31-95-165/events.out.tfevents.1697188176.ip-172-31-95-165.137562.3 +0 -0
- intent_classification_model/runs/Oct13_09-10-07_ip-172-31-95-165/events.out.tfevents.1697188208.ip-172-31-95-165.138160.0 +0 -0
- research/04_inference.ipynb +217 -0
- research/10_demo_test_data.ipynb +19 -10
- research/11_evaluation.html +0 -0
- research/11_evaluation.ipynb +290 -0
- research/11_intent_classification_using_distilbert.ipynb +898 -0
- utils/__pycache__/get_category.cpython-310.pyc +0 -0
- utils/__pycache__/get_intent.cpython-310.pyc +0 -0
- utils/__pycache__/get_sentence_status.cpython-310.pyc +0 -0
- utils/get_category.py +8 -4
- utils/get_intent.py +69 -0
- utils/get_sentence_status.py +48 -1
data_intent/intent_data.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:24091e2e977d444be178138ac717fa57b8d16534dcf5e66d4084cf3f77e6f6ce
|
3 |
+
size 39551
|
intent_classification_model/checkpoint-324/added_tokens.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"[CLS]": 101,
|
3 |
+
"[MASK]": 103,
|
4 |
+
"[PAD]": 0,
|
5 |
+
"[SEP]": 102,
|
6 |
+
"[UNK]": 100
|
7 |
+
}
|
intent_classification_model/checkpoint-324/config.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "distilbert-base-uncased",
|
3 |
+
"activation": "gelu",
|
4 |
+
"architectures": [
|
5 |
+
"DistilBertForSequenceClassification"
|
6 |
+
],
|
7 |
+
"attention_dropout": 0.1,
|
8 |
+
"dim": 768,
|
9 |
+
"dropout": 0.1,
|
10 |
+
"hidden_dim": 3072,
|
11 |
+
"id2label": {
|
12 |
+
"0": "Commercial",
|
13 |
+
"1": "Informational",
|
14 |
+
"2": "Navigational",
|
15 |
+
"3": "Local",
|
16 |
+
"4": "Transactional"
|
17 |
+
},
|
18 |
+
"initializer_range": 0.02,
|
19 |
+
"label2id": {
|
20 |
+
"Commercial": 0,
|
21 |
+
"Informational": 1,
|
22 |
+
"Local": 3,
|
23 |
+
"Navigational": 2,
|
24 |
+
"Transactional": 4
|
25 |
+
},
|
26 |
+
"max_position_embeddings": 512,
|
27 |
+
"model_type": "distilbert",
|
28 |
+
"n_heads": 12,
|
29 |
+
"n_layers": 6,
|
30 |
+
"pad_token_id": 0,
|
31 |
+
"problem_type": "single_label_classification",
|
32 |
+
"qa_dropout": 0.1,
|
33 |
+
"seq_classif_dropout": 0.2,
|
34 |
+
"sinusoidal_pos_embds": false,
|
35 |
+
"tie_weights_": true,
|
36 |
+
"torch_dtype": "float32",
|
37 |
+
"transformers_version": "4.34.0",
|
38 |
+
"vocab_size": 30522
|
39 |
+
}
|
intent_classification_model/checkpoint-324/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a50f88f7a9097ecddb2b3c7e3d38747deec4ca3a386132fac9e0e4efaa82ae0e
|
3 |
+
size 535745722
|
intent_classification_model/checkpoint-324/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b339df5c0d892e025a1749d085ab010e551f4b249eb497812a1a3bd7ebd5fd99
|
3 |
+
size 267865194
|
intent_classification_model/checkpoint-324/rng_state.pth
ADDED
Binary file (14.2 kB). View file
|
|
intent_classification_model/checkpoint-324/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73f74582c189fe624f606122980ccb279125588a1db45b4052dc704fa2b51184
|
3 |
+
size 1064
|
intent_classification_model/checkpoint-324/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
intent_classification_model/checkpoint-324/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
intent_classification_model/checkpoint-324/tokenizer_config.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"100": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"101": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"102": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"103": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"additional_special_tokens": [],
|
45 |
+
"clean_up_tokenization_spaces": true,
|
46 |
+
"cls_token": "[CLS]",
|
47 |
+
"do_lower_case": true,
|
48 |
+
"mask_token": "[MASK]",
|
49 |
+
"model_max_length": 512,
|
50 |
+
"pad_token": "[PAD]",
|
51 |
+
"sep_token": "[SEP]",
|
52 |
+
"strip_accents": null,
|
53 |
+
"tokenize_chinese_chars": true,
|
54 |
+
"tokenizer_class": "DistilBertTokenizer",
|
55 |
+
"unk_token": "[UNK]"
|
56 |
+
}
|
intent_classification_model/checkpoint-324/trainer_state.json
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": 0.16397738456726074,
|
3 |
+
"best_model_checkpoint": "intent_classification_model/checkpoint-270",
|
4 |
+
"epoch": 6.0,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 324,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 1.0,
|
13 |
+
"eval_accuracy": 0.9488372093023256,
|
14 |
+
"eval_loss": 0.4676927328109741,
|
15 |
+
"eval_runtime": 0.1185,
|
16 |
+
"eval_samples_per_second": 1814.083,
|
17 |
+
"eval_steps_per_second": 118.126,
|
18 |
+
"step": 54
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"epoch": 2.0,
|
22 |
+
"eval_accuracy": 0.9534883720930233,
|
23 |
+
"eval_loss": 0.20428764820098877,
|
24 |
+
"eval_runtime": 0.0972,
|
25 |
+
"eval_samples_per_second": 2210.83,
|
26 |
+
"eval_steps_per_second": 143.961,
|
27 |
+
"step": 108
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"epoch": 3.0,
|
31 |
+
"eval_accuracy": 0.9674418604651163,
|
32 |
+
"eval_loss": 0.16401757299900055,
|
33 |
+
"eval_runtime": 0.1015,
|
34 |
+
"eval_samples_per_second": 2118.828,
|
35 |
+
"eval_steps_per_second": 137.97,
|
36 |
+
"step": 162
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"epoch": 4.0,
|
40 |
+
"eval_accuracy": 0.9674418604651163,
|
41 |
+
"eval_loss": 0.16496841609477997,
|
42 |
+
"eval_runtime": 0.0941,
|
43 |
+
"eval_samples_per_second": 2284.398,
|
44 |
+
"eval_steps_per_second": 148.752,
|
45 |
+
"step": 216
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"epoch": 5.0,
|
49 |
+
"eval_accuracy": 0.9674418604651163,
|
50 |
+
"eval_loss": 0.16397738456726074,
|
51 |
+
"eval_runtime": 0.0975,
|
52 |
+
"eval_samples_per_second": 2204.851,
|
53 |
+
"eval_steps_per_second": 143.572,
|
54 |
+
"step": 270
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"epoch": 6.0,
|
58 |
+
"eval_accuracy": 0.9674418604651163,
|
59 |
+
"eval_loss": 0.16553252935409546,
|
60 |
+
"eval_runtime": 0.0947,
|
61 |
+
"eval_samples_per_second": 2271.063,
|
62 |
+
"eval_steps_per_second": 147.883,
|
63 |
+
"step": 324
|
64 |
+
}
|
65 |
+
],
|
66 |
+
"logging_steps": 500,
|
67 |
+
"max_steps": 324,
|
68 |
+
"num_train_epochs": 6,
|
69 |
+
"save_steps": 500,
|
70 |
+
"total_flos": 13032177536640.0,
|
71 |
+
"trial_name": null,
|
72 |
+
"trial_params": null
|
73 |
+
}
|
intent_classification_model/checkpoint-324/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c27308f0087e544f12e1806abafb33d65745a5791fb1559d9e521f3670215df9
|
3 |
+
size 4536
|
intent_classification_model/checkpoint-324/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
intent_classification_model/runs/Oct13_09-06-59_ip-172-31-95-165/events.out.tfevents.1697188019.ip-172-31-95-165.137562.0
ADDED
Binary file (5.3 kB). View file
|
|
intent_classification_model/runs/Oct13_09-08-12_ip-172-31-95-165/events.out.tfevents.1697188092.ip-172-31-95-165.137562.1
ADDED
Binary file (6.02 kB). View file
|
|
intent_classification_model/runs/Oct13_09-08-49_ip-172-31-95-165/events.out.tfevents.1697188130.ip-172-31-95-165.137562.2
ADDED
Binary file (5.93 kB). View file
|
|
intent_classification_model/runs/Oct13_09-09-35_ip-172-31-95-165/events.out.tfevents.1697188176.ip-172-31-95-165.137562.3
ADDED
Binary file (4.73 kB). View file
|
|
intent_classification_model/runs/Oct13_09-10-07_ip-172-31-95-165/events.out.tfevents.1697188208.ip-172-31-95-165.138160.0
ADDED
Binary file (6.6 kB). View file
|
|
research/04_inference.ipynb
CHANGED
@@ -673,6 +673,223 @@
|
|
673 |
"There are a few reasons why language modeling people like perplexity instead of just using entropy. One is that, because of the exponent, improvements in perplexity \"feel\" like they are more substantial than the equivalent improvement in entropy. Another is that before they started using perplexity, the complexity of a language model was reported using a simplistic branching factor measurement that is more similar to perplexity than it is to entropy.''')"
|
674 |
]
|
675 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
{
|
677 |
"cell_type": "code",
|
678 |
"execution_count": null,
|
|
|
673 |
"There are a few reasons why language modeling people like perplexity instead of just using entropy. One is that, because of the exponent, improvements in perplexity \"feel\" like they are more substantial than the equivalent improvement in entropy. Another is that before they started using perplexity, the complexity of a language model was reported using a simplistic branching factor measurement that is more similar to perplexity than it is to entropy.''')"
|
674 |
]
|
675 |
},
|
676 |
+
{
|
677 |
+
"cell_type": "code",
|
678 |
+
"execution_count": null,
|
679 |
+
"metadata": {},
|
680 |
+
"outputs": [],
|
681 |
+
"source": []
|
682 |
+
},
|
683 |
+
{
|
684 |
+
"cell_type": "code",
|
685 |
+
"execution_count": 1,
|
686 |
+
"metadata": {},
|
687 |
+
"outputs": [],
|
688 |
+
"source": [
|
689 |
+
"import os; os.chdir(\n",
|
690 |
+
" '..'\n",
|
691 |
+
")"
|
692 |
+
]
|
693 |
+
},
|
694 |
+
{
|
695 |
+
"cell_type": "code",
|
696 |
+
"execution_count": 2,
|
697 |
+
"metadata": {},
|
698 |
+
"outputs": [
|
699 |
+
{
|
700 |
+
"name": "stderr",
|
701 |
+
"output_type": "stream",
|
702 |
+
"text": [
|
703 |
+
"/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
704 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
705 |
+
]
|
706 |
+
},
|
707 |
+
{
|
708 |
+
"name": "stderr",
|
709 |
+
"output_type": "stream",
|
710 |
+
"text": [
|
711 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
712 |
+
]
|
713 |
+
}
|
714 |
+
],
|
715 |
+
"source": [
|
716 |
+
"from utils.get_sentence_status import get_top_labels"
|
717 |
+
]
|
718 |
+
},
|
719 |
+
{
|
720 |
+
"cell_type": "code",
|
721 |
+
"execution_count": 4,
|
722 |
+
"metadata": {},
|
723 |
+
"outputs": [
|
724 |
+
{
|
725 |
+
"data": {
|
726 |
+
"text/plain": [
|
727 |
+
"[('Human Written', 0.999), ('AI written', 0.002)]"
|
728 |
+
]
|
729 |
+
},
|
730 |
+
"execution_count": 4,
|
731 |
+
"metadata": {},
|
732 |
+
"output_type": "execute_result"
|
733 |
+
}
|
734 |
+
],
|
735 |
+
"source": [
|
736 |
+
"get_top_labels('''12\n",
|
737 |
+
"\n",
|
738 |
+
"Yes, the perplexity is always equal to two to the power of the entropy. It doesn't matter what type of model you have, n-gram, unigram, or neural network.\n",
|
739 |
+
"\n",
|
740 |
+
"There are a few reasons why language modeling people like perplexity instead of just using entropy. One is that, because of the exponent, improvements in perplexity \"feel\" like they are more substantial than the equivalent improvement in entropy. Another is that before they started using perplexity, the complexity of a language model was reported using a simplistic branching factor measurement that is more similar to perplexity than it is to entropy.''')"
|
741 |
+
]
|
742 |
+
},
|
743 |
+
{
|
744 |
+
"cell_type": "code",
|
745 |
+
"execution_count": 3,
|
746 |
+
"metadata": {},
|
747 |
+
"outputs": [
|
748 |
+
{
|
749 |
+
"data": {
|
750 |
+
"text/plain": [
|
751 |
+
"[('AI written', 1.0), ('Human Written', 0.0)]"
|
752 |
+
]
|
753 |
+
},
|
754 |
+
"execution_count": 3,
|
755 |
+
"metadata": {},
|
756 |
+
"output_type": "execute_result"
|
757 |
+
}
|
758 |
+
],
|
759 |
+
"source": [
|
760 |
+
"get_top_labels(\n",
|
761 |
+
" 'My name is deepankar'\n",
|
762 |
+
")"
|
763 |
+
]
|
764 |
+
},
|
765 |
+
{
|
766 |
+
"cell_type": "code",
|
767 |
+
"execution_count": 6,
|
768 |
+
"metadata": {},
|
769 |
+
"outputs": [
|
770 |
+
{
|
771 |
+
"data": {
|
772 |
+
"text/plain": [
|
773 |
+
"[('AI written', 0.999), ('Human Written', 0.001)]"
|
774 |
+
]
|
775 |
+
},
|
776 |
+
"execution_count": 6,
|
777 |
+
"metadata": {},
|
778 |
+
"output_type": "execute_result"
|
779 |
+
}
|
780 |
+
],
|
781 |
+
"source": [
|
782 |
+
"get_top_labels(\n",
|
783 |
+
" '''Hate speech or discriminatory content: Hate speech is speech, conduct, writing, or expressions that discriminate or promote discrimination against individuals or groups based on attributes such as race, religion, nationality, gender, sexual orientation, disability, or other characteristics. It often includes offensive language, stereotypes, or harmful stereotypes and can contribute to a hostile or unsafe environment for affected individuals.\n",
|
784 |
+
"\n",
|
785 |
+
"Explicit or adult content: Explicit or adult content typically refers to material that is sexually explicit, pornographic, or contains graphic depictions of sexual acts. This content may not be suitable for all audiences and is subject to age restrictions and content regulations in many jurisdictions.'''\n",
|
786 |
+
")"
|
787 |
+
]
|
788 |
+
},
|
789 |
+
{
|
790 |
+
"cell_type": "code",
|
791 |
+
"execution_count": 8,
|
792 |
+
"metadata": {},
|
793 |
+
"outputs": [
|
794 |
+
{
|
795 |
+
"data": {
|
796 |
+
"text/plain": [
|
797 |
+
"[('AI written', 0.912), ('Human Written', 0.115)]"
|
798 |
+
]
|
799 |
+
},
|
800 |
+
"execution_count": 8,
|
801 |
+
"metadata": {},
|
802 |
+
"output_type": "execute_result"
|
803 |
+
}
|
804 |
+
],
|
805 |
+
"source": [
|
806 |
+
"get_top_labels(\n",
|
807 |
+
" '''Of course, I can provide a more detailed explanation of these topics:\n",
|
808 |
+
"\n",
|
809 |
+
"1. **Hate speech or discriminatory content:** Hate speech is speech, conduct, writing, or expressions that discriminate or promote discrimination against individuals or groups based on attributes such as race, religion, nationality, gender, sexual orientation, disability, or other characteristics. It often includes offensive language, stereotypes, or harmful stereotypes and can contribute to a hostile or unsafe environment for affected individuals.\n",
|
810 |
+
"\n",
|
811 |
+
"2. **Explicit or adult content:** Explicit or adult content typically refers to material that is sexually explicit, pornographic, or contains graphic depictions of sexual acts. This content may not be suitable for all audiences and is subject to age restrictions and content regulations in many jurisdictions.\n",
|
812 |
+
"\n",
|
813 |
+
"9. **Inflammatory or extremist viewpoints:** Inflammatory viewpoints are those that are deliberately provocative, offensive, or designed to incite anger or outrage. Extreme or extremist viewpoints often involve radical ideologies and can contribute to division and hostility in discussions. Engaging in conversations that promote understanding and open dialogue is generally more constructive.\n",
|
814 |
+
"\n",
|
815 |
+
"In summary, these topics can be divisive, offensive, and harmful. When discussing or encountering them, it's essential to approach with respect, empathy, and a focus on maintaining a positive and safe environment for everyone involved.'''\n",
|
816 |
+
")"
|
817 |
+
]
|
818 |
+
},
|
819 |
+
{
|
820 |
+
"cell_type": "code",
|
821 |
+
"execution_count": 9,
|
822 |
+
"metadata": {},
|
823 |
+
"outputs": [
|
824 |
+
{
|
825 |
+
"data": {
|
826 |
+
"text/plain": [
|
827 |
+
"[('AI written', 0.998), ('Human Written', 0.003)]"
|
828 |
+
]
|
829 |
+
},
|
830 |
+
"execution_count": 9,
|
831 |
+
"metadata": {},
|
832 |
+
"output_type": "execute_result"
|
833 |
+
}
|
834 |
+
],
|
835 |
+
"source": [
|
836 |
+
"get_top_labels(\n",
|
837 |
+
" '''The situation in Israel remains tense. More than 1,200 people have been killed so far in the terror attacks by Hamas groups, both by their infiltration and rockets. The southern part of Israel which shares borders with the Gaza Strip still remains vulnerable. \n",
|
838 |
+
"\n",
|
839 |
+
"Ashkelon, one of the biggest cities in South Israel, has become a ghost town. Life is no longer normal here. Post noon till midnight there are a number of siren alarms, creating a constant atmosphere of panic ever since rockets were pounded into the city.'''\n",
|
840 |
+
")"
|
841 |
+
]
|
842 |
+
},
|
843 |
+
{
|
844 |
+
"cell_type": "code",
|
845 |
+
"execution_count": 10,
|
846 |
+
"metadata": {},
|
847 |
+
"outputs": [
|
848 |
+
{
|
849 |
+
"data": {
|
850 |
+
"text/plain": [
|
851 |
+
"[('AI written', 1.0), ('Human Written', 0.0)]"
|
852 |
+
]
|
853 |
+
},
|
854 |
+
"execution_count": 10,
|
855 |
+
"metadata": {},
|
856 |
+
"output_type": "execute_result"
|
857 |
+
}
|
858 |
+
],
|
859 |
+
"source": [
|
860 |
+
"get_top_labels(\n",
|
861 |
+
" '''Optical illusions are fascinating pictures that trick our eyes, making us doubt what's real. They come in different types and can make us question what we see, think, and understand about the world. Even scientists sometimes struggle to figure out these puzzling illusions.\n",
|
862 |
+
"\n",
|
863 |
+
"These illusions have various purposes. They challenge our minds, testing our thinking abilities. But they also provide a special way to delve into our personalities, revealing hidden aspects of who we are.\n",
|
864 |
+
"\n",
|
865 |
+
"The task is simple: look at the image and note what you see first. Your initial observation can unveil your deepest insecurity. Most people see either a ditch surrounded by trees or an eye.'''\n",
|
866 |
+
")"
|
867 |
+
]
|
868 |
+
},
|
869 |
+
{
|
870 |
+
"cell_type": "code",
|
871 |
+
"execution_count": 11,
|
872 |
+
"metadata": {},
|
873 |
+
"outputs": [
|
874 |
+
{
|
875 |
+
"data": {
|
876 |
+
"text/plain": [
|
877 |
+
"[('Human Written', 0.941), ('AI written', 0.056)]"
|
878 |
+
]
|
879 |
+
},
|
880 |
+
"execution_count": 11,
|
881 |
+
"metadata": {},
|
882 |
+
"output_type": "execute_result"
|
883 |
+
}
|
884 |
+
],
|
885 |
+
"source": [
|
886 |
+
"get_top_labels(\n",
|
887 |
+
" '''Learn from IIT Faculty & Industry Experts with Guaranteed Job Interviews.\n",
|
888 |
+
"Campus Immersion at IIT Roorkee.\n",
|
889 |
+
"Master machine learning and artificial intelligence skills with this advanced data science and artificial intelligence course from iHub IIT Roorkee. Learn from IIT faculty and industry experts with 1:1 mentorship in this intensive online bootcamp. Top 2 performers from each batch may get a fellowship worth Rs. 80,000, plus the opportunity to showcase their startup ideas and secure incubation support of upto Rs. 50 Lakhs for their startup from iHUB DivyaSampark, IIT Roorkee.'''\n",
|
890 |
+
")"
|
891 |
+
]
|
892 |
+
},
|
893 |
{
|
894 |
"cell_type": "code",
|
895 |
"execution_count": null,
|
research/10_demo_test_data.ipynb
CHANGED
@@ -768,7 +768,7 @@
|
|
768 |
},
|
769 |
{
|
770 |
"cell_type": "code",
|
771 |
-
"execution_count":
|
772 |
"metadata": {},
|
773 |
"outputs": [],
|
774 |
"source": [
|
@@ -785,7 +785,13 @@
|
|
785 |
"output_type": "stream",
|
786 |
"text": [
|
787 |
"/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
788 |
-
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
789 |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
790 |
]
|
791 |
}
|
@@ -802,7 +808,10 @@
|
|
802 |
{
|
803 |
"data": {
|
804 |
"text/plain": [
|
805 |
-
"[('Food_and_Drink', 0.
|
|
|
|
|
|
|
806 |
]
|
807 |
},
|
808 |
"execution_count": 3,
|
@@ -818,38 +827,38 @@
|
|
818 |
},
|
819 |
{
|
820 |
"cell_type": "code",
|
821 |
-
"execution_count":
|
822 |
"metadata": {},
|
823 |
"outputs": [
|
824 |
{
|
825 |
"data": {
|
826 |
"text/plain": [
|
827 |
-
"[('
|
828 |
]
|
829 |
},
|
830 |
-
"execution_count":
|
831 |
"metadata": {},
|
832 |
"output_type": "execute_result"
|
833 |
}
|
834 |
],
|
835 |
"source": [
|
836 |
"get_top_labels(\n",
|
837 |
-
" '
|
838 |
")"
|
839 |
]
|
840 |
},
|
841 |
{
|
842 |
"cell_type": "code",
|
843 |
-
"execution_count":
|
844 |
"metadata": {},
|
845 |
"outputs": [
|
846 |
{
|
847 |
"data": {
|
848 |
"text/plain": [
|
849 |
-
"[('Home_and_Garden',
|
850 |
]
|
851 |
},
|
852 |
-
"execution_count":
|
853 |
"metadata": {},
|
854 |
"output_type": "execute_result"
|
855 |
}
|
|
|
768 |
},
|
769 |
{
|
770 |
"cell_type": "code",
|
771 |
+
"execution_count": 2,
|
772 |
"metadata": {},
|
773 |
"outputs": [],
|
774 |
"source": [
|
|
|
785 |
"output_type": "stream",
|
786 |
"text": [
|
787 |
"/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
788 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
789 |
+
]
|
790 |
+
},
|
791 |
+
{
|
792 |
+
"name": "stderr",
|
793 |
+
"output_type": "stream",
|
794 |
+
"text": [
|
795 |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
796 |
]
|
797 |
}
|
|
|
808 |
{
|
809 |
"data": {
|
810 |
"text/plain": [
|
811 |
+
"[('Food_and_Drink', 0.989),\n",
|
812 |
+
" ('Computers_and_Electronics', 0.973),\n",
|
813 |
+
" ('Games', 0.172),\n",
|
814 |
+
" ('Shopping', 0.134)]"
|
815 |
]
|
816 |
},
|
817 |
"execution_count": 3,
|
|
|
827 |
},
|
828 |
{
|
829 |
"cell_type": "code",
|
830 |
+
"execution_count": 4,
|
831 |
"metadata": {},
|
832 |
"outputs": [
|
833 |
{
|
834 |
"data": {
|
835 |
"text/plain": [
|
836 |
+
"[('Computers_and_Electronics', 0.999), ('Shopping', 0.993)]"
|
837 |
]
|
838 |
},
|
839 |
+
"execution_count": 4,
|
840 |
"metadata": {},
|
841 |
"output_type": "execute_result"
|
842 |
}
|
843 |
],
|
844 |
"source": [
|
845 |
"get_top_labels(\n",
|
846 |
+
" 'amazon mindkoo headsets with discount'\n",
|
847 |
")"
|
848 |
]
|
849 |
},
|
850 |
{
|
851 |
"cell_type": "code",
|
852 |
+
"execution_count": 5,
|
853 |
"metadata": {},
|
854 |
"outputs": [
|
855 |
{
|
856 |
"data": {
|
857 |
"text/plain": [
|
858 |
+
"[('Home_and_Garden', 0.999), ('Computers_and_Electronics', 0.243)]"
|
859 |
]
|
860 |
},
|
861 |
+
"execution_count": 5,
|
862 |
"metadata": {},
|
863 |
"output_type": "execute_result"
|
864 |
}
|
research/11_evaluation.html
ADDED
The diff for this file is too large to render.
See raw diff
|
|
research/11_evaluation.ipynb
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os; os.chdir('..')"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 2,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"from utils.get_intent import get_top_intent"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 3,
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [
|
26 |
+
{
|
27 |
+
"data": {
|
28 |
+
"text/plain": [
|
29 |
+
"[('Commercial', 0.969),\n",
|
30 |
+
" ('Transactional', 0.673),\n",
|
31 |
+
" ('Informational', 0.237),\n",
|
32 |
+
" ('Navigational', 0.215),\n",
|
33 |
+
" ('Local', 0.155)]"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
"execution_count": 3,
|
37 |
+
"metadata": {},
|
38 |
+
"output_type": "execute_result"
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"source": [
|
42 |
+
"get_top_intent(\"best cat ear headphones\")"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 4,
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [
|
50 |
+
{
|
51 |
+
"data": {
|
52 |
+
"text/plain": [
|
53 |
+
"[('Transactional', 0.987),\n",
|
54 |
+
" ('Navigational', 0.317),\n",
|
55 |
+
" ('Commercial', 0.27),\n",
|
56 |
+
" ('Informational', 0.249),\n",
|
57 |
+
" ('Local', 0.229)]"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
"execution_count": 4,
|
61 |
+
"metadata": {},
|
62 |
+
"output_type": "execute_result"
|
63 |
+
}
|
64 |
+
],
|
65 |
+
"source": [
|
66 |
+
"get_top_intent(\"buy cat ear headphones\")"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": 5,
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [
|
74 |
+
{
|
75 |
+
"data": {
|
76 |
+
"text/plain": [
|
77 |
+
"[('Informational', 0.984),\n",
|
78 |
+
" ('Local', 0.244),\n",
|
79 |
+
" ('Commercial', 0.237),\n",
|
80 |
+
" ('Transactional', 0.212),\n",
|
81 |
+
" ('Navigational', 0.194)]"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
"execution_count": 5,
|
85 |
+
"metadata": {},
|
86 |
+
"output_type": "execute_result"
|
87 |
+
}
|
88 |
+
],
|
89 |
+
"source": [
|
90 |
+
"get_top_intent(\"how to create a facebook account\")"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": 6,
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [
|
98 |
+
{
|
99 |
+
"data": {
|
100 |
+
"text/plain": [
|
101 |
+
"[('Local', 0.988),\n",
|
102 |
+
" ('Informational', 0.3),\n",
|
103 |
+
" ('Commercial', 0.278),\n",
|
104 |
+
" ('Navigational', 0.273),\n",
|
105 |
+
" ('Transactional', 0.234)]"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
"execution_count": 6,
|
109 |
+
"metadata": {},
|
110 |
+
"output_type": "execute_result"
|
111 |
+
}
|
112 |
+
],
|
113 |
+
"source": [
|
114 |
+
"get_top_intent(\"barber shops in USA\")"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "code",
|
119 |
+
"execution_count": 7,
|
120 |
+
"metadata": {},
|
121 |
+
"outputs": [
|
122 |
+
{
|
123 |
+
"data": {
|
124 |
+
"text/plain": [
|
125 |
+
"[('Informational', 0.763),\n",
|
126 |
+
" ('Navigational', 0.638),\n",
|
127 |
+
" ('Transactional', 0.433),\n",
|
128 |
+
" ('Commercial', 0.286),\n",
|
129 |
+
" ('Local', 0.236)]"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
"execution_count": 7,
|
133 |
+
"metadata": {},
|
134 |
+
"output_type": "execute_result"
|
135 |
+
}
|
136 |
+
],
|
137 |
+
"source": [
|
138 |
+
"get_top_intent(\"Razer Kraken Headsets\")"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"cell_type": "code",
|
143 |
+
"execution_count": 8,
|
144 |
+
"metadata": {},
|
145 |
+
"outputs": [
|
146 |
+
{
|
147 |
+
"data": {
|
148 |
+
"text/plain": [
|
149 |
+
"[('Navigational', 0.861),\n",
|
150 |
+
" ('Transactional', 0.725),\n",
|
151 |
+
" ('Local', 0.422),\n",
|
152 |
+
" ('Commercial', 0.287),\n",
|
153 |
+
" ('Informational', 0.202)]"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
"execution_count": 8,
|
157 |
+
"metadata": {},
|
158 |
+
"output_type": "execute_result"
|
159 |
+
}
|
160 |
+
],
|
161 |
+
"source": [
|
162 |
+
"get_top_intent(\"Amazon Great indian festival\")"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": 9,
|
168 |
+
"metadata": {},
|
169 |
+
"outputs": [
|
170 |
+
{
|
171 |
+
"data": {
|
172 |
+
"text/plain": [
|
173 |
+
"[('Navigational', 0.983),\n",
|
174 |
+
" ('Transactional', 0.27),\n",
|
175 |
+
" ('Local', 0.23),\n",
|
176 |
+
" ('Informational', 0.209),\n",
|
177 |
+
" ('Commercial', 0.192)]"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
"execution_count": 9,
|
181 |
+
"metadata": {},
|
182 |
+
"output_type": "execute_result"
|
183 |
+
}
|
184 |
+
],
|
185 |
+
"source": [
|
186 |
+
"get_top_intent(\"facebook\")"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "code",
|
191 |
+
"execution_count": 10,
|
192 |
+
"metadata": {},
|
193 |
+
"outputs": [
|
194 |
+
{
|
195 |
+
"data": {
|
196 |
+
"text/plain": [
|
197 |
+
"[('Navigational', 0.983),\n",
|
198 |
+
" ('Transactional', 0.256),\n",
|
199 |
+
" ('Informational', 0.241),\n",
|
200 |
+
" ('Local', 0.214),\n",
|
201 |
+
" ('Commercial', 0.184)]"
|
202 |
+
]
|
203 |
+
},
|
204 |
+
"execution_count": 10,
|
205 |
+
"metadata": {},
|
206 |
+
"output_type": "execute_result"
|
207 |
+
}
|
208 |
+
],
|
209 |
+
"source": [
|
210 |
+
"get_top_intent(\"spotify\")"
|
211 |
+
]
|
212 |
+
},
|
213 |
+
{
|
214 |
+
"cell_type": "code",
|
215 |
+
"execution_count": 11,
|
216 |
+
"metadata": {},
|
217 |
+
"outputs": [
|
218 |
+
{
|
219 |
+
"data": {
|
220 |
+
"text/plain": [
|
221 |
+
"[('Local', 0.988),\n",
|
222 |
+
" ('Informational', 0.294),\n",
|
223 |
+
" ('Navigational', 0.284),\n",
|
224 |
+
" ('Commercial', 0.252),\n",
|
225 |
+
" ('Transactional', 0.235)]"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
"execution_count": 11,
|
229 |
+
"metadata": {},
|
230 |
+
"output_type": "execute_result"
|
231 |
+
}
|
232 |
+
],
|
233 |
+
"source": [
|
234 |
+
"get_top_intent(\"parlours in dubai\")"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"cell_type": "code",
|
239 |
+
"execution_count": 12,
|
240 |
+
"metadata": {},
|
241 |
+
"outputs": [
|
242 |
+
{
|
243 |
+
"data": {
|
244 |
+
"text/plain": [
|
245 |
+
"[('Informational', 0.984),\n",
|
246 |
+
" ('Local', 0.245),\n",
|
247 |
+
" ('Commercial', 0.242),\n",
|
248 |
+
" ('Transactional', 0.226),\n",
|
249 |
+
" ('Navigational', 0.189)]"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
"execution_count": 12,
|
253 |
+
"metadata": {},
|
254 |
+
"output_type": "execute_result"
|
255 |
+
}
|
256 |
+
],
|
257 |
+
"source": [
|
258 |
+
"get_top_intent(\"how to wear headphones\")"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "code",
|
263 |
+
"execution_count": null,
|
264 |
+
"metadata": {},
|
265 |
+
"outputs": [],
|
266 |
+
"source": []
|
267 |
+
}
|
268 |
+
],
|
269 |
+
"metadata": {
|
270 |
+
"kernelspec": {
|
271 |
+
"display_name": "venv",
|
272 |
+
"language": "python",
|
273 |
+
"name": "python3"
|
274 |
+
},
|
275 |
+
"language_info": {
|
276 |
+
"codemirror_mode": {
|
277 |
+
"name": "ipython",
|
278 |
+
"version": 3
|
279 |
+
},
|
280 |
+
"file_extension": ".py",
|
281 |
+
"mimetype": "text/x-python",
|
282 |
+
"name": "python",
|
283 |
+
"nbconvert_exporter": "python",
|
284 |
+
"pygments_lexer": "ipython3",
|
285 |
+
"version": "3.10.12"
|
286 |
+
}
|
287 |
+
},
|
288 |
+
"nbformat": 4,
|
289 |
+
"nbformat_minor": 2
|
290 |
+
}
|
research/11_intent_classification_using_distilbert.ipynb
ADDED
@@ -0,0 +1,898 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os; os.chdir('..')"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 2,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"import pandas as pd"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 3,
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [
|
26 |
+
{
|
27 |
+
"data": {
|
28 |
+
"text/html": [
|
29 |
+
"<div>\n",
|
30 |
+
"<style scoped>\n",
|
31 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
32 |
+
" vertical-align: middle;\n",
|
33 |
+
" }\n",
|
34 |
+
"\n",
|
35 |
+
" .dataframe tbody tr th {\n",
|
36 |
+
" vertical-align: top;\n",
|
37 |
+
" }\n",
|
38 |
+
"\n",
|
39 |
+
" .dataframe thead th {\n",
|
40 |
+
" text-align: right;\n",
|
41 |
+
" }\n",
|
42 |
+
"</style>\n",
|
43 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
44 |
+
" <thead>\n",
|
45 |
+
" <tr style=\"text-align: right;\">\n",
|
46 |
+
" <th></th>\n",
|
47 |
+
" <th>keyword</th>\n",
|
48 |
+
" <th>intent</th>\n",
|
49 |
+
" </tr>\n",
|
50 |
+
" </thead>\n",
|
51 |
+
" <tbody>\n",
|
52 |
+
" <tr>\n",
|
53 |
+
" <th>0</th>\n",
|
54 |
+
" <td>citalopram vs prozac</td>\n",
|
55 |
+
" <td>Commercial</td>\n",
|
56 |
+
" </tr>\n",
|
57 |
+
" <tr>\n",
|
58 |
+
" <th>1</th>\n",
|
59 |
+
" <td>who is the oldest football player</td>\n",
|
60 |
+
" <td>Informational</td>\n",
|
61 |
+
" </tr>\n",
|
62 |
+
" <tr>\n",
|
63 |
+
" <th>2</th>\n",
|
64 |
+
" <td>t mobile town east</td>\n",
|
65 |
+
" <td>Navigational</td>\n",
|
66 |
+
" </tr>\n",
|
67 |
+
" <tr>\n",
|
68 |
+
" <th>3</th>\n",
|
69 |
+
" <td>starbucks</td>\n",
|
70 |
+
" <td>Navigational</td>\n",
|
71 |
+
" </tr>\n",
|
72 |
+
" <tr>\n",
|
73 |
+
" <th>4</th>\n",
|
74 |
+
" <td>tech crunch</td>\n",
|
75 |
+
" <td>Navigational</td>\n",
|
76 |
+
" </tr>\n",
|
77 |
+
" </tbody>\n",
|
78 |
+
"</table>\n",
|
79 |
+
"</div>"
|
80 |
+
],
|
81 |
+
"text/plain": [
|
82 |
+
" keyword intent\n",
|
83 |
+
"0 citalopram vs prozac Commercial\n",
|
84 |
+
"1 who is the oldest football player Informational\n",
|
85 |
+
"2 t mobile town east Navigational\n",
|
86 |
+
"3 starbucks Navigational\n",
|
87 |
+
"4 tech crunch Navigational"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
"execution_count": 3,
|
91 |
+
"metadata": {},
|
92 |
+
"output_type": "execute_result"
|
93 |
+
}
|
94 |
+
],
|
95 |
+
"source": [
|
96 |
+
"original_df= pd.read_csv(\"data_intent/intent_data.csv\")\n",
|
97 |
+
"original_df.head()"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": 4,
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"intents= original_df.intent.unique().tolist()"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "code",
|
111 |
+
"execution_count": 5,
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [],
|
114 |
+
"source": [
|
115 |
+
"id2label= {}\n",
|
116 |
+
"label2id= {}\n",
|
117 |
+
"for i in range(len(intents)):\n",
|
118 |
+
" id2label[i]= intents[i]\n",
|
119 |
+
" label2id[intents[i]]= i"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": 6,
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [
|
127 |
+
{
|
128 |
+
"data": {
|
129 |
+
"text/plain": [
|
130 |
+
"{0: 'Commercial',\n",
|
131 |
+
" 1: 'Informational',\n",
|
132 |
+
" 2: 'Navigational',\n",
|
133 |
+
" 3: 'Local',\n",
|
134 |
+
" 4: 'Transactional'}"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
"execution_count": 6,
|
138 |
+
"metadata": {},
|
139 |
+
"output_type": "execute_result"
|
140 |
+
}
|
141 |
+
],
|
142 |
+
"source": [
|
143 |
+
"id2label"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"execution_count": 7,
|
149 |
+
"metadata": {},
|
150 |
+
"outputs": [
|
151 |
+
{
|
152 |
+
"data": {
|
153 |
+
"text/plain": [
|
154 |
+
"{'Commercial': 0,\n",
|
155 |
+
" 'Informational': 1,\n",
|
156 |
+
" 'Navigational': 2,\n",
|
157 |
+
" 'Local': 3,\n",
|
158 |
+
" 'Transactional': 4}"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
"execution_count": 7,
|
162 |
+
"metadata": {},
|
163 |
+
"output_type": "execute_result"
|
164 |
+
}
|
165 |
+
],
|
166 |
+
"source": [
|
167 |
+
"label2id"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": 8,
|
173 |
+
"metadata": {},
|
174 |
+
"outputs": [],
|
175 |
+
"source": [
|
176 |
+
"def make_label2id(label):\n",
|
177 |
+
" return label2id[label]"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": 9,
|
183 |
+
"metadata": {},
|
184 |
+
"outputs": [
|
185 |
+
{
|
186 |
+
"data": {
|
187 |
+
"text/html": [
|
188 |
+
"<div>\n",
|
189 |
+
"<style scoped>\n",
|
190 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
191 |
+
" vertical-align: middle;\n",
|
192 |
+
" }\n",
|
193 |
+
"\n",
|
194 |
+
" .dataframe tbody tr th {\n",
|
195 |
+
" vertical-align: top;\n",
|
196 |
+
" }\n",
|
197 |
+
"\n",
|
198 |
+
" .dataframe thead th {\n",
|
199 |
+
" text-align: right;\n",
|
200 |
+
" }\n",
|
201 |
+
"</style>\n",
|
202 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
203 |
+
" <thead>\n",
|
204 |
+
" <tr style=\"text-align: right;\">\n",
|
205 |
+
" <th></th>\n",
|
206 |
+
" <th>keyword</th>\n",
|
207 |
+
" <th>intent</th>\n",
|
208 |
+
" <th>id</th>\n",
|
209 |
+
" </tr>\n",
|
210 |
+
" </thead>\n",
|
211 |
+
" <tbody>\n",
|
212 |
+
" <tr>\n",
|
213 |
+
" <th>0</th>\n",
|
214 |
+
" <td>citalopram vs prozac</td>\n",
|
215 |
+
" <td>Commercial</td>\n",
|
216 |
+
" <td>0</td>\n",
|
217 |
+
" </tr>\n",
|
218 |
+
" <tr>\n",
|
219 |
+
" <th>1</th>\n",
|
220 |
+
" <td>who is the oldest football player</td>\n",
|
221 |
+
" <td>Informational</td>\n",
|
222 |
+
" <td>1</td>\n",
|
223 |
+
" </tr>\n",
|
224 |
+
" <tr>\n",
|
225 |
+
" <th>2</th>\n",
|
226 |
+
" <td>t mobile town east</td>\n",
|
227 |
+
" <td>Navigational</td>\n",
|
228 |
+
" <td>2</td>\n",
|
229 |
+
" </tr>\n",
|
230 |
+
" <tr>\n",
|
231 |
+
" <th>3</th>\n",
|
232 |
+
" <td>starbucks</td>\n",
|
233 |
+
" <td>Navigational</td>\n",
|
234 |
+
" <td>2</td>\n",
|
235 |
+
" </tr>\n",
|
236 |
+
" <tr>\n",
|
237 |
+
" <th>4</th>\n",
|
238 |
+
" <td>tech crunch</td>\n",
|
239 |
+
" <td>Navigational</td>\n",
|
240 |
+
" <td>2</td>\n",
|
241 |
+
" </tr>\n",
|
242 |
+
" <tr>\n",
|
243 |
+
" <th>...</th>\n",
|
244 |
+
" <td>...</td>\n",
|
245 |
+
" <td>...</td>\n",
|
246 |
+
" <td>...</td>\n",
|
247 |
+
" </tr>\n",
|
248 |
+
" <tr>\n",
|
249 |
+
" <th>1066</th>\n",
|
250 |
+
" <td>How to make a paper flower?</td>\n",
|
251 |
+
" <td>Informational</td>\n",
|
252 |
+
" <td>1</td>\n",
|
253 |
+
" </tr>\n",
|
254 |
+
" <tr>\n",
|
255 |
+
" <th>1067</th>\n",
|
256 |
+
" <td>Why do some animals camouflage?</td>\n",
|
257 |
+
" <td>Informational</td>\n",
|
258 |
+
" <td>1</td>\n",
|
259 |
+
" </tr>\n",
|
260 |
+
" <tr>\n",
|
261 |
+
" <th>1068</th>\n",
|
262 |
+
" <td>What is the history of ancient civilizations?</td>\n",
|
263 |
+
" <td>Informational</td>\n",
|
264 |
+
" <td>1</td>\n",
|
265 |
+
" </tr>\n",
|
266 |
+
" <tr>\n",
|
267 |
+
" <th>1069</th>\n",
|
268 |
+
" <td>How to make a simple machine?</td>\n",
|
269 |
+
" <td>Informational</td>\n",
|
270 |
+
" <td>1</td>\n",
|
271 |
+
" </tr>\n",
|
272 |
+
" <tr>\n",
|
273 |
+
" <th>1070</th>\n",
|
274 |
+
" <td>Why do we see the phases of the moon?</td>\n",
|
275 |
+
" <td>Informational</td>\n",
|
276 |
+
" <td>1</td>\n",
|
277 |
+
" </tr>\n",
|
278 |
+
" </tbody>\n",
|
279 |
+
"</table>\n",
|
280 |
+
"<p>1071 rows × 3 columns</p>\n",
|
281 |
+
"</div>"
|
282 |
+
],
|
283 |
+
"text/plain": [
|
284 |
+
" keyword intent id\n",
|
285 |
+
"0 citalopram vs prozac Commercial 0\n",
|
286 |
+
"1 who is the oldest football player Informational 1\n",
|
287 |
+
"2 t mobile town east Navigational 2\n",
|
288 |
+
"3 starbucks Navigational 2\n",
|
289 |
+
"4 tech crunch Navigational 2\n",
|
290 |
+
"... ... ... ..\n",
|
291 |
+
"1066 How to make a paper flower? Informational 1\n",
|
292 |
+
"1067 Why do some animals camouflage? Informational 1\n",
|
293 |
+
"1068 What is the history of ancient civilizations? Informational 1\n",
|
294 |
+
"1069 How to make a simple machine? Informational 1\n",
|
295 |
+
"1070 Why do we see the phases of the moon? Informational 1\n",
|
296 |
+
"\n",
|
297 |
+
"[1071 rows x 3 columns]"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
"execution_count": 9,
|
301 |
+
"metadata": {},
|
302 |
+
"output_type": "execute_result"
|
303 |
+
}
|
304 |
+
],
|
305 |
+
"source": [
|
306 |
+
"original_df['id']= original_df.intent.map(make_label2id)\n",
|
307 |
+
"original_df"
|
308 |
+
]
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"cell_type": "code",
|
312 |
+
"execution_count": 10,
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [
|
315 |
+
{
|
316 |
+
"data": {
|
317 |
+
"text/html": [
|
318 |
+
"<div>\n",
|
319 |
+
"<style scoped>\n",
|
320 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
321 |
+
" vertical-align: middle;\n",
|
322 |
+
" }\n",
|
323 |
+
"\n",
|
324 |
+
" .dataframe tbody tr th {\n",
|
325 |
+
" vertical-align: top;\n",
|
326 |
+
" }\n",
|
327 |
+
"\n",
|
328 |
+
" .dataframe thead th {\n",
|
329 |
+
" text-align: right;\n",
|
330 |
+
" }\n",
|
331 |
+
"</style>\n",
|
332 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
333 |
+
" <thead>\n",
|
334 |
+
" <tr style=\"text-align: right;\">\n",
|
335 |
+
" <th></th>\n",
|
336 |
+
" <th>keyword</th>\n",
|
337 |
+
" <th>id</th>\n",
|
338 |
+
" </tr>\n",
|
339 |
+
" </thead>\n",
|
340 |
+
" <tbody>\n",
|
341 |
+
" <tr>\n",
|
342 |
+
" <th>0</th>\n",
|
343 |
+
" <td>citalopram vs prozac</td>\n",
|
344 |
+
" <td>0</td>\n",
|
345 |
+
" </tr>\n",
|
346 |
+
" <tr>\n",
|
347 |
+
" <th>1</th>\n",
|
348 |
+
" <td>who is the oldest football player</td>\n",
|
349 |
+
" <td>1</td>\n",
|
350 |
+
" </tr>\n",
|
351 |
+
" <tr>\n",
|
352 |
+
" <th>2</th>\n",
|
353 |
+
" <td>t mobile town east</td>\n",
|
354 |
+
" <td>2</td>\n",
|
355 |
+
" </tr>\n",
|
356 |
+
" <tr>\n",
|
357 |
+
" <th>3</th>\n",
|
358 |
+
" <td>starbucks</td>\n",
|
359 |
+
" <td>2</td>\n",
|
360 |
+
" </tr>\n",
|
361 |
+
" <tr>\n",
|
362 |
+
" <th>4</th>\n",
|
363 |
+
" <td>tech crunch</td>\n",
|
364 |
+
" <td>2</td>\n",
|
365 |
+
" </tr>\n",
|
366 |
+
" <tr>\n",
|
367 |
+
" <th>...</th>\n",
|
368 |
+
" <td>...</td>\n",
|
369 |
+
" <td>...</td>\n",
|
370 |
+
" </tr>\n",
|
371 |
+
" <tr>\n",
|
372 |
+
" <th>1066</th>\n",
|
373 |
+
" <td>How to make a paper flower?</td>\n",
|
374 |
+
" <td>1</td>\n",
|
375 |
+
" </tr>\n",
|
376 |
+
" <tr>\n",
|
377 |
+
" <th>1067</th>\n",
|
378 |
+
" <td>Why do some animals camouflage?</td>\n",
|
379 |
+
" <td>1</td>\n",
|
380 |
+
" </tr>\n",
|
381 |
+
" <tr>\n",
|
382 |
+
" <th>1068</th>\n",
|
383 |
+
" <td>What is the history of ancient civilizations?</td>\n",
|
384 |
+
" <td>1</td>\n",
|
385 |
+
" </tr>\n",
|
386 |
+
" <tr>\n",
|
387 |
+
" <th>1069</th>\n",
|
388 |
+
" <td>How to make a simple machine?</td>\n",
|
389 |
+
" <td>1</td>\n",
|
390 |
+
" </tr>\n",
|
391 |
+
" <tr>\n",
|
392 |
+
" <th>1070</th>\n",
|
393 |
+
" <td>Why do we see the phases of the moon?</td>\n",
|
394 |
+
" <td>1</td>\n",
|
395 |
+
" </tr>\n",
|
396 |
+
" </tbody>\n",
|
397 |
+
"</table>\n",
|
398 |
+
"<p>1071 rows × 2 columns</p>\n",
|
399 |
+
"</div>"
|
400 |
+
],
|
401 |
+
"text/plain": [
|
402 |
+
" keyword id\n",
|
403 |
+
"0 citalopram vs prozac 0\n",
|
404 |
+
"1 who is the oldest football player 1\n",
|
405 |
+
"2 t mobile town east 2\n",
|
406 |
+
"3 starbucks 2\n",
|
407 |
+
"4 tech crunch 2\n",
|
408 |
+
"... ... ..\n",
|
409 |
+
"1066 How to make a paper flower? 1\n",
|
410 |
+
"1067 Why do some animals camouflage? 1\n",
|
411 |
+
"1068 What is the history of ancient civilizations? 1\n",
|
412 |
+
"1069 How to make a simple machine? 1\n",
|
413 |
+
"1070 Why do we see the phases of the moon? 1\n",
|
414 |
+
"\n",
|
415 |
+
"[1071 rows x 2 columns]"
|
416 |
+
]
|
417 |
+
},
|
418 |
+
"execution_count": 10,
|
419 |
+
"metadata": {},
|
420 |
+
"output_type": "execute_result"
|
421 |
+
}
|
422 |
+
],
|
423 |
+
"source": [
|
424 |
+
"df= original_df[['keyword', 'id']]\n",
|
425 |
+
"df"
|
426 |
+
]
|
427 |
+
},
|
428 |
+
{
|
429 |
+
"cell_type": "code",
|
430 |
+
"execution_count": 11,
|
431 |
+
"metadata": {},
|
432 |
+
"outputs": [
|
433 |
+
{
|
434 |
+
"name": "stderr",
|
435 |
+
"output_type": "stream",
|
436 |
+
"text": [
|
437 |
+
"/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
438 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
439 |
+
]
|
440 |
+
}
|
441 |
+
],
|
442 |
+
"source": [
|
443 |
+
"from datasets import Dataset, load_dataset\n"
|
444 |
+
]
|
445 |
+
},
|
446 |
+
{
|
447 |
+
"cell_type": "code",
|
448 |
+
"execution_count": 12,
|
449 |
+
"metadata": {},
|
450 |
+
"outputs": [
|
451 |
+
{
|
452 |
+
"name": "stderr",
|
453 |
+
"output_type": "stream",
|
454 |
+
"text": [
|
455 |
+
"/tmp/ipykernel_138160/1635098052.py:1: SettingWithCopyWarning: \n",
|
456 |
+
"A value is trying to be set on a copy of a slice from a DataFrame\n",
|
457 |
+
"\n",
|
458 |
+
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
459 |
+
" df.rename(columns={\n"
|
460 |
+
]
|
461 |
+
},
|
462 |
+
{
|
463 |
+
"data": {
|
464 |
+
"text/html": [
|
465 |
+
"<div>\n",
|
466 |
+
"<style scoped>\n",
|
467 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
468 |
+
" vertical-align: middle;\n",
|
469 |
+
" }\n",
|
470 |
+
"\n",
|
471 |
+
" .dataframe tbody tr th {\n",
|
472 |
+
" vertical-align: top;\n",
|
473 |
+
" }\n",
|
474 |
+
"\n",
|
475 |
+
" .dataframe thead th {\n",
|
476 |
+
" text-align: right;\n",
|
477 |
+
" }\n",
|
478 |
+
"</style>\n",
|
479 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
480 |
+
" <thead>\n",
|
481 |
+
" <tr style=\"text-align: right;\">\n",
|
482 |
+
" <th></th>\n",
|
483 |
+
" <th>text</th>\n",
|
484 |
+
" <th>label</th>\n",
|
485 |
+
" </tr>\n",
|
486 |
+
" </thead>\n",
|
487 |
+
" <tbody>\n",
|
488 |
+
" <tr>\n",
|
489 |
+
" <th>706</th>\n",
|
490 |
+
" <td>Purchase DJ equipment</td>\n",
|
491 |
+
" <td>4</td>\n",
|
492 |
+
" </tr>\n",
|
493 |
+
" <tr>\n",
|
494 |
+
" <th>24</th>\n",
|
495 |
+
" <td>best headphones quora</td>\n",
|
496 |
+
" <td>2</td>\n",
|
497 |
+
" </tr>\n",
|
498 |
+
" <tr>\n",
|
499 |
+
" <th>727</th>\n",
|
500 |
+
" <td>Purchase fitness tracker</td>\n",
|
501 |
+
" <td>4</td>\n",
|
502 |
+
" </tr>\n",
|
503 |
+
" <tr>\n",
|
504 |
+
" <th>17</th>\n",
|
505 |
+
" <td>facebook</td>\n",
|
506 |
+
" <td>2</td>\n",
|
507 |
+
" </tr>\n",
|
508 |
+
" <tr>\n",
|
509 |
+
" <th>808</th>\n",
|
510 |
+
" <td>Outdoor activities in Lake Tahoe</td>\n",
|
511 |
+
" <td>3</td>\n",
|
512 |
+
" </tr>\n",
|
513 |
+
" <tr>\n",
|
514 |
+
" <th>946</th>\n",
|
515 |
+
" <td>Wine bars in Napa Valley</td>\n",
|
516 |
+
" <td>3</td>\n",
|
517 |
+
" </tr>\n",
|
518 |
+
" <tr>\n",
|
519 |
+
" <th>944</th>\n",
|
520 |
+
" <td>Art installations in Chicago</td>\n",
|
521 |
+
" <td>3</td>\n",
|
522 |
+
" </tr>\n",
|
523 |
+
" <tr>\n",
|
524 |
+
" <th>899</th>\n",
|
525 |
+
" <td>Snowboarding parks in Utah</td>\n",
|
526 |
+
" <td>3</td>\n",
|
527 |
+
" </tr>\n",
|
528 |
+
" <tr>\n",
|
529 |
+
" <th>36</th>\n",
|
530 |
+
" <td>Mission Immpossible</td>\n",
|
531 |
+
" <td>1</td>\n",
|
532 |
+
" </tr>\n",
|
533 |
+
" <tr>\n",
|
534 |
+
" <th>129</th>\n",
|
535 |
+
" <td>Instagram</td>\n",
|
536 |
+
" <td>2</td>\n",
|
537 |
+
" </tr>\n",
|
538 |
+
" </tbody>\n",
|
539 |
+
"</table>\n",
|
540 |
+
"</div>"
|
541 |
+
],
|
542 |
+
"text/plain": [
|
543 |
+
" text label\n",
|
544 |
+
"706 Purchase DJ equipment 4\n",
|
545 |
+
"24 best headphones quora 2\n",
|
546 |
+
"727 Purchase fitness tracker 4\n",
|
547 |
+
"17 facebook 2\n",
|
548 |
+
"808 Outdoor activities in Lake Tahoe 3\n",
|
549 |
+
"946 Wine bars in Napa Valley 3\n",
|
550 |
+
"944 Art installations in Chicago 3\n",
|
551 |
+
"899 Snowboarding parks in Utah 3\n",
|
552 |
+
"36 Mission Immpossible 1\n",
|
553 |
+
"129 Instagram 2"
|
554 |
+
]
|
555 |
+
},
|
556 |
+
"execution_count": 12,
|
557 |
+
"metadata": {},
|
558 |
+
"output_type": "execute_result"
|
559 |
+
}
|
560 |
+
],
|
561 |
+
"source": [
|
562 |
+
"df.rename(columns={\n",
|
563 |
+
" \"keyword\": \"text\", \n",
|
564 |
+
" \"id\": \"label\"\n",
|
565 |
+
"}, \n",
|
566 |
+
" inplace=True\n",
|
567 |
+
")\n",
|
568 |
+
"\n",
|
569 |
+
"df.sample(10)"
|
570 |
+
]
|
571 |
+
},
|
572 |
+
{
|
573 |
+
"cell_type": "code",
|
574 |
+
"execution_count": 13,
|
575 |
+
"metadata": {},
|
576 |
+
"outputs": [
|
577 |
+
{
|
578 |
+
"name": "stderr",
|
579 |
+
"output_type": "stream",
|
580 |
+
"text": [
|
581 |
+
"/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/pyarrow/pandas_compat.py:373: FutureWarning: is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.\n",
|
582 |
+
" if _pandas_api.is_sparse(col):\n"
|
583 |
+
]
|
584 |
+
},
|
585 |
+
{
|
586 |
+
"data": {
|
587 |
+
"text/plain": [
|
588 |
+
"Dataset({\n",
|
589 |
+
" features: ['text', 'label'],\n",
|
590 |
+
" num_rows: 1071\n",
|
591 |
+
"})"
|
592 |
+
]
|
593 |
+
},
|
594 |
+
"execution_count": 13,
|
595 |
+
"metadata": {},
|
596 |
+
"output_type": "execute_result"
|
597 |
+
}
|
598 |
+
],
|
599 |
+
"source": [
|
600 |
+
"dataset_df= Dataset.from_pandas(df)\n",
|
601 |
+
"dataset_df"
|
602 |
+
]
|
603 |
+
},
|
604 |
+
{
|
605 |
+
"cell_type": "code",
|
606 |
+
"execution_count": 14,
|
607 |
+
"metadata": {},
|
608 |
+
"outputs": [
|
609 |
+
{
|
610 |
+
"data": {
|
611 |
+
"text/plain": [
|
612 |
+
"DatasetDict({\n",
|
613 |
+
" train: Dataset({\n",
|
614 |
+
" features: ['text', 'label'],\n",
|
615 |
+
" num_rows: 856\n",
|
616 |
+
" })\n",
|
617 |
+
" test: Dataset({\n",
|
618 |
+
" features: ['text', 'label'],\n",
|
619 |
+
" num_rows: 215\n",
|
620 |
+
" })\n",
|
621 |
+
"})"
|
622 |
+
]
|
623 |
+
},
|
624 |
+
"execution_count": 14,
|
625 |
+
"metadata": {},
|
626 |
+
"output_type": "execute_result"
|
627 |
+
}
|
628 |
+
],
|
629 |
+
"source": [
|
630 |
+
"new_data= dataset_df.train_test_split(test_size=0.2)\n",
|
631 |
+
"new_data"
|
632 |
+
]
|
633 |
+
},
|
634 |
+
{
|
635 |
+
"cell_type": "code",
|
636 |
+
"execution_count": 15,
|
637 |
+
"metadata": {},
|
638 |
+
"outputs": [],
|
639 |
+
"source": [
|
640 |
+
"from transformers import AutoTokenizer\n",
|
641 |
+
"\n",
|
642 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")"
|
643 |
+
]
|
644 |
+
},
|
645 |
+
{
|
646 |
+
"cell_type": "code",
|
647 |
+
"execution_count": 16,
|
648 |
+
"metadata": {},
|
649 |
+
"outputs": [],
|
650 |
+
"source": [
|
651 |
+
"def preprocess_function(examples):\n",
|
652 |
+
" return tokenizer(examples[\"text\"], truncation=True)"
|
653 |
+
]
|
654 |
+
},
|
655 |
+
{
|
656 |
+
"cell_type": "code",
|
657 |
+
"execution_count": 17,
|
658 |
+
"metadata": {},
|
659 |
+
"outputs": [
|
660 |
+
{
|
661 |
+
"name": "stderr",
|
662 |
+
"output_type": "stream",
|
663 |
+
"text": [
|
664 |
+
"Map: 100%|██████████| 856/856 [00:00<00:00, 18779.12 examples/s]\n",
|
665 |
+
"Map: 100%|██████████| 215/215 [00:00<00:00, 27520.84 examples/s]\n"
|
666 |
+
]
|
667 |
+
}
|
668 |
+
],
|
669 |
+
"source": [
|
670 |
+
"tokenized_df = new_data.map(preprocess_function, batched=True)\n"
|
671 |
+
]
|
672 |
+
},
|
673 |
+
{
|
674 |
+
"cell_type": "code",
|
675 |
+
"execution_count": 18,
|
676 |
+
"metadata": {},
|
677 |
+
"outputs": [
|
678 |
+
{
|
679 |
+
"name": "stderr",
|
680 |
+
"output_type": "stream",
|
681 |
+
"text": [
|
682 |
+
"2023-10-13 09:10:00.122326: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
683 |
+
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
684 |
+
"2023-10-13 09:10:01.611782: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
|
685 |
+
]
|
686 |
+
}
|
687 |
+
],
|
688 |
+
"source": [
|
689 |
+
"# from transformers import DataCollatorWithPadding\n",
|
690 |
+
"\n",
|
691 |
+
"# data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors=\"tf\")\n",
|
692 |
+
"\n",
|
693 |
+
"\n",
|
694 |
+
"\n",
|
695 |
+
"\n",
|
696 |
+
"from transformers import DataCollatorWithPadding\n",
|
697 |
+
"\n",
|
698 |
+
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
|
699 |
+
]
|
700 |
+
},
|
701 |
+
{
|
702 |
+
"cell_type": "code",
|
703 |
+
"execution_count": 19,
|
704 |
+
"metadata": {},
|
705 |
+
"outputs": [],
|
706 |
+
"source": [
|
707 |
+
"import evaluate\n",
|
708 |
+
"\n",
|
709 |
+
"accuracy = evaluate.load(\"accuracy\")"
|
710 |
+
]
|
711 |
+
},
|
712 |
+
{
|
713 |
+
"cell_type": "code",
|
714 |
+
"execution_count": 20,
|
715 |
+
"metadata": {},
|
716 |
+
"outputs": [],
|
717 |
+
"source": [
|
718 |
+
"import numpy as np\n",
|
719 |
+
"\n",
|
720 |
+
"\n",
|
721 |
+
"def compute_metrics(eval_pred):\n",
|
722 |
+
" predictions, labels = eval_pred\n",
|
723 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
724 |
+
" return accuracy.compute(predictions=predictions, references=labels)"
|
725 |
+
]
|
726 |
+
},
|
727 |
+
{
|
728 |
+
"cell_type": "code",
|
729 |
+
"execution_count": 21,
|
730 |
+
"metadata": {},
|
731 |
+
"outputs": [
|
732 |
+
{
|
733 |
+
"name": "stderr",
|
734 |
+
"output_type": "stream",
|
735 |
+
"text": [
|
736 |
+
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'classifier.bias', 'pre_classifier.weight']\n",
|
737 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
738 |
+
]
|
739 |
+
}
|
740 |
+
],
|
741 |
+
"source": [
|
742 |
+
"from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
|
743 |
+
"\n",
|
744 |
+
"model = AutoModelForSequenceClassification.from_pretrained(\n",
|
745 |
+
" \"distilbert-base-uncased\", num_labels=5, id2label=id2label, label2id=label2id\n",
|
746 |
+
")"
|
747 |
+
]
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"cell_type": "code",
|
751 |
+
"execution_count": 22,
|
752 |
+
"metadata": {},
|
753 |
+
"outputs": [
|
754 |
+
{
|
755 |
+
"name": "stderr",
|
756 |
+
"output_type": "stream",
|
757 |
+
"text": [
|
758 |
+
"You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
759 |
+
]
|
760 |
+
},
|
761 |
+
{
|
762 |
+
"data": {
|
763 |
+
"text/html": [
|
764 |
+
"\n",
|
765 |
+
" <div>\n",
|
766 |
+
" \n",
|
767 |
+
" <progress value='324' max='324' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
768 |
+
" [324/324 00:39, Epoch 6/6]\n",
|
769 |
+
" </div>\n",
|
770 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
771 |
+
" <thead>\n",
|
772 |
+
" <tr style=\"text-align: left;\">\n",
|
773 |
+
" <th>Epoch</th>\n",
|
774 |
+
" <th>Training Loss</th>\n",
|
775 |
+
" <th>Validation Loss</th>\n",
|
776 |
+
" <th>Accuracy</th>\n",
|
777 |
+
" </tr>\n",
|
778 |
+
" </thead>\n",
|
779 |
+
" <tbody>\n",
|
780 |
+
" <tr>\n",
|
781 |
+
" <td>1</td>\n",
|
782 |
+
" <td>No log</td>\n",
|
783 |
+
" <td>0.467693</td>\n",
|
784 |
+
" <td>0.948837</td>\n",
|
785 |
+
" </tr>\n",
|
786 |
+
" <tr>\n",
|
787 |
+
" <td>2</td>\n",
|
788 |
+
" <td>No log</td>\n",
|
789 |
+
" <td>0.204288</td>\n",
|
790 |
+
" <td>0.953488</td>\n",
|
791 |
+
" </tr>\n",
|
792 |
+
" <tr>\n",
|
793 |
+
" <td>3</td>\n",
|
794 |
+
" <td>No log</td>\n",
|
795 |
+
" <td>0.164018</td>\n",
|
796 |
+
" <td>0.967442</td>\n",
|
797 |
+
" </tr>\n",
|
798 |
+
" <tr>\n",
|
799 |
+
" <td>4</td>\n",
|
800 |
+
" <td>No log</td>\n",
|
801 |
+
" <td>0.164968</td>\n",
|
802 |
+
" <td>0.967442</td>\n",
|
803 |
+
" </tr>\n",
|
804 |
+
" <tr>\n",
|
805 |
+
" <td>5</td>\n",
|
806 |
+
" <td>No log</td>\n",
|
807 |
+
" <td>0.163977</td>\n",
|
808 |
+
" <td>0.967442</td>\n",
|
809 |
+
" </tr>\n",
|
810 |
+
" <tr>\n",
|
811 |
+
" <td>6</td>\n",
|
812 |
+
" <td>No log</td>\n",
|
813 |
+
" <td>0.165533</td>\n",
|
814 |
+
" <td>0.967442</td>\n",
|
815 |
+
" </tr>\n",
|
816 |
+
" </tbody>\n",
|
817 |
+
"</table><p>"
|
818 |
+
],
|
819 |
+
"text/plain": [
|
820 |
+
"<IPython.core.display.HTML object>"
|
821 |
+
]
|
822 |
+
},
|
823 |
+
"metadata": {},
|
824 |
+
"output_type": "display_data"
|
825 |
+
},
|
826 |
+
{
|
827 |
+
"data": {
|
828 |
+
"text/plain": [
|
829 |
+
"TrainOutput(global_step=324, training_loss=0.2842947171058184, metrics={'train_runtime': 40.8212, 'train_samples_per_second': 125.817, 'train_steps_per_second': 7.937, 'total_flos': 13032177536640.0, 'train_loss': 0.2842947171058184, 'epoch': 6.0})"
|
830 |
+
]
|
831 |
+
},
|
832 |
+
"execution_count": 22,
|
833 |
+
"metadata": {},
|
834 |
+
"output_type": "execute_result"
|
835 |
+
}
|
836 |
+
],
|
837 |
+
"source": [
|
838 |
+
"training_args = TrainingArguments(\n",
|
839 |
+
" output_dir=\"intent_classification_model\",\n",
|
840 |
+
" learning_rate=2e-5,\n",
|
841 |
+
" per_device_train_batch_size=16,\n",
|
842 |
+
" per_device_eval_batch_size=16,\n",
|
843 |
+
" num_train_epochs=6,\n",
|
844 |
+
" weight_decay=0.01,\n",
|
845 |
+
" evaluation_strategy=\"epoch\",\n",
|
846 |
+
" save_strategy=\"epoch\",\n",
|
847 |
+
" load_best_model_at_end=True,\n",
|
848 |
+
" # push_to_hub=True,\n",
|
849 |
+
")\n",
|
850 |
+
"\n",
|
851 |
+
"trainer = Trainer(\n",
|
852 |
+
" model=model,\n",
|
853 |
+
" args=training_args,\n",
|
854 |
+
" train_dataset=tokenized_df[\"train\"],\n",
|
855 |
+
" eval_dataset=tokenized_df[\"test\"],\n",
|
856 |
+
" tokenizer=tokenizer,\n",
|
857 |
+
" data_collator=data_collator,\n",
|
858 |
+
" compute_metrics=compute_metrics,\n",
|
859 |
+
")\n",
|
860 |
+
"\n",
|
861 |
+
"trainer.train()"
|
862 |
+
]
|
863 |
+
},
|
864 |
+
{
|
865 |
+
"cell_type": "code",
|
866 |
+
"execution_count": null,
|
867 |
+
"metadata": {},
|
868 |
+
"outputs": [],
|
869 |
+
"source": []
|
870 |
+
},
|
871 |
+
{
|
872 |
+
"cell_type": "markdown",
|
873 |
+
"metadata": {},
|
874 |
+
"source": []
|
875 |
+
}
|
876 |
+
],
|
877 |
+
"metadata": {
|
878 |
+
"kernelspec": {
|
879 |
+
"display_name": "venv",
|
880 |
+
"language": "python",
|
881 |
+
"name": "python3"
|
882 |
+
},
|
883 |
+
"language_info": {
|
884 |
+
"codemirror_mode": {
|
885 |
+
"name": "ipython",
|
886 |
+
"version": 3
|
887 |
+
},
|
888 |
+
"file_extension": ".py",
|
889 |
+
"mimetype": "text/x-python",
|
890 |
+
"name": "python",
|
891 |
+
"nbconvert_exporter": "python",
|
892 |
+
"pygments_lexer": "ipython3",
|
893 |
+
"version": "3.10.12"
|
894 |
+
}
|
895 |
+
},
|
896 |
+
"nbformat": 4,
|
897 |
+
"nbformat_minor": 2
|
898 |
+
}
|
utils/__pycache__/get_category.cpython-310.pyc
CHANGED
Binary files a/utils/__pycache__/get_category.cpython-310.pyc and b/utils/__pycache__/get_category.cpython-310.pyc differ
|
|
utils/__pycache__/get_intent.cpython-310.pyc
ADDED
Binary file (1.5 kB). View file
|
|
utils/__pycache__/get_sentence_status.cpython-310.pyc
CHANGED
Binary files a/utils/__pycache__/get_sentence_status.cpython-310.pyc and b/utils/__pycache__/get_sentence_status.cpython-310.pyc differ
|
|
utils/get_category.py
CHANGED
@@ -93,16 +93,20 @@ def get_top_labels(keyword: str):
|
|
93 |
|
94 |
for i in range(27):
|
95 |
score= individual_probabilities_scores[i]
|
96 |
-
if score>=0.
|
97 |
score_list.append(
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
score_list.sort(
|
103 |
key= lambda x: x[1], reverse=True
|
104 |
)
|
105 |
|
106 |
-
return score_list
|
107 |
|
108 |
|
|
|
93 |
|
94 |
for i in range(27):
|
95 |
score= individual_probabilities_scores[i]
|
96 |
+
if score>=0.1:
|
97 |
score_list.append(
|
98 |
+
(id2label[i], score)
|
99 |
+
)
|
100 |
+
# if score>=0.5:
|
101 |
+
# score_list.append(
|
102 |
+
# (id2label[i], score)
|
103 |
+
# )
|
104 |
|
105 |
|
106 |
score_list.sort(
|
107 |
key= lambda x: x[1], reverse=True
|
108 |
)
|
109 |
|
110 |
+
return score_list[:5]
|
111 |
|
112 |
|
utils/get_intent.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer
|
2 |
+
from transformers import AutoModelForSequenceClassification
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
label2id= json.load(
|
11 |
+
open('data/categories_refined.json', 'r')
|
12 |
+
)
|
13 |
+
id2label= {}
|
14 |
+
for key in label2id.keys():
|
15 |
+
id2label[label2id[key]] = key
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
model_name= "intent_classification_model/checkpoint-324"
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
21 |
+
|
22 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name).to("cuda")
|
23 |
+
|
24 |
+
|
25 |
+
# probabilities = 1 / (1 + np.exp(-logit_score))
|
26 |
+
def logit2prob(logit):
|
27 |
+
# odds =np.exp(logit)
|
28 |
+
# prob = odds / (1 + odds)
|
29 |
+
prob= 1/(1+ np.exp(-logit))
|
30 |
+
return np.round(prob, 3)
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
def get_top_intent(keyword: str):
|
36 |
+
'''
|
37 |
+
Returns score list
|
38 |
+
'''
|
39 |
+
inputs = tokenizer(keyword, return_tensors="pt").to("cuda")
|
40 |
+
with torch.no_grad():
|
41 |
+
logits = model(**inputs).logits
|
42 |
+
|
43 |
+
# print("logits: ", logits)
|
44 |
+
# predicted_class_id = logits.argmax().item()
|
45 |
+
|
46 |
+
# get probabilities using softmax from logit score and convert it to numpy array
|
47 |
+
# probabilities_scores = F.softmax(logits.cpu(), dim = -1).numpy()[0]
|
48 |
+
individual_probabilities_scores = logit2prob(logits.cpu().numpy()[0])
|
49 |
+
|
50 |
+
score_list= []
|
51 |
+
|
52 |
+
for i in range(5):
|
53 |
+
label= model.config.id2label[i]
|
54 |
+
|
55 |
+
score= individual_probabilities_scores[i]
|
56 |
+
score_list.append(
|
57 |
+
(label, score)
|
58 |
+
)
|
59 |
+
# if score>=0.5:
|
60 |
+
# score_list.append(
|
61 |
+
# (id2label[i], score)
|
62 |
+
# )
|
63 |
+
|
64 |
+
|
65 |
+
score_list.sort(
|
66 |
+
key= lambda x: x[1], reverse=True
|
67 |
+
)
|
68 |
+
|
69 |
+
return score_list
|
utils/get_sentence_status.py
CHANGED
@@ -12,6 +12,13 @@ tokenizer_v2 = AutoTokenizer.from_pretrained("gpt2-large")
|
|
12 |
model = AutoModelForSequenceClassification.from_pretrained("gpt3_finetuned_model/checkpoint-30048").to("cuda")
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def split_sentence(sentence:str):
|
16 |
# Create a regular expression pattern from the list of separators
|
17 |
sentence= sentence.replace('\n', '')
|
@@ -98,4 +105,44 @@ def complete_sentence_analysis(sentence:str):
|
|
98 |
"label": label,
|
99 |
"variance": variance,
|
100 |
"avg_length": avg_length
|
101 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
model = AutoModelForSequenceClassification.from_pretrained("gpt3_finetuned_model/checkpoint-30048").to("cuda")
|
13 |
|
14 |
|
15 |
+
# probabilities = 1 / (1 + np.exp(-logit_score))
|
16 |
+
def logit2prob(logit):
|
17 |
+
# odds =np.exp(logit)
|
18 |
+
# prob = odds / (1 + odds)
|
19 |
+
prob= 1/(1+ np.exp(-logit))
|
20 |
+
return np.round(prob, 3)
|
21 |
+
|
22 |
def split_sentence(sentence:str):
|
23 |
# Create a regular expression pattern from the list of separators
|
24 |
sentence= sentence.replace('\n', '')
|
|
|
105 |
"label": label,
|
106 |
"variance": variance,
|
107 |
"avg_length": avg_length
|
108 |
+
}
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def get_top_labels(keyword: str):
|
115 |
+
'''
|
116 |
+
Returns score list
|
117 |
+
'''
|
118 |
+
inputs = tokenizer(keyword, return_tensors="pt").to("cuda")
|
119 |
+
with torch.no_grad():
|
120 |
+
logits = model(**inputs).logits
|
121 |
+
|
122 |
+
# print("logits: ", logits)
|
123 |
+
# predicted_class_id = logits.argmax().item()
|
124 |
+
|
125 |
+
# get probabilities using softmax from logit score and convert it to numpy array
|
126 |
+
# probabilities_scores = F.softmax(logits.cpu(), dim = -1).numpy()[0]
|
127 |
+
individual_probabilities_scores = logit2prob(logits.cpu().numpy()[0])
|
128 |
+
|
129 |
+
score_list= []
|
130 |
+
|
131 |
+
for i in range(2):
|
132 |
+
label= "Human Written" if model.config.id2label[i]=='NEGATIVE' else 'AI written'
|
133 |
+
|
134 |
+
score= individual_probabilities_scores[i]
|
135 |
+
score_list.append(
|
136 |
+
(label, score)
|
137 |
+
)
|
138 |
+
# if score>=0.5:
|
139 |
+
# score_list.append(
|
140 |
+
# (id2label[i], score)
|
141 |
+
# )
|
142 |
+
|
143 |
+
|
144 |
+
score_list.sort(
|
145 |
+
key= lambda x: x[1], reverse=True
|
146 |
+
)
|
147 |
+
|
148 |
+
return score_list[:5]
|