Ubuntu commited on
Commit
e77b318
1 Parent(s): d0702fa

added intent classification using distil bert

Browse files
Files changed (29) hide show
  1. data_intent/intent_data.csv +3 -0
  2. intent_classification_model/checkpoint-324/added_tokens.json +7 -0
  3. intent_classification_model/checkpoint-324/config.json +39 -0
  4. intent_classification_model/checkpoint-324/optimizer.pt +3 -0
  5. intent_classification_model/checkpoint-324/pytorch_model.bin +3 -0
  6. intent_classification_model/checkpoint-324/rng_state.pth +0 -0
  7. intent_classification_model/checkpoint-324/scheduler.pt +3 -0
  8. intent_classification_model/checkpoint-324/special_tokens_map.json +7 -0
  9. intent_classification_model/checkpoint-324/tokenizer.json +0 -0
  10. intent_classification_model/checkpoint-324/tokenizer_config.json +56 -0
  11. intent_classification_model/checkpoint-324/trainer_state.json +73 -0
  12. intent_classification_model/checkpoint-324/training_args.bin +3 -0
  13. intent_classification_model/checkpoint-324/vocab.txt +0 -0
  14. intent_classification_model/runs/Oct13_09-06-59_ip-172-31-95-165/events.out.tfevents.1697188019.ip-172-31-95-165.137562.0 +0 -0
  15. intent_classification_model/runs/Oct13_09-08-12_ip-172-31-95-165/events.out.tfevents.1697188092.ip-172-31-95-165.137562.1 +0 -0
  16. intent_classification_model/runs/Oct13_09-08-49_ip-172-31-95-165/events.out.tfevents.1697188130.ip-172-31-95-165.137562.2 +0 -0
  17. intent_classification_model/runs/Oct13_09-09-35_ip-172-31-95-165/events.out.tfevents.1697188176.ip-172-31-95-165.137562.3 +0 -0
  18. intent_classification_model/runs/Oct13_09-10-07_ip-172-31-95-165/events.out.tfevents.1697188208.ip-172-31-95-165.138160.0 +0 -0
  19. research/04_inference.ipynb +217 -0
  20. research/10_demo_test_data.ipynb +19 -10
  21. research/11_evaluation.html +0 -0
  22. research/11_evaluation.ipynb +290 -0
  23. research/11_intent_classification_using_distilbert.ipynb +898 -0
  24. utils/__pycache__/get_category.cpython-310.pyc +0 -0
  25. utils/__pycache__/get_intent.cpython-310.pyc +0 -0
  26. utils/__pycache__/get_sentence_status.cpython-310.pyc +0 -0
  27. utils/get_category.py +8 -4
  28. utils/get_intent.py +69 -0
  29. utils/get_sentence_status.py +48 -1
data_intent/intent_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24091e2e977d444be178138ac717fa57b8d16534dcf5e66d4084cf3f77e6f6ce
3
+ size 39551
intent_classification_model/checkpoint-324/added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "[CLS]": 101,
3
+ "[MASK]": 103,
4
+ "[PAD]": 0,
5
+ "[SEP]": 102,
6
+ "[UNK]": 100
7
+ }
intent_classification_model/checkpoint-324/config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "distilbert-base-uncased",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertForSequenceClassification"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "Commercial",
13
+ "1": "Informational",
14
+ "2": "Navigational",
15
+ "3": "Local",
16
+ "4": "Transactional"
17
+ },
18
+ "initializer_range": 0.02,
19
+ "label2id": {
20
+ "Commercial": 0,
21
+ "Informational": 1,
22
+ "Local": 3,
23
+ "Navigational": 2,
24
+ "Transactional": 4
25
+ },
26
+ "max_position_embeddings": 512,
27
+ "model_type": "distilbert",
28
+ "n_heads": 12,
29
+ "n_layers": 6,
30
+ "pad_token_id": 0,
31
+ "problem_type": "single_label_classification",
32
+ "qa_dropout": 0.1,
33
+ "seq_classif_dropout": 0.2,
34
+ "sinusoidal_pos_embds": false,
35
+ "tie_weights_": true,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.34.0",
38
+ "vocab_size": 30522
39
+ }
intent_classification_model/checkpoint-324/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a50f88f7a9097ecddb2b3c7e3d38747deec4ca3a386132fac9e0e4efaa82ae0e
3
+ size 535745722
intent_classification_model/checkpoint-324/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b339df5c0d892e025a1749d085ab010e551f4b249eb497812a1a3bd7ebd5fd99
3
+ size 267865194
intent_classification_model/checkpoint-324/rng_state.pth ADDED
Binary file (14.2 kB). View file
 
intent_classification_model/checkpoint-324/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73f74582c189fe624f606122980ccb279125588a1db45b4052dc704fa2b51184
3
+ size 1064
intent_classification_model/checkpoint-324/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
intent_classification_model/checkpoint-324/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
intent_classification_model/checkpoint-324/tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "additional_special_tokens": [],
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "[CLS]",
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "DistilBertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
intent_classification_model/checkpoint-324/trainer_state.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.16397738456726074,
3
+ "best_model_checkpoint": "intent_classification_model/checkpoint-270",
4
+ "epoch": 6.0,
5
+ "eval_steps": 500,
6
+ "global_step": 324,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 1.0,
13
+ "eval_accuracy": 0.9488372093023256,
14
+ "eval_loss": 0.4676927328109741,
15
+ "eval_runtime": 0.1185,
16
+ "eval_samples_per_second": 1814.083,
17
+ "eval_steps_per_second": 118.126,
18
+ "step": 54
19
+ },
20
+ {
21
+ "epoch": 2.0,
22
+ "eval_accuracy": 0.9534883720930233,
23
+ "eval_loss": 0.20428764820098877,
24
+ "eval_runtime": 0.0972,
25
+ "eval_samples_per_second": 2210.83,
26
+ "eval_steps_per_second": 143.961,
27
+ "step": 108
28
+ },
29
+ {
30
+ "epoch": 3.0,
31
+ "eval_accuracy": 0.9674418604651163,
32
+ "eval_loss": 0.16401757299900055,
33
+ "eval_runtime": 0.1015,
34
+ "eval_samples_per_second": 2118.828,
35
+ "eval_steps_per_second": 137.97,
36
+ "step": 162
37
+ },
38
+ {
39
+ "epoch": 4.0,
40
+ "eval_accuracy": 0.9674418604651163,
41
+ "eval_loss": 0.16496841609477997,
42
+ "eval_runtime": 0.0941,
43
+ "eval_samples_per_second": 2284.398,
44
+ "eval_steps_per_second": 148.752,
45
+ "step": 216
46
+ },
47
+ {
48
+ "epoch": 5.0,
49
+ "eval_accuracy": 0.9674418604651163,
50
+ "eval_loss": 0.16397738456726074,
51
+ "eval_runtime": 0.0975,
52
+ "eval_samples_per_second": 2204.851,
53
+ "eval_steps_per_second": 143.572,
54
+ "step": 270
55
+ },
56
+ {
57
+ "epoch": 6.0,
58
+ "eval_accuracy": 0.9674418604651163,
59
+ "eval_loss": 0.16553252935409546,
60
+ "eval_runtime": 0.0947,
61
+ "eval_samples_per_second": 2271.063,
62
+ "eval_steps_per_second": 147.883,
63
+ "step": 324
64
+ }
65
+ ],
66
+ "logging_steps": 500,
67
+ "max_steps": 324,
68
+ "num_train_epochs": 6,
69
+ "save_steps": 500,
70
+ "total_flos": 13032177536640.0,
71
+ "trial_name": null,
72
+ "trial_params": null
73
+ }
intent_classification_model/checkpoint-324/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c27308f0087e544f12e1806abafb33d65745a5791fb1559d9e521f3670215df9
3
+ size 4536
intent_classification_model/checkpoint-324/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
intent_classification_model/runs/Oct13_09-06-59_ip-172-31-95-165/events.out.tfevents.1697188019.ip-172-31-95-165.137562.0 ADDED
Binary file (5.3 kB). View file
 
intent_classification_model/runs/Oct13_09-08-12_ip-172-31-95-165/events.out.tfevents.1697188092.ip-172-31-95-165.137562.1 ADDED
Binary file (6.02 kB). View file
 
