cspocketindia commited on
Commit
46e6114
1 Parent(s): a3f1f28

monitory initial commit

Browse files
Files changed (1) hide show
  1. gradio_app.py +54 -2
gradio_app.py CHANGED
@@ -1,5 +1,9 @@
 
 
1
  import gradio
2
-
 
 
3
  from src.models.bert import BERTClassifier
4
  from src.utils.utilities import Utility
5
 
@@ -7,11 +11,59 @@ model = BERTClassifier(model_name='jeevavijay10/nlp-goemotions-bert')
7
 
8
  classes = Utility().read_emotion_list()
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def predict(sentence):
 
11
  print(sentence)
 
12
  predictions = model.evaluate([sentence])
 
13
  print(f"Predictions: {predictions}")
14
- return classes[predictions[0]]
 
 
 
 
 
 
15
 
16
  gradio.Interface(
17
  fn=predict,
 
1
+ import os
2
+ import csv
3
  import gradio
4
+ from gradio import utils
5
+ import huggingface_hub
6
+ from pathlib import Path
7
  from src.models.bert import BERTClassifier
8
  from src.utils.utilities import Utility
9
 
 
11
 
12
  classes = Utility().read_emotion_list()
13
 
14
+ hf_token = os.getenv("HF_TOKEN")
15
+
16
+ dataset_dir = "logs"
17
+
18
+ headers = ["input", "output"]
19
+
20
+ path_to_dataset_repo = huggingface_hub.create_repo(
21
+ name="jeevavijay10/senti-pred-gradio",
22
+ token=hf_token,
23
+ private=False,
24
+ repo_type="dataset",
25
+ exist_ok=True,
26
+ )
27
+
28
+ repo = huggingface_hub.Repository(
29
+ local_dir=dataset_dir, clone_from=path_to_dataset_repo, use_auth_token=hf_token
30
+ )
31
+
32
+ def log_record(input, output):
33
+ repo.git_pull(lfs=True)
34
+
35
+ log_file = Path(dataset_dir) / "log.csv"
36
+
37
+ is_new = not Path(log_file).exists()
38
+
39
+ with open(log_file, "a", newline="", encoding="utf-8") as csvfile:
40
+ writer = csv.writer(csvfile)
41
+
42
+ if is_new:
43
+ writer.writerow(utils.sanitize_list_for_csv(headers))
44
+
45
+ writer.writerow(utils.sanitize_list_for_csv([input, output]))
46
+
47
+ with open(log_file, "r", encoding="utf-8") as csvfile:
48
+ line_count = len([None for _ in csv.reader(csvfile)]) - 1
49
+
50
+ repo.push_to_hub(commit_message=f"Flagged sample #{line_count}")
51
+
52
+
53
  def predict(sentence):
54
+
55
  print(sentence)
56
+
57
  predictions = model.evaluate([sentence])
58
+
59
  print(f"Predictions: {predictions}")
60
+
61
+ output = classes[predictions[0]]
62
+
63
+ log_record(sentence, output)
64
+
65
+ return output
66
+
67
 
68
  gradio.Interface(
69
  fn=predict,