Spaces:
Runtime error
Runtime error
cspocketindia
commited on
Commit
•
46e6114
1
Parent(s):
a3f1f28
monitory initial commit
Browse files- gradio_app.py +54 -2
gradio_app.py
CHANGED
@@ -1,5 +1,9 @@
|
|
|
|
|
|
1 |
import gradio
|
2 |
-
|
|
|
|
|
3 |
from src.models.bert import BERTClassifier
|
4 |
from src.utils.utilities import Utility
|
5 |
|
@@ -7,11 +11,59 @@ model = BERTClassifier(model_name='jeevavijay10/nlp-goemotions-bert')
|
|
7 |
|
8 |
classes = Utility().read_emotion_list()
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def predict(sentence):
|
|
|
11 |
print(sentence)
|
|
|
12 |
predictions = model.evaluate([sentence])
|
|
|
13 |
print(f"Predictions: {predictions}")
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
gradio.Interface(
|
17 |
fn=predict,
|
|
|
1 |
+
import os
|
2 |
+
import csv
|
3 |
import gradio
|
4 |
+
from gradio import utils
|
5 |
+
import huggingface_hub
|
6 |
+
from pathlib import Path
|
7 |
from src.models.bert import BERTClassifier
|
8 |
from src.utils.utilities import Utility
|
9 |
|
|
|
11 |
|
12 |
classes = Utility().read_emotion_list()
|
13 |
|
14 |
+
hf_token = os.getenv("HF_TOKEN")
|
15 |
+
|
16 |
+
dataset_dir = "logs"
|
17 |
+
|
18 |
+
headers = ["input", "output"]
|
19 |
+
|
20 |
+
path_to_dataset_repo = huggingface_hub.create_repo(
|
21 |
+
name="jeevavijay10/senti-pred-gradio",
|
22 |
+
token=hf_token,
|
23 |
+
private=False,
|
24 |
+
repo_type="dataset",
|
25 |
+
exist_ok=True,
|
26 |
+
)
|
27 |
+
|
28 |
+
repo = huggingface_hub.Repository(
|
29 |
+
local_dir=dataset_dir, clone_from=path_to_dataset_repo, use_auth_token=hf_token
|
30 |
+
)
|
31 |
+
|
32 |
+
def log_record(input, output):
|
33 |
+
repo.git_pull(lfs=True)
|
34 |
+
|
35 |
+
log_file = Path(dataset_dir) / "log.csv"
|
36 |
+
|
37 |
+
is_new = not Path(log_file).exists()
|
38 |
+
|
39 |
+
with open(log_file, "a", newline="", encoding="utf-8") as csvfile:
|
40 |
+
writer = csv.writer(csvfile)
|
41 |
+
|
42 |
+
if is_new:
|
43 |
+
writer.writerow(utils.sanitize_list_for_csv(headers))
|
44 |
+
|
45 |
+
writer.writerow(utils.sanitize_list_for_csv([input, output]))
|
46 |
+
|
47 |
+
with open(log_file, "r", encoding="utf-8") as csvfile:
|
48 |
+
line_count = len([None for _ in csv.reader(csvfile)]) - 1
|
49 |
+
|
50 |
+
repo.push_to_hub(commit_message=f"Flagged sample #{line_count}")
|
51 |
+
|
52 |
+
|
53 |
def predict(sentence):
|
54 |
+
|
55 |
print(sentence)
|
56 |
+
|
57 |
predictions = model.evaluate([sentence])
|
58 |
+
|
59 |
print(f"Predictions: {predictions}")
|
60 |
+
|
61 |
+
output = classes[predictions[0]]
|
62 |
+
|
63 |
+
log_record(sentence, output)
|
64 |
+
|
65 |
+
return output
|
66 |
+
|
67 |
|
68 |
gradio.Interface(
|
69 |
fn=predict,
|