intent_classification_model/runs/Oct13_09-08-49_ip-172-31-95-165/events.out.tfevents.1697188130.ip-172-31-95-165.137562.2 ADDED
Binary file (5.93 kB). View file
 
intent_classification_model/runs/Oct13_09-09-35_ip-172-31-95-165/events.out.tfevents.1697188176.ip-172-31-95-165.137562.3 ADDED
Binary file (4.73 kB). View file
 
intent_classification_model/runs/Oct13_09-10-07_ip-172-31-95-165/events.out.tfevents.1697188208.ip-172-31-95-165.138160.0 ADDED
Binary file (6.6 kB). View file
 
research/04_inference.ipynb CHANGED
@@ -673,6 +673,223 @@
673
  "There are a few reasons why language modeling people like perplexity instead of just using entropy. One is that, because of the exponent, improvements in perplexity \"feel\" like they are more substantial than the equivalent improvement in entropy. Another is that before they started using perplexity, the complexity of a language model was reported using a simplistic branching factor measurement that is more similar to perplexity than it is to entropy.''')"
674
  ]
675
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  {
677
  "cell_type": "code",
678
  "execution_count": null,
 
673
  "There are a few reasons why language modeling people like perplexity instead of just using entropy. One is that, because of the exponent, improvements in perplexity \"feel\" like they are more substantial than the equivalent improvement in entropy. Another is that before they started using perplexity, the complexity of a language model was reported using a simplistic branching factor measurement that is more similar to perplexity than it is to entropy.''')"
674
  ]
675
  },
676
+ {
677
+ "cell_type": "code",
678
+ "execution_count": null,
679
+ "metadata": {},
680
+ "outputs": [],
681
+ "source": []
682
+ },
683
+ {
684
+ "cell_type": "code",
685
+ "execution_count": 1,
686
+ "metadata": {},
687
+ "outputs": [],
688
+ "source": [
689
+ "import os; os.chdir(\n",
690
+ " '..'\n",
691
+ ")"
692
+ ]
693
+ },
694
+ {
695
+ "cell_type": "code",
696
+ "execution_count": 2,
697
+ "metadata": {},
698
+ "outputs": [
699
+ {
700
+ "name": "stderr",
701
+ "output_type": "stream",
702
+ "text": [
703
+ "/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
704
+ " from .autonotebook import tqdm as notebook_tqdm\n"
705
+ ]
706
+ },
707
+ {
708
+ "name": "stderr",
709
+ "output_type": "stream",
710
+ "text": [
711
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
712
+ ]
713
+ }
714
+ ],
715
+ "source": [
716
+ "from utils.get_sentence_status import get_top_labels"
717
+ ]
718
+ },
719
+ {
720
+ "cell_type": "code",
721
+ "execution_count": 4,
722
+ "metadata": {},
723
+ "outputs": [
724
+ {
725
+ "data": {
726
+ "text/plain": [
727
+ "[('Human Written', 0.999), ('AI written', 0.002)]"
728
+ ]
729
+ },
730
+ "execution_count": 4,
731
+ "metadata": {},
732
+ "output_type": "execute_result"
733
+ }
734
+ ],
735
+ "source": [
736
+ "get_top_labels('''12\n",
737
+ "\n",
738
+ "Yes, the perplexity is always equal to two to the power of the entropy. It doesn't matter what type of model you have, n-gram, unigram, or neural network.\n",
739
+ "\n",
740
+ "There are a few reasons why language modeling people like perplexity instead of just using entropy. One is that, because of the exponent, improvements in perplexity \"feel\" like they are more substantial than the equivalent improvement in entropy. Another is that before they started using perplexity, the complexity of a language model was reported using a simplistic branching factor measurement that is more similar to perplexity than it is to entropy.''')"
741
+ ]
742
+ },
743
+ {
744
+ "cell_type": "code",
745
+ "execution_count": 3,
746
+ "metadata": {},
747
+ "outputs": [
748
+ {
749
+ "data": {
750
+ "text/plain": [
751
+ "[('AI written', 1.0), ('Human Written', 0.0)]"
752
+ ]
753
+ },
754
+ "execution_count": 3,
755
+ "metadata": {},
756
+ "output_type": "execute_result"
757
+ }
758
+ ],
759
+ "source": [
760
+ "get_top_labels(\n",
761
+ " 'My name is deepankar'\n",
762
+ ")"
763
+ ]
764
+ },
765
+ {
766
+ "cell_type": "code",
767
+ "execution_count": 6,
768
+ "metadata": {},
769
+ "outputs": [
770
+ {
771
+ "data": {
772
+ "text/plain": [
773
+ "[('AI written', 0.999), ('Human Written', 0.001)]"
774
+ ]
775
+ },
776
+ "execution_count": 6,
777
+ "metadata": {},
778
+ "output_type": "execute_result"
779
+ }
780
+ ],
781
+ "source": [
782
+ "get_top_labels(\n",
783
+ " '''Hate speech or discriminatory content: Hate speech is speech, conduct, writing, or expressions that discriminate or promote discrimination against individuals or groups based on attributes such as race, religion, nationality, gender, sexual orientation, disability, or other characteristics. It often includes offensive language, stereotypes, or harmful stereotypes and can contribute to a hostile or unsafe environment for affected individuals.\n",
784
+ "\n",
785
+ "Explicit or adult content: Explicit or adult content typically refers to material that is sexually explicit, pornographic, or contains graphic depictions of sexual acts. This content may not be suitable for all audiences and is subject to age restrictions and content regulations in many jurisdictions.'''\n",
786
+ ")"
787
+ ]
788
+ },
789
+ {
790
+ "cell_type": "code",
791
+ "execution_count": 8,
792
+ "metadata": {},
793
+ "outputs": [
794
+ {
795
+ "data": {
796
+ "text/plain": [
797
+ "[('AI written', 0.912), ('Human Written', 0.115)]"
798
+ ]
799
+ },
800
+ "execution_count": 8,
801
+ "metadata": {},
802
+ "output_type": "execute_result"
803
+ }
804
+ ],
805
+ "source": [
806
+ "get_top_labels(\n",
807
+ " '''Of course, I can provide a more detailed explanation of these topics:\n",
808
+ "\n",
809
+ "1. **Hate speech or discriminatory content:** Hate speech is speech, conduct, writing, or expressions that discriminate or promote discrimination against individuals or groups based on attributes such as race, religion, nationality, gender, sexual orientation, disability, or other characteristics. It often includes offensive language, stereotypes, or harmful stereotypes and can contribute to a hostile or unsafe environment for affected individuals.\n",
810
+ "\n",
811
+ "2. **Explicit or adult content:** Explicit or adult content typically refers to material that is sexually explicit, pornographic, or contains graphic depictions of sexual acts. This content may not be suitable for all audiences and is subject to age restrictions and content regulations in many jurisdictions.\n",
812
+ "\n",
813
+ "9. **Inflammatory or extremist viewpoints:** Inflammatory viewpoints are those that are deliberately provocative, offensive, or designed to incite anger or outrage. Extreme or extremist viewpoints often involve radical ideologies and can contribute to division and hostility in discussions. Engaging in conversations that promote understanding and open dialogue is generally more constructive.\n",
814
+ "\n",
815
+ "In summary, these topics can be divisive, offensive, and harmful. When discussing or encountering them, it's essential to approach with respect, empathy, and a focus on maintaining a positive and safe environment for everyone involved.'''\n",
816
+ ")"
817
+ ]
818
+ },
819
+ {
820
+ "cell_type": "code",
821
+ "execution_count": 9,
822
+ "metadata": {},
823
+ "outputs": [
824
+ {
825
+ "data": {
826
+ "text/plain": [
827
+ "[('AI written', 0.998), ('Human Written', 0.003)]"
828
+ ]
829
+ },
830
+ "execution_count": 9,
831
+ "metadata": {},
832
+ "output_type": "execute_result"
833
+ }
834
+ ],
835
+ "source": [
836
+ "get_top_labels(\n",
837
+ " '''The situation in Israel remains tense. More than 1,200 people have been killed so far in the terror attacks by Hamas groups, both by their infiltration and rockets. The southern part of Israel which shares borders with the Gaza Strip still remains vulnerable. \n",
838
+ "\n",
839
+ "Ashkelon, one of the biggest cities in South Israel, has become a ghost town. Life is no longer normal here. Post noon till midnight there are a number of siren alarms, creating a constant atmosphere of panic ever since rockets were pounded into the city.'''\n",
840
+ ")"
841
+ ]
842
+ },
843
+ {
844
+ "cell_type": "code",
845
+ "execution_count": 10,
846
+ "metadata": {},
847
+ "outputs": [
848
+ {
849
+ "data": {
850
+ "text/plain": [
851
+ "[('AI written', 1.0), ('Human Written', 0.0)]"
852
+ ]
853
+ },
854
+ "execution_count": 10,
855
+ "metadata": {},
856
+ "output_type": "execute_result"
857
+ }
858
+ ],
859
+ "source": [
860
+ "get_top_labels(\n",
861
+ " '''Optical illusions are fascinating pictures that trick our eyes, making us doubt what's real. They come in different types and can make us question what we see, think, and understand about the world. Even scientists sometimes struggle to figure out these puzzling illusions.\n",
862
+ "\n",
863
+ "These illusions have various purposes. They challenge our minds, testing our thinking abilities. But they also provide a special way to delve into our personalities, revealing hidden aspects of who we are.\n",
864
+ "\n",
865
+ "The task is simple: look at the image and note what you see first. Your initial observation can unveil your deepest insecurity. Most people see either a ditch surrounded by trees or an eye.'''\n",
866
+ ")"
867
+ ]
868
+ },
869
+ {
870
+ "cell_type": "code",
871
+ "execution_count": 11,
872
+ "metadata": {},
873
+ "outputs": [
874
+ {
875
+ "data": {
876
+ "text/plain": [
877
+ "[('Human Written', 0.941), ('AI written', 0.056)]"
878
+ ]
879
+ },
880
+ "execution_count": 11,
881
+ "metadata": {},
882
+ "output_type": "execute_result"
883
+ }
884
+ ],
885
+ "source": [
886
+ "get_top_labels(\n",
887
+ " '''Learn from IIT Faculty & Industry Experts with Guaranteed Job Interviews.\n",
888
+ "Campus Immersion at IIT Roorkee.\n",
889
+ "Master machine learning and artificial intelligence skills with this advanced data science and artificial intelligence course from iHub IIT Roorkee. Learn from IIT faculty and industry experts with 1:1 mentorship in this intensive online bootcamp. Top 2 performers from each batch may get a fellowship worth Rs. 80,000, plus the opportunity to showcase their startup ideas and secure incubation support of upto Rs. 50 Lakhs for their startup from iHUB DivyaSampark, IIT Roorkee.'''\n",
890
+ ")"
891
+ ]
892
+ },
893
  {
894
  "cell_type": "code",
895
  "execution_count": null,
research/10_demo_test_data.ipynb CHANGED
@@ -768,7 +768,7 @@
768
  },
769
  {
770
  "cell_type": "code",
771
- "execution_count": 1,
772
  "metadata": {},
773
  "outputs": [],
774
  "source": [
@@ -785,7 +785,13 @@
785
  "output_type": "stream",
786
  "text": [
787
  "/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
788
- " from .autonotebook import tqdm as notebook_tqdm\n",
 
 
 
 
 
 
789
  "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
790
  ]
791
  }
@@ -802,7 +808,10 @@
802
  {
803
  "data": {
804
  "text/plain": [
805
- "[('Food_and_Drink', 0.99), ('Computers_and_Electronics', 0.973)]"
 
 
 
806
  ]
807
  },
808
  "execution_count": 3,
@@ -818,38 +827,38 @@
818
  },
819
  {
820
  "cell_type": "code",
821
- "execution_count": 6,
822
  "metadata": {},
823
  "outputs": [
824
  {
825
  "data": {
826
  "text/plain": [
827
- "[('Pets_and_Animals', 0.583)]"
828
  ]
829
  },
830
- "execution_count": 6,
831
  "metadata": {},
832
  "output_type": "execute_result"
833
  }
834
  ],
835
  "source": [
836
  "get_top_labels(\n",
837
- " 'turtle beach shaped headset guide'\n",
838
  ")"
839
  ]
840
  },
841
  {
842
  "cell_type": "code",
843
- "execution_count": 17,
844
  "metadata": {},
845
  "outputs": [
846
  {
847
  "data": {
848
  "text/plain": [
849
- "[('Home_and_Garden', 1.0)]"
850
  ]
851
  },
852
- "execution_count": 17,
853
  "metadata": {},
854
  "output_type": "execute_result"
855
  }
 
768
  },
769
  {
770
  "cell_type": "code",
771
+ "execution_count": 2,
772
  "metadata": {},
773
  "outputs": [],
774
  "source": [
 
785
  "output_type": "stream",
786
  "text": [
787
  "/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
788
+ " from .autonotebook import tqdm as notebook_tqdm\n"
789
+ ]
790
+ },
791
+ {
792
+ "name": "stderr",
793
+ "output_type": "stream",
794
+ "text": [
795
  "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
796
  ]
797
  }
 
808
  {
809
  "data": {
810
  "text/plain": [
811
+ "[('Food_and_Drink', 0.989),\n",
812
+ " ('Computers_and_Electronics', 0.973),\n",
813
+ " ('Games', 0.172),\n",
814
+ " ('Shopping', 0.134)]"
815
  ]
816
  },
817
  "execution_count": 3,
 
827
  },
828
  {
829
  "cell_type": "code",
830
+ "execution_count": 4,
831
  "metadata": {},
832
  "outputs": [
833
  {
834
  "data": {
835
  "text/plain": [
836
+ "[('Computers_and_Electronics', 0.999), ('Shopping', 0.993)]"
837
  ]
838
  },
839
+ "execution_count": 4,
840
  "metadata": {},
841
  "output_type": "execute_result"
842
  }
843
  ],
844
  "source": [
845
  "get_top_labels(\n",
846
+ " 'amazon mindkoo headsets with discount'\n",
847
  ")"
848
  ]
849
  },
850
  {
851
  "cell_type": "code",
852
+ "execution_count": 5,
853
  "metadata": {},
854
  "outputs": [
855
  {
856
  "data": {
857
  "text/plain": [
858
+ "[('Home_and_Garden', 0.999), ('Computers_and_Electronics', 0.243)]"
859
  ]
860
  },
861
+ "execution_count": 5,
862
  "metadata": {},
863
  "output_type": "execute_result"
864
  }
research/11_evaluation.html ADDED
The diff for this file is too large to render. See raw diff
 
research/11_evaluation.ipynb ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os; os.chdir('..')"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 2,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from utils.get_intent import get_top_intent"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 3,
24
+ "metadata": {},
25
+ "outputs": [
26
+ {
27
+ "data": {
28
+ "text/plain": [
29
+ "[('Commercial', 0.969),\n",
30
+ " ('Transactional', 0.673),\n",
31
+ " ('Informational', 0.237),\n",
32
+ " ('Navigational', 0.215),\n",
33
+ " ('Local', 0.155)]"
34
+ ]
35
+ },
36
+ "execution_count": 3,
37
+ "metadata": {},
38
+ "output_type": "execute_result"
39
+ }
40
+ ],
41
+ "source": [
42
+ "get_top_intent(\"best cat ear headphones\")"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 4,
48
+ "metadata": {},
49
+ "outputs": [
50
+ {
51
+ "data": {
52
+ "text/plain": [
53
+ "[('Transactional', 0.987),\n",
54
+ " ('Navigational', 0.317),\n",
55
+ " ('Commercial', 0.27),\n",
56
+ " ('Informational', 0.249),\n",
57
+ " ('Local', 0.229)]"
58
+ ]
59
+ },
60
+ "execution_count": 4,
61
+ "metadata": {},
62
+ "output_type": "execute_result"
63
+ }
64
+ ],
65
+ "source": [
66
+ "get_top_intent(\"buy cat ear headphones\")"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 5,
72
+ "metadata": {},
73
+ "outputs": [
74
+ {
75
+ "data": {
76
+ "text/plain": [
77
+ "[('Informational', 0.984),\n",
78
+ " ('Local', 0.244),\n",
79
+ " ('Commercial', 0.237),\n",
80
+ " ('Transactional', 0.212),\n",
81
+ " ('Navigational', 0.194)]"
82
+ ]
83
+ },
84
+ "execution_count": 5,
85
+ "metadata": {},
86
+ "output_type": "execute_result"
87
+ }
88
+ ],
89
+ "source": [
90
+ "get_top_intent(\"how to create a facebook account\")"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 6,
96
+ "metadata": {},
97
+ "outputs": [
98
+ {
99
+ "data": {
100
+ "text/plain": [
101
+ "[('Local', 0.988),\n",
102
+ " ('Informational', 0.3),\n",
103
+ " ('Commercial', 0.278),\n",
104
+ " ('Navigational', 0.273),\n",
105
+ " ('Transactional', 0.234)]"
106
+ ]
107
+ },
108
+ "execution_count": 6,
109
+ "metadata": {},
110
+ "output_type": "execute_result"
111
+ }
112
+ ],
113
+ "source": [
114
+ "get_top_intent(\"barber shops in USA\")"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 7,
120
+ "metadata": {},
121
+ "outputs": [
122
+ {
123
+ "data": {
124
+ "text/plain": [
125
+ "[('Informational', 0.763),\n",
126
+ " ('Navigational', 0.638),\n",
127
+ " ('Transactional', 0.433),\n",
128
+ " ('Commercial', 0.286),\n",
129
+ " ('Local', 0.236)]"
130
+ ]
131
+ },
132
+ "execution_count": 7,
133
+ "metadata": {},
134
+ "output_type": "execute_result"
135
+ }
136
+ ],
137
+ "source": [
138
+ "get_top_intent(\"Razer Kraken Headsets\")"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 8,
144
+ "metadata": {},
145
+ "outputs": [
146
+ {
147
+ "data": {
148
+ "text/plain": [
149
+ "[('Navigational', 0.861),\n",
150
+ " ('Transactional', 0.725),\n",
151
+ " ('Local', 0.422),\n",
152
+ " ('Commercial', 0.287),\n",
153
+ " ('Informational', 0.202)]"
154
+ ]
155
+ },
156
+ "execution_count": 8,
157
+ "metadata": {},
158
+ "output_type": "execute_result"
159
+ }
160
+ ],
161
+ "source": [
162
+ "get_top_intent(\"Amazon Great indian festival\")"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 9,
168
+ "metadata": {},
169
+ "outputs": [
170
+ {
171
+ "data": {
172
+ "text/plain": [
173
+ "[('Navigational', 0.983),\n",
174
+ " ('Transactional', 0.27),\n",
175
+ " ('Local', 0.23),\n",
176
+ " ('Informational', 0.209),\n",
177
+ " ('Commercial', 0.192)]"
178
+ ]
179
+ },
180
+ "execution_count": 9,
181
+ "metadata": {},
182
+ "output_type": "execute_result"
183
+ }
184
+ ],
185
+ "source": [
186
+ "get_top_intent(\"facebook\")"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": 10,
192
+ "metadata": {},
193
+ "outputs": [
194
+ {
195
+ "data": {
196
+ "text/plain": [
197
+ "[('Navigational', 0.983),\n",
198
+ " ('Transactional', 0.256),\n",
199
+ " ('Informational', 0.241),\n",
200
+ " ('Local', 0.214),\n",
201
+ " ('Commercial', 0.184)]"
202
+ ]
203
+ },
204
+ "execution_count": 10,
205
+ "metadata": {},
206
+ "output_type": "execute_result"
207
+ }
208
+ ],
209
+ "source": [
210
+ "get_top_intent(\"spotify\")"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": 11,
216
+ "metadata": {},
217
+ "outputs": [
218
+ {
219
+ "data": {
220
+ "text/plain": [
221
+ "[('Local', 0.988),\n",
222
+ " ('Informational', 0.294),\n",
223
+ " ('Navigational', 0.284),\n",
224
+ " ('Commercial', 0.252),\n",
225
+ " ('Transactional', 0.235)]"
226
+ ]
227
+ },
228
+ "execution_count": 11,
229
+ "metadata": {},
230
+ "output_type": "execute_result"
231
+ }
232
+ ],
233
+ "source": [
234
+ "get_top_intent(\"parlours in dubai\")"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": 12,
240
+ "metadata": {},
241
+ "outputs": [
242
+ {
243
+ "data": {
244
+ "text/plain": [
245
+ "[('Informational', 0.984),\n",
246
+ " ('Local', 0.245),\n",
247
+ " ('Commercial', 0.242),\n",
248
+ " ('Transactional', 0.226),\n",
249
+ " ('Navigational', 0.189)]"
250
+ ]
251
+ },
252
+ "execution_count": 12,
253
+ "metadata": {},
254
+ "output_type": "execute_result"
255
+ }
256
+ ],
257
+ "source": [
258
+ "get_top_intent(\"how to wear headphones\")"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": []
267
+ }
268
+ ],
269
+ "metadata": {
270
+ "kernelspec": {
271
+ "display_name": "venv",
272
+ "language": "python",
273
+ "name": "python3"
274
+ },
275
+ "language_info": {
276
+ "codemirror_mode": {
277
+ "name": "ipython",
278
+ "version": 3
279
+ },
280
+ "file_extension": ".py",
281
+ "mimetype": "text/x-python",
282
+ "name": "python",
283
+ "nbconvert_exporter": "python",
284
+ "pygments_lexer": "ipython3",
285
+ "version": "3.10.12"
286
+ }
287
+ },
288
+ "nbformat": 4,
289
+ "nbformat_minor": 2
290
+ }
research/11_intent_classification_using_distilbert.ipynb ADDED
@@ -0,0 +1,898 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os; os.chdir('..')"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 2,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import pandas as pd"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 3,
24
+ "metadata": {},
25
+ "outputs": [
26
+ {
27
+ "data": {
28
+ "text/html": [
29
+ "<div>\n",
30
+ "<style scoped>\n",
31
+ " .dataframe tbody tr th:only-of-type {\n",
32
+ " vertical-align: middle;\n",
33
+ " }\n",
34
+ "\n",
35
+ " .dataframe tbody tr th {\n",
36
+ " vertical-align: top;\n",
37
+ " }\n",
38
+ "\n",
39
+ " .dataframe thead th {\n",
40
+ " text-align: right;\n",
41
+ " }\n",
42
+ "</style>\n",
43
+ "<table border=\"1\" class=\"dataframe\">\n",
44
+ " <thead>\n",
45
+ " <tr style=\"text-align: right;\">\n",
46
+ " <th></th>\n",
47
+ " <th>keyword</th>\n",
48
+ " <th>intent</th>\n",
49
+ " </tr>\n",
50
+ " </thead>\n",
51
+ " <tbody>\n",
52
+ " <tr>\n",
53
+ " <th>0</th>\n",
54
+ " <td>citalopram vs prozac</td>\n",
55
+ " <td>Commercial</td>\n",
56
+ " </tr>\n",
57
+ " <tr>\n",
58
+ " <th>1</th>\n",
59
+ " <td>who is the oldest football player</td>\n",
60
+ " <td>Informational</td>\n",
61
+ " </tr>\n",
62
+ " <tr>\n",
63
+ " <th>2</th>\n",
64
+ " <td>t mobile town east</td>\n",
65
+ " <td>Navigational</td>\n",
66
+ " </tr>\n",
67
+ " <tr>\n",
68
+ " <th>3</th>\n",
69
+ " <td>starbucks</td>\n",
70
+ " <td>Navigational</td>\n",
71
+ " </tr>\n",
72
+ " <tr>\n",
73
+ " <th>4</th>\n",
74
+ " <td>tech crunch</td>\n",
75
+ " <td>Navigational</td>\n",
76
+ " </tr>\n",
77
+ " </tbody>\n",
78
+ "</table>\n",
79
+ "</div>"
80
+ ],
81
+ "text/plain": [
82
+ " keyword intent\n",
83
+ "0 citalopram vs prozac Commercial\n",
84
+ "1 who is the oldest football player Informational\n",
85
+ "2 t mobile town east Navigational\n",
86
+ "3 starbucks Navigational\n",
87
+ "4 tech crunch Navigational"
88
+ ]
89
+ },
90
+ "execution_count": 3,
91
+ "metadata": {},
92
+ "output_type": "execute_result"
93
+ }
94
+ ],
95
+ "source": [
96
+ "original_df= pd.read_csv(\"data_intent/intent_data.csv\")\n",
97
+ "original_df.head()"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": 4,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "intents= original_df.intent.unique().tolist()"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 5,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "id2label= {}\n",
116
+ "label2id= {}\n",
117
+ "for i in range(len(intents)):\n",
118
+ " id2label[i]= intents[i]\n",
119
+ " label2id[intents[i]]= i"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 6,
125
+ "metadata": {},
126
+ "outputs": [
127
+ {
128
+ "data": {
129
+ "text/plain": [
130
+ "{0: 'Commercial',\n",
131
+ " 1: 'Informational',\n",
132
+ " 2: 'Navigational',\n",
133
+ " 3: 'Local',\n",
134
+ " 4: 'Transactional'}"
135
+ ]
136
+ },
137
+ "execution_count": 6,
138
+ "metadata": {},
139
+ "output_type": "execute_result"
140
+ }
141
+ ],
142
+ "source": [
143
+ "id2label"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 7,
149
+ "metadata": {},
150
+ "outputs": [
151
+ {
152
+ "data": {
153
+ "text/plain": [
154
+ "{'Commercial': 0,\n",
155
+ " 'Informational': 1,\n",
156
+ " 'Navigational': 2,\n",
157
+ " 'Local': 3,\n",
158
+ " 'Transactional': 4}"
159
+ ]
160
+ },
161
+ "execution_count": 7,
162
+ "metadata": {},
163
+ "output_type": "execute_result"
164
+ }
165
+ ],
166
+ "source": [
167
+ "label2id"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 8,
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "def make_label2id(label):\n",
177
+ " return label2id[label]"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": 9,
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "data": {
187
+ "text/html": [
188
+ "<div>\n",
189
+ "<style scoped>\n",
190
+ " .dataframe tbody tr th:only-of-type {\n",
191
+ " vertical-align: middle;\n",
192
+ " }\n",
193
+ "\n",
194
+ " .dataframe tbody tr th {\n",
195
+ " vertical-align: top;\n",
196
+ " }\n",
197
+ "\n",
198
+ " .dataframe thead th {\n",
199
+ " text-align: right;\n",
200
+ " }\n",
201
+ "</style>\n",
202
+ "<table border=\"1\" class=\"dataframe\">\n",
203
+ " <thead>\n",
204
+ " <tr style=\"text-align: right;\">\n",
205
+ " <th></th>\n",
206
+ " <th>keyword</th>\n",
207
+ " <th>intent</th>\n",
208
+ " <th>id</th>\n",
209
+ " </tr>\n",
210
+ " </thead>\n",
211
+ " <tbody>\n",
212
+ " <tr>\n",
213
+ " <th>0</th>\n",
214
+ " <td>citalopram vs prozac</td>\n",
215
+ " <td>Commercial</td>\n",
216
+ " <td>0</td>\n",
217
+ " </tr>\n",
218
+ " <tr>\n",
219
+ " <th>1</th>\n",
220
+ " <td>who is the oldest football player</td>\n",
221
+ " <td>Informational</td>\n",
222
+ " <td>1</td>\n",
223
+ " </tr>\n",
224
+ " <tr>\n",
225
+ " <th>2</th>\n",
226
+ " <td>t mobile town east</td>\n",
227
+ " <td>Navigational</td>\n",
228
+ " <td>2</td>\n",
229
+ " </tr>\n",
230
+ " <tr>\n",
231
+ " <th>3</th>\n",
232
+ " <td>starbucks</td>\n",
233
+ " <td>Navigational</td>\n",
234
+ " <td>2</td>\n",
235
+ " </tr>\n",
236
+ " <tr>\n",
237
+ " <th>4</th>\n",
238
+ " <td>tech crunch</td>\n",
239
+ " <td>Navigational</td>\n",
240
+ " <td>2</td>\n",
241
+ " </tr>\n",
242
+ " <tr>\n",
243
+ " <th>...</th>\n",
244
+ " <td>...</td>\n",
245
+ " <td>...</td>\n",
246
+ " <td>...</td>\n",
247
+ " </tr>\n",
248
+ " <tr>\n",
249
+ " <th>1066</th>\n",
250
+ " <td>How to make a paper flower?</td>\n",
251
+ " <td>Informational</td>\n",
252
+ " <td>1</td>\n",
253
+ " </tr>\n",
254
+ " <tr>\n",
255
+ " <th>1067</th>\n",
256
+ " <td>Why do some animals camouflage?</td>\n",
257
+ " <td>Informational</td>\n",
258
+ " <td>1</td>\n",
259
+ " </tr>\n",
260
+ " <tr>\n",
261
+ " <th>1068</th>\n",
262
+ " <td>What is the history of ancient civilizations?</td>\n",
263
+ " <td>Informational</td>\n",
264
+ " <td>1</td>\n",
265
+ " </tr>\n",
266
+ " <tr>\n",
267
+ " <th>1069</th>\n",
268
+ " <td>How to make a simple machine?</td>\n",
269
+ " <td>Informational</td>\n",
270
+ " <td>1</td>\n",
271
+ " </tr>\n",
272
+ " <tr>\n",
273
+ " <th>1070</th>\n",
274
+ " <td>Why do we see the phases of the moon?</td>\n",
275
+ " <td>Informational</td>\n",
276
+ " <td>1</td>\n",
277
+ " </tr>\n",
278
+ " </tbody>\n",
279
+ "</table>\n",
280
+ "<p>1071 rows × 3 columns</p>\n",
281
+ "</div>"
282
+ ],
283
+ "text/plain": [
284
+ " keyword intent id\n",
285
+ "0 citalopram vs prozac Commercial 0\n",
286
+ "1 who is the oldest football player Informational 1\n",
287
+ "2 t mobile town east Navigational 2\n",
288
+ "3 starbucks Navigational 2\n",
289
+ "4 tech crunch Navigational 2\n",
290
+ "... ... ... ..\n",
291
+ "1066 How to make a paper flower? Informational 1\n",
292
+ "1067 Why do some animals camouflage? Informational 1\n",
293
+ "1068 What is the history of ancient civilizations? Informational 1\n",
294
+ "1069 How to make a simple machine? Informational 1\n",
295
+ "1070 Why do we see the phases of the moon? Informational 1\n",
296
+ "\n",
297
+ "[1071 rows x 3 columns]"
298
+ ]
299
+ },
300
+ "execution_count": 9,
301
+ "metadata": {},
302
+ "output_type": "execute_result"
303
+ }
304
+ ],
305
+ "source": [
306
+ "original_df['id']= original_df.intent.map(make_label2id)\n",
307
+ "original_df"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": 10,
313
+ "metadata": {},
314
+ "outputs": [
315
+ {
316
+ "data": {
317
+ "text/html": [
318
+ "<div>\n",
319
+ "<style scoped>\n",
320
+ " .dataframe tbody tr th:only-of-type {\n",
321
+ " vertical-align: middle;\n",
322
+ " }\n",
323
+ "\n",
324
+ " .dataframe tbody tr th {\n",
325
+ " vertical-align: top;\n",
326
+ " }\n",
327
+ "\n",
328
+ " .dataframe thead th {\n",
329
+ " text-align: right;\n",
330
+ " }\n",
331
+ "</style>\n",
332
+ "<table border=\"1\" class=\"dataframe\">\n",
333
+ " <thead>\n",
334
+ " <tr style=\"text-align: right;\">\n",
335
+ " <th></th>\n",
336
+ " <th>keyword</th>\n",
337
+ " <th>id</th>\n",
338
+ " </tr>\n",
339
+ " </thead>\n",
340
+ " <tbody>\n",
341
+ " <tr>\n",
342
+ " <th>0</th>\n",
343
+ " <td>citalopram vs prozac</td>\n",
344
+ " <td>0</td>\n",
345
+ " </tr>\n",
346
+ " <tr>\n",
347
+ " <th>1</th>\n",
348
+ " <td>who is the oldest football player</td>\n",
349
+ " <td>1</td>\n",
350
+ " </tr>\n",
351
+ " <tr>\n",
352
+ " <th>2</th>\n",
353
+ " <td>t mobile town east</td>\n",
354
+ " <td>2</td>\n",
355
+ " </tr>\n",
356
+ " <tr>\n",
357
+ " <th>3</th>\n",
358
+ " <td>starbucks</td>\n",
359
+ " <td>2</td>\n",
360
+ " </tr>\n",
361
+ " <tr>\n",
362
+ " <th>4</th>\n",
363
+ " <td>tech crunch</td>\n",
364
+ " <td>2</td>\n",
365
+ " </tr>\n",
366
+ " <tr>\n",
367
+ " <th>...</th>\n",
368
+ " <td>...</td>\n",
369
+ " <td>...</td>\n",
370
+ " </tr>\n",
371
+ " <tr>\n",
372
+ " <th>1066</th>\n",
373
+ " <td>How to make a paper flower?</td>\n",
374
+ " <td>1</td>\n",
375
+ " </tr>\n",
376
+ " <tr>\n",
377
+ " <th>1067</th>\n",
378
+ " <td>Why do some animals camouflage?</td>\n",
379
+ " <td>1</td>\n",
380
+ " </tr>\n",
381
+ " <tr>\n",
382
+ " <th>1068</th>\n",
383
+ " <td>What is the history of ancient civilizations?</td>\n",
384
+ " <td>1</td>\n",
385
+ " </tr>\n",
386
+ " <tr>\n",
387
+ " <th>1069</th>\n",
388
+ " <td>How to make a simple machine?</td>\n",
389
+ " <td>1</td>\n",
390
+ " </tr>\n",
391
+ " <tr>\n",
392
+ " <th>1070</th>\n",
393
+ " <td>Why do we see the phases of the moon?</td>\n",
394
+ " <td>1</td>\n",
395
+ " </tr>\n",
396
+ " </tbody>\n",
397
+ "</table>\n",
398
+ "<p>1071 rows × 2 columns</p>\n",
399
+ "</div>"
400
+ ],
401
+ "text/plain": [
402
+ " keyword id\n",
403
+ "0 citalopram vs prozac 0\n",
404
+ "1 who is the oldest football player 1\n",
405
+ "2 t mobile town east 2\n",
406
+ "3 starbucks 2\n",
407
+ "4 tech crunch 2\n",
408
+ "... ... ..\n",
409
+ "1066 How to make a paper flower? 1\n",
410
+ "1067 Why do some animals camouflage? 1\n",
411
+ "1068 What is the history of ancient civilizations? 1\n",
412
+ "1069 How to make a simple machine? 1\n",
413
+ "1070 Why do we see the phases of the moon? 1\n",
414
+ "\n",
415
+ "[1071 rows x 2 columns]"
416
+ ]
417
+ },
418
+ "execution_count": 10,
419
+ "metadata": {},
420
+ "output_type": "execute_result"
421
+ }
422
+ ],
423
+ "source": [
424
+ "df= original_df[['keyword', 'id']]\n",
425
+ "df"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": 11,
431
+ "metadata": {},
432
+ "outputs": [
433
+ {
434
+ "name": "stderr",
435
+ "output_type": "stream",
436
+ "text": [
437
+ "/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
438
+ " from .autonotebook import tqdm as notebook_tqdm\n"
439
+ ]
440
+ }
441
+ ],
442
+ "source": [
443
+ "from datasets import Dataset, load_dataset\n"
444
+ ]
445
+ },
446
+ {
447
+ "cell_type": "code",
448
+ "execution_count": 12,
449
+ "metadata": {},
450
+ "outputs": [
451
+ {
452
+ "name": "stderr",
453
+ "output_type": "stream",
454
+ "text": [
455
+ "/tmp/ipykernel_138160/1635098052.py:1: SettingWithCopyWarning: \n",
456
+ "A value is trying to be set on a copy of a slice from a DataFrame\n",
457
+ "\n",
458
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
459
+ " df.rename(columns={\n"
460
+ ]
461
+ },
462
+ {
463
+ "data": {
464
+ "text/html": [
465
+ "<div>\n",
466
+ "<style scoped>\n",
467
+ " .dataframe tbody tr th:only-of-type {\n",
468
+ " vertical-align: middle;\n",
469
+ " }\n",
470
+ "\n",
471
+ " .dataframe tbody tr th {\n",
472
+ " vertical-align: top;\n",
473
+ " }\n",
474
+ "\n",
475
+ " .dataframe thead th {\n",
476
+ " text-align: right;\n",
477
+ " }\n",
478
+ "</style>\n",
479
+ "<table border=\"1\" class=\"dataframe\">\n",
480
+ " <thead>\n",
481
+ " <tr style=\"text-align: right;\">\n",
482
+ " <th></th>\n",
483
+ " <th>text</th>\n",
484
+ " <th>label</th>\n",
485
+ " </tr>\n",
486
+ " </thead>\n",
487
+ " <tbody>\n",
488
+ " <tr>\n",
489
+ " <th>706</th>\n",
490
+ " <td>Purchase DJ equipment</td>\n",
491
+ " <td>4</td>\n",
492
+ " </tr>\n",
493
+ " <tr>\n",
494
+ " <th>24</th>\n",
495
+ " <td>best headphones quora</td>\n",
496
+ " <td>2</td>\n",
497
+ " </tr>\n",
498
+ " <tr>\n",
499
+ " <th>727</th>\n",
500
+ " <td>Purchase fitness tracker</td>\n",
501
+ " <td>4</td>\n",
502
+ " </tr>\n",
503
+ " <tr>\n",
504
+ " <th>17</th>\n",
505
+ " <td>facebook</td>\n",
506
+ " <td>2</td>\n",
507
+ " </tr>\n",
508
+ " <tr>\n",
509
+ " <th>808</th>\n",
510
+ " <td>Outdoor activities in Lake Tahoe</td>\n",
511
+ " <td>3</td>\n",
512
+ " </tr>\n",
513
+ " <tr>\n",
514
+ " <th>946</th>\n",
515
+ " <td>Wine bars in Napa Valley</td>\n",
516
+ " <td>3</td>\n",
517
+ " </tr>\n",
518
+ " <tr>\n",
519
+ " <th>944</th>\n",
520
+ " <td>Art installations in Chicago</td>\n",
521
+ " <td>3</td>\n",
522
+ " </tr>\n",
523
+ " <tr>\n",
524
+ " <th>899</th>\n",
525
+ " <td>Snowboarding parks in Utah</td>\n",
526
+ " <td>3</td>\n",
527
+ " </tr>\n",
528
+ " <tr>\n",
529
+ " <th>36</th>\n",
530
+ " <td>Mission Immpossible</td>\n",
531
+ " <td>1</td>\n",
532
+ " </tr>\n",
533
+ " <tr>\n",
534
+ " <th>129</th>\n",
535
+ " <td>Instagram</td>\n",
536
+ " <td>2</td>\n",
537
+ " </tr>\n",
538
+ " </tbody>\n",
539
+ "</table>\n",
540
+ "</div>"
541
+ ],
542
+ "text/plain": [
543
+ " text label\n",
544
+ "706 Purchase DJ equipment 4\n",
545
+ "24 best headphones quora 2\n",
546
+ "727 Purchase fitness tracker 4\n",
547
+ "17 facebook 2\n",
548
+ "808 Outdoor activities in Lake Tahoe 3\n",
549
+ "946 Wine bars in Napa Valley 3\n",
550
+ "944 Art installations in Chicago 3\n",
551
+ "899 Snowboarding parks in Utah 3\n",
552
+ "36 Mission Immpossible 1\n",
553
+ "129 Instagram 2"
554
+ ]
555
+ },
556
+ "execution_count": 12,
557
+ "metadata": {},
558
+ "output_type": "execute_result"
559
+ }
560
+ ],
561
+ "source": [
562
+ "df.rename(columns={\n",
563
+ " \"keyword\": \"text\", \n",
564
+ " \"id\": \"label\"\n",
565
+ "}, \n",
566
+ " inplace=True\n",
567
+ ")\n",
568
+ "\n",
569
+ "df.sample(10)"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": 13,
575
+ "metadata": {},
576
+ "outputs": [
577
+ {
578
+ "name": "stderr",
579
+ "output_type": "stream",
580
+ "text": [
581
+ "/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/pyarrow/pandas_compat.py:373: FutureWarning: is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.\n",
582
+ " if _pandas_api.is_sparse(col):\n"
583
+ ]
584
+ },
585
+ {
586
+ "data": {
587
+ "text/plain": [
588
+ "Dataset({\n",
589
+ " features: ['text', 'label'],\n",
590
+ " num_rows: 1071\n",
591
+ "})"
592
+ ]
593
+ },
594
+ "execution_count": 13,
595
+ "metadata": {},
596
+ "output_type": "execute_result"
597
+ }
598
+ ],
599
+ "source": [
600
+ "dataset_df= Dataset.from_pandas(df)\n",
601
+ "dataset_df"
602
+ ]
603
+ },
604
+ {
605
+ "cell_type": "code",
606
+ "execution_count": 14,
607
+ "metadata": {},
608
+ "outputs": [
609
+ {
610
+ "data": {
611
+ "text/plain": [
612
+ "DatasetDict({\n",
613
+ " train: Dataset({\n",
614
+ " features: ['text', 'label'],\n",
615
+ " num_rows: 856\n",
616
+ " })\n",
617
+ " test: Dataset({\n",
618
+ " features: ['text', 'label'],\n",
619
+ " num_rows: 215\n",
620
+ " })\n",
621
+ "})"
622
+ ]
623
+ },
624
+ "execution_count": 14,
625
+ "metadata": {},
626
+ "output_type": "execute_result"
627
+ }
628
+ ],
629
+ "source": [
630
+ "new_data= dataset_df.train_test_split(test_size=0.2)\n",
631
+ "new_data"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "execution_count": 15,
637
+ "metadata": {},
638
+ "outputs": [],
639
+ "source": [
640
+ "from transformers import AutoTokenizer\n",
641
+ "\n",
642
+ "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "execution_count": 16,
648
+ "metadata": {},
649
+ "outputs": [],
650
+ "source": [
651
+ "def preprocess_function(examples):\n",
652
+ " return tokenizer(examples[\"text\"], truncation=True)"
653
+ ]
654
+ },
655
+ {
656
+ "cell_type": "code",
657
+ "execution_count": 17,
658
+ "metadata": {},
659
+ "outputs": [
660
+ {
661
+ "name": "stderr",
662
+ "output_type": "stream",
663
+ "text": [
664
+ "Map: 100%|██████████| 856/856 [00:00<00:00, 18779.12 examples/s]\n",
665
+ "Map: 100%|██████████| 215/215 [00:00<00:00, 27520.84 examples/s]\n"
666
+ ]
667
+ }
668
+ ],
669
+ "source": [
670
+ "tokenized_df = new_data.map(preprocess_function, batched=True)\n"
671
+ ]
672
+ },
673
+ {
674
+ "cell_type": "code",
675
+ "execution_count": 18,
676
+ "metadata": {},
677
+ "outputs": [
678
+ {
679
+ "name": "stderr",
680
+ "output_type": "stream",
681
+ "text": [
682
+ "2023-10-13 09:10:00.122326: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
683
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
684
+ "2023-10-13 09:10:01.611782: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
685
+ ]
686
+ }
687
+ ],
688
+ "source": [
689
+ "# from transformers import DataCollatorWithPadding\n",
690
+ "\n",
691
+ "# data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors=\"tf\")\n",
692
+ "\n",
693
+ "\n",
694
+ "\n",
695
+ "\n",
696
+ "from transformers import DataCollatorWithPadding\n",
697
+ "\n",
698
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
699
+ ]
700
+ },
701
+ {
702
+ "cell_type": "code",
703
+ "execution_count": 19,
704
+ "metadata": {},
705
+ "outputs": [],
706
+ "source": [
707
+ "import evaluate\n",
708
+ "\n",
709
+ "accuracy = evaluate.load(\"accuracy\")"
710
+ ]
711
+ },
712
+ {
713
+ "cell_type": "code",
714
+ "execution_count": 20,
715
+ "metadata": {},
716
+ "outputs": [],
717
+ "source": [
718
+ "import numpy as np\n",
719
+ "\n",
720
+ "\n",
721
+ "def compute_metrics(eval_pred):\n",
722
+ " predictions, labels = eval_pred\n",
723
+ " predictions = np.argmax(predictions, axis=1)\n",
724
+ " return accuracy.compute(predictions=predictions, references=labels)"
725
+ ]
726
+ },
727
+ {
728
+ "cell_type": "code",
729
+ "execution_count": 21,
730
+ "metadata": {},
731
+ "outputs": [
732
+ {
733
+ "name": "stderr",
734
+ "output_type": "stream",
735
+ "text": [
736
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'classifier.bias', 'pre_classifier.weight']\n",
737
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
738
+ ]
739
+ }
740
+ ],
741
+ "source": [
742
+ "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
743
+ "\n",
744
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
745
+ " \"distilbert-base-uncased\", num_labels=5, id2label=id2label, label2id=label2id\n",
746
+ ")"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "code",
751
+ "execution_count": 22,
752
+ "metadata": {},
753
+ "outputs": [
754
+ {
755
+ "name": "stderr",
756
+ "output_type": "stream",
757
+ "text": [
758
+ "You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
759
+ ]
760
+ },
761
+ {
762
+ "data": {
763
+ "text/html": [
764
+ "\n",
765
+ " <div>\n",
766
+ " \n",
767
+ " <progress value='324' max='324' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
768
+ " [324/324 00:39, Epoch 6/6]\n",
769
+ " </div>\n",
770
+ " <table border=\"1\" class=\"dataframe\">\n",
771
+ " <thead>\n",
772
+ " <tr style=\"text-align: left;\">\n",
773
+ " <th>Epoch</th>\n",
774
+ " <th>Training Loss</th>\n",
775
+ " <th>Validation Loss</th>\n",
776
+ " <th>Accuracy</th>\n",
777
+ " </tr>\n",
778
+ " </thead>\n",
779
+ " <tbody>\n",
780
+ " <tr>\n",
781
+ " <td>1</td>\n",
782
+ " <td>No log</td>\n",
783
+ " <td>0.467693</td>\n",
784
+ " <td>0.948837</td>\n",
785
+ " </tr>\n",
786
+ " <tr>\n",
787
+ " <td>2</td>\n",
788
+ " <td>No log</td>\n",
789
+ " <td>0.204288</td>\n",
790
+ " <td>0.953488</td>\n",
791
+ " </tr>\n",
792
+ " <tr>\n",
793
+ " <td>3</td>\n",
794
+ " <td>No log</td>\n",
795
+ " <td>0.164018</td>\n",
796
+ " <td>0.967442</td>\n",
797
+ " </tr>\n",
798
+ " <tr>\n",
799
+ " <td>4</td>\n",
800
+ " <td>No log</td>\n",
801
+ " <td>0.164968</td>\n",
802
+ " <td>0.967442</td>\n",
803
+ " </tr>\n",
804
+ " <tr>\n",
805
+ " <td>5</td>\n",
806
+ " <td>No log</td>\n",
807
+ " <td>0.163977</td>\n",
808
+ " <td>0.967442</td>\n",
809
+ " </tr>\n",
810
+ " <tr>\n",
811
+ " <td>6</td>\n",
812
+ " <td>No log</td>\n",
813
+ " <td>0.165533</td>\n",
814
+ " <td>0.967442</td>\n",
815
+ " </tr>\n",
816
+ " </tbody>\n",
817
+ "</table><p>"
818
+ ],
819
+ "text/plain": [
820
+ "<IPython.core.display.HTML object>"
821
+ ]
822
+ },
823
+ "metadata": {},
824
+ "output_type": "display_data"
825
+ },
826
+ {
827
+ "data": {
828
+ "text/plain": [
829
+ "TrainOutput(global_step=324, training_loss=0.2842947171058184, metrics={'train_runtime': 40.8212, 'train_samples_per_second': 125.817, 'train_steps_per_second': 7.937, 'total_flos': 13032177536640.0, 'train_loss': 0.2842947171058184, 'epoch': 6.0})"
830
+ ]
831
+ },
832
+ "execution_count": 22,
833
+ "metadata": {},
834
+ "output_type": "execute_result"
835
+ }
836
+ ],
837
+ "source": [
838
+ "training_args = TrainingArguments(\n",
839
+ " output_dir=\"intent_classification_model\",\n",
840
+ " learning_rate=2e-5,\n",
841
+ " per_device_train_batch_size=16,\n",
842
+ " per_device_eval_batch_size=16,\n",
843
+ " num_train_epochs=6,\n",
844
+ " weight_decay=0.01,\n",
845
+ " evaluation_strategy=\"epoch\",\n",
846
+ " save_strategy=\"epoch\",\n",
847
+ " load_best_model_at_end=True,\n",
848
+ " # push_to_hub=True,\n",
849
+ ")\n",
850
+ "\n",
851
+ "trainer = Trainer(\n",
852
+ " model=model,\n",
853
+ " args=training_args,\n",
854
+ " train_dataset=tokenized_df[\"train\"],\n",
855
+ " eval_dataset=tokenized_df[\"test\"],\n",
856
+ " tokenizer=tokenizer,\n",
857
+ " data_collator=data_collator,\n",
858
+ " compute_metrics=compute_metrics,\n",
859
+ ")\n",
860
+ "\n",
861
+ "trainer.train()"
862
+ ]
863
+ },
864
+ {
865
+ "cell_type": "code",
866
+ "execution_count": null,
867
+ "metadata": {},
868
+ "outputs": [],
869
+ "source": []
870
+ },
871
+ {
872
+ "cell_type": "markdown",
873
+ "metadata": {},
874
+ "source": []
875
+ }
876
+ ],
877
+ "metadata": {
878
+ "kernelspec": {
879
+ "display_name": "venv",
880
+ "language": "python",
881
+ "name": "python3"
882
+ },
883
+ "language_info": {
884
+ "codemirror_mode": {
885
+ "name": "ipython",
886
+ "version": 3
887
+ },
888
+ "file_extension": ".py",
889
+ "mimetype": "text/x-python",
890
+ "name": "python",
891
+ "nbconvert_exporter": "python",
892
+ "pygments_lexer": "ipython3",
893
+ "version": "3.10.12"
894
+ }
895
+ },
896
+ "nbformat": 4,
897
+ "nbformat_minor": 2
898
+ }
utils/__pycache__/get_category.cpython-310.pyc CHANGED
Binary files a/utils/__pycache__/get_category.cpython-310.pyc and b/utils/__pycache__/get_category.cpython-310.pyc differ
 
utils/__pycache__/get_intent.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
 
utils/__pycache__/get_sentence_status.cpython-310.pyc CHANGED
Binary files a/utils/__pycache__/get_sentence_status.cpython-310.pyc and b/utils/__pycache__/get_sentence_status.cpython-310.pyc differ
 
utils/get_category.py CHANGED
@@ -93,16 +93,20 @@ def get_top_labels(keyword: str):
93
 
94
  for i in range(27):
95
  score= individual_probabilities_scores[i]
96
- if score>=0.5:
97
  score_list.append(
98
- (id2label[i], score)
99
- )
 
 
 
 
100
 
101
 
102
  score_list.sort(
103
  key= lambda x: x[1], reverse=True
104
  )
105
 
106
- return score_list
107
 
108
 
 
93
 
94
  for i in range(27):
95
  score= individual_probabilities_scores[i]
96
+ if score>=0.1:
97
  score_list.append(
98
+ (id2label[i], score)
99
+ )
100
+ # if score>=0.5:
101
+ # score_list.append(
102
+ # (id2label[i], score)
103
+ # )
104
 
105
 
106
  score_list.sort(
107
  key= lambda x: x[1], reverse=True
108
  )
109
 
110
+ return score_list[:5]
111
 
112
 
utils/get_intent.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ from transformers import AutoModelForSequenceClassification
3
+ import torch
4
+ from torch.nn import functional as F
5
+ import numpy as np
6
+ import json
7
+
8
+
9
+
10
+ label2id= json.load(
11
+ open('data/categories_refined.json', 'r')
12
+ )
13
+ id2label= {}
14
+ for key in label2id.keys():
15
+ id2label[label2id[key]] = key
16
+
17
+
18
+
19
+ model_name= "intent_classification_model/checkpoint-324"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+
22
+ model = AutoModelForSequenceClassification.from_pretrained(model_name).to("cuda")
23
+
24
+
25
+ # probabilities = 1 / (1 + np.exp(-logit_score))
26
+ def logit2prob(logit):
27
+ # odds =np.exp(logit)
28
+ # prob = odds / (1 + odds)
29
+ prob= 1/(1+ np.exp(-logit))
30
+ return np.round(prob, 3)
31
+
32
+
33
+
34
+
35
+ def get_top_intent(keyword: str):
36
+ '''
37
+ Returns score list
38
+ '''
39
+ inputs = tokenizer(keyword, return_tensors="pt").to("cuda")
40
+ with torch.no_grad():
41
+ logits = model(**inputs).logits
42
+
43
+ # print("logits: ", logits)
44
+ # predicted_class_id = logits.argmax().item()
45
+
46
+ # get probabilities using softmax from logit score and convert it to numpy array
47
+ # probabilities_scores = F.softmax(logits.cpu(), dim = -1).numpy()[0]
48
+ individual_probabilities_scores = logit2prob(logits.cpu().numpy()[0])
49
+
50
+ score_list= []
51
+
52
+ for i in range(5):
53
+ label= model.config.id2label[i]
54
+
55
+ score= individual_probabilities_scores[i]
56
+ score_list.append(
57
+ (label, score)
58
+ )
59
+ # if score>=0.5:
60
+ # score_list.append(
61
+ # (id2label[i], score)
62
+ # )
63
+
64
+
65
+ score_list.sort(
66
+ key= lambda x: x[1], reverse=True
67
+ )
68
+
69
+ return score_list
utils/get_sentence_status.py CHANGED
@@ -12,6 +12,13 @@ tokenizer_v2 = AutoTokenizer.from_pretrained("gpt2-large")
12
  model = AutoModelForSequenceClassification.from_pretrained("gpt3_finetuned_model/checkpoint-30048").to("cuda")
13
 
14
 
 
 
 
 
 
 
 
15
  def split_sentence(sentence:str):
16
  # Create a regular expression pattern from the list of separators
17
  sentence= sentence.replace('\n', '')
@@ -98,4 +105,44 @@ def complete_sentence_analysis(sentence:str):
98
  "label": label,
99
  "variance": variance,
100
  "avg_length": avg_length
101
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  model = AutoModelForSequenceClassification.from_pretrained("gpt3_finetuned_model/checkpoint-30048").to("cuda")
13
 
14
 
15
+ # probabilities = 1 / (1 + np.exp(-logit_score))
16
+ def logit2prob(logit):
17
+ # odds =np.exp(logit)
18
+ # prob = odds / (1 + odds)
19
+ prob= 1/(1+ np.exp(-logit))
20
+ return np.round(prob, 3)
21
+
22
  def split_sentence(sentence:str):
23
  # Create a regular expression pattern from the list of separators
24
  sentence= sentence.replace('\n', '')
 
105
  "label": label,
106
  "variance": variance,
107
  "avg_length": avg_length
108
+ }
109
+
110
+
111
+
112
+
113
+
114
+ def get_top_labels(keyword: str):
115
+ '''
116
+ Returns score list
117
+ '''
118
+ inputs = tokenizer(keyword, return_tensors="pt").to("cuda")
119
+ with torch.no_grad():
120
+ logits = model(**inputs).logits
121
+
122
+ # print("logits: ", logits)
123
+ # predicted_class_id = logits.argmax().item()
124
+
125
+ # get probabilities using softmax from logit score and convert it to numpy array
126
+ # probabilities_scores = F.softmax(logits.cpu(), dim = -1).numpy()[0]
127
+ individual_probabilities_scores = logit2prob(logits.cpu().numpy()[0])
128
+
129
+ score_list= []
130
+
131
+ for i in range(2):
132
+ label= "Human Written" if model.config.id2label[i]=='NEGATIVE' else 'AI written'
133
+
134
+ score= individual_probabilities_scores[i]
135
+ score_list.append(
136
+ (label, score)
137
+ )
138
+ # if score>=0.5:
139
+ # score_list.append(
140
+ # (id2label[i], score)
141
+ # )
142
+
143
+
144
+ score_list.sort(
145
+ key= lambda x: x[1], reverse=True
146
+ )
147
+
148
+ return score_list[:5]