Jorge Fioranelli commited on
Commit
6f0a968
β€’
1 Parent(s): 839e6e2

Added zero-shot and distilbert models

Browse files
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ .DS_Store
all-models.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import concurrent.futures
4
+ import ktrain
5
+
6
+ zero_shot = pipeline("zero-shot-classification")
7
+ distilbert = ktrain.load_predictor("models/distilbert-base-uncased-finetuned-internet-provider")
8
+
9
+ def zero_shot_predict(text):
10
+ labels = ["Slow Connectivity", "Billing", "Setup", "No Connectivity"]
11
+ preds = zero_shot(text, candidate_labels=labels)
12
+ return {label: float(pred) for label, pred in zip(preds["labels"], preds["scores"])}
13
+
14
+ def distilbert_predict(text):
15
+ labels = distilbert.get_classes()
16
+ preds = distilbert.predict_proba(text)
17
+ return {label: float(pred) for label, pred in zip(labels, preds)}
18
+
19
+ def predict(text):
20
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
21
+ zero_shot_future = executor.submit(zero_shot_predict, text)
22
+ distilbert_future = executor.submit(distilbert_predict, text)
23
+ concurrent.futures.wait([zero_shot_future, distilbert_future])
24
+ zero_shot_preds = zero_shot_future.result()
25
+ distilbert_preds = distilbert_future.result()
26
+ return zero_shot_preds, distilbert_preds
27
+
28
+ input = gr.inputs.Textbox(label="Customer Sentence")
29
+ outputs = [gr.outputs.Label(num_top_classes=4, label="Zero-Shot-Classification"), gr.outputs.Label(num_top_classes=4, label="DistilBERT")]
30
+ title = "Case Classification"
31
+ description = "Comparison of Zero-Shot-Classification and a fine-tuned DistilBERT."
32
+ gr.Interface(predict, input, outputs, live=True, title=title, analytics_enabled=False,
33
+ description=description, capture_session=True).launch()
app.py DELETED
@@ -1,29 +0,0 @@
1
- import gradio as gr
2
- import numpy as np
3
- import tensorflow as tf
4
- import urllib.request
5
-
6
- # mlp_model = tf.keras.models.load_model(
7
- # "models/sketch_recognition/mlp.h5")
8
- cnn_model = tf.keras.models.load_model(
9
- "models/sketch_recognition/cnn.h5")
10
-
11
- labels = urllib.request.urlopen("https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt")
12
- labels = labels.read()
13
- labels = labels.decode('utf-8').split("\n")[:-1]
14
-
15
-
16
- def predict(img):
17
- img = tf.math.divide(img, 255)
18
- preds = cnn_model.predict(img.numpy().reshape(-1, 28, 28, 1))[0]
19
- return {label: float(pred) for label, pred in zip(labels, preds)}
20
-
21
- output = gr.outputs.Label(num_top_classes=3)
22
-
23
- title="Sketch Recognition"
24
- description="This Convolution Neural Network was trained on Google's " \
25
- "QuickDraw dataset with 345 classes. Try it by drawing a " \
26
- "lightbulb, radio, or anything you can think of!"
27
- thumbnail="https://github.com/gradio-app/machine-learning-experiments/raw/master/lightbulb.png?raw=true"
28
- gr.Interface(predict, "sketchpad", output, live=True, title=title, analytics_enabled=False,
29
- description=description, thumbnail=thumbnail, capture_session=True).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/internet_provider.csv ADDED
The diff for this file is too large to render. See raw diff
 
distillbert-classification-finetuning.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ktrain
2
+ from ktrain import text
3
+ import pandas as pd
4
+ from sklearn.model_selection import train_test_split
5
+ import os
6
+
7
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
8
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
9
+
10
+ data = pd.read_csv('data/internet_provider.csv') # Replace 'data.csv' with your actual file name
11
+ categories = ['Slow Connection', 'Billing', 'Setup', 'No Connectivity']
12
+
13
+ train_data, temp_data = train_test_split(data, test_size=0.2, random_state=42, shuffle=True)
14
+ val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42, shuffle=True)
15
+
16
+ model_name = "distilbert-base-uncased"
17
+
18
+ model = text.Transformer(model_name=model_name, maxlen=512, class_names=categories)
19
+
20
+ train_data = model.preprocess_train(train_data["Text"].tolist(), train_data["Category"].tolist())
21
+ val_data = model.preprocess_train(val_data["Text"].tolist(), val_data["Category"].tolist())
22
+ test_data = model.preprocess_train(test_data["Text"].tolist(), test_data["Category"].tolist())
23
+
24
+ classifier = model.get_classifier()
25
+
26
+ learner = ktrain.get_learner(classifier, train_data=train_data, val_data=val_data, batch_size=16)
27
+
28
+ learner.lr_find(show_plot=True, max_epochs=20)
29
+
30
+ learner.fit_onecycle(0.0001, 1)
31
+ learner.validate(class_names=categories)
32
+ learner.view_top_losses(n=5, preproc=model)
33
+
34
+ print(train_data.iloc[100])
35
+
36
+ predictor = ktrain.get_predictor(learner.model, preproc=model)
37
+
38
+ x = "I have issues with my internet connection"
39
+
40
+ prediction = predictor.predict(x)
41
+
42
+ print(f"prediction: {prediction}")
43
+ print(predictor.explain(x))
44
+
45
+ predictor.save("distilbest-model")
46
+
47
+ predictor = ktrain.load_predictor("distilbest-model")
48
+
49
+ x = "I have issues with my internet connection"
50
+
51
+ prediction = predictor.predict(x)
52
+
53
+ print(f"prediction: {prediction}")
54
+ print(predictor.explain(x))
distillbert-classification-run.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ktrain
2
+ from ktrain import text
3
+ import pandas as pd
4
+ from sklearn.model_selection import train_test_split
5
+ import os
6
+
7
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
8
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
9
+
10
+ predictor = ktrain.load_predictor("models/distilbert-base-uncased-finetuned-internet-provider")
11
+
12
+ x = "I have issues with my internet connection"
13
+
14
+ prediction = predictor.predict(x)
15
+
16
+ print(f"prediction: {prediction}")
17
+
18
+ labels = predictor.get_classes()
19
+ probs = predictor.predict_proba(x)
20
+ for i, label in enumerate(labels):
21
+ print(label, ":", probs[i])
flagged/DistilBERT/tmpxkobdgs7.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
flagged/Zero-Shot-Classification/tmp0kb5h3dj.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Customer Sentence,Zero-Shot-Classification,DistilBERT,flag,username,timestamp
2
+ ,/Users/fioranel/Projects/Case-Classification/flagged/Zero-Shot-Classification/tmp0kb5h3dj.json,/Users/fioranel/Projects/Case-Classification/flagged/DistilBERT/tmpxkobdgs7.json,,,2023-07-19 01:56:06.921642
models/distilbert-base-uncased-finetuned-internet-provider/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/tmp/tmp1o3vy28s",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertForSequenceClassification"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "LABEL_0",
13
+ "1": "LABEL_1",
14
+ "2": "LABEL_2",
15
+ "3": "LABEL_3"
16
+ },
17
+ "initializer_range": 0.02,
18
+ "label2id": {
19
+ "LABEL_0": 0,
20
+ "LABEL_1": 1,
21
+ "LABEL_2": 2,
22
+ "LABEL_3": 3
23
+ },
24
+ "max_position_embeddings": 512,
25
+ "model_type": "distilbert",
26
+ "n_heads": 12,
27
+ "n_layers": 6,
28
+ "pad_token_id": 0,
29
+ "qa_dropout": 0.1,
30
+ "seq_classif_dropout": 0.2,
31
+ "sinusoidal_pos_embds": false,
32
+ "tie_weights_": true,
33
+ "transformers_version": "4.31.0",
34
+ "vocab_size": 30522
35
+ }
models/distilbert-base-uncased-finetuned-internet-provider/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
models/{sketch_recognition/cnn.h5 β†’ distilbert-base-uncased-finetuned-internet-provider/tf_model.h5} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:25b551a31a9ca980637231fef74428a4f7b2dcea199cbe3226fed89629741075
3
- size 2029392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6375d12296e96db5c916233c62535f7b20fbe8b85bb62260801c14450231ec80
3
+ size 267961288
models/distilbert-base-uncased-finetuned-internet-provider/tf_model.preproc ADDED
Binary file (2.76 kB). View file
 
models/distilbert-base-uncased-finetuned-internet-provider/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/distilbert-base-uncased-finetuned-internet-provider/tokenizer_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "clean_up_tokenization_spaces": true,
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "mask_token": "[MASK]",
6
+ "model_max_length": 512,
7
+ "pad_token": "[PAD]",
8
+ "sep_token": "[SEP]",
9
+ "strip_accents": null,
10
+ "tokenize_chinese_chars": true,
11
+ "tokenizer_class": "DistilBertTokenizer",
12
+ "unk_token": "[UNK]"
13
+ }
models/distilbert-base-uncased-finetuned-internet-provider/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- tensorflow
 
 
 
 
 
1
+ tensorflow
2
+ gradio
3
+ transformers
4
+ openai
5
+ ktrain
zero-shot-classification.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ classifier = pipeline("zero-shot-classification")
5
+ labels = ["Connectivity", "Billing", "Setup"]
6
+
7
+ def predict(text):
8
+ preds = classifier(text, candidate_labels=labels)
9
+ return {label: float(pred) for label, pred in zip(preds["labels"], preds["scores"])}
10
+
11
+ input = gr.inputs.Textbox(label="Customer Sentence")
12
+ output = gr.outputs.Label(num_top_classes=4, label="Zero-Shot-Classification")
13
+ title = "Case Classification"
14
+ description = "Zero-Shot-Classification test"
15
+ gr.Interface(predict, input, output, live=True, title=title, analytics_enabled=False,
16
+ description=description).launch()