ddemszky commited on
Commit
ba5a35c
1 Parent(s): 1fbd45a

update handler

Browse files
Files changed (3) hide show
  1. __pycache__/handler.cpython-39.pyc +0 -0
  2. handler.py +14 -19
  3. test.py +11 -6
__pycache__/handler.cpython-39.pyc CHANGED
Binary files a/__pycache__/handler.cpython-39.pyc and b/__pycache__/handler.cpython-39.pyc differ
 
handler.py CHANGED
@@ -35,10 +35,10 @@ class EndpointHandler():
35
  return_pooler_output=False)
36
  return output
37
 
38
- def get_uptake_score(self, utterances, speakerA, speakerB):
39
 
40
- textA = self.get_clean_text(utterances[speakerA], remove_punct=False)
41
- textB = self.get_clean_text(utterances[speakerB], remove_punct=False)
42
 
43
  instance = self.input_builder.build_inputs([textA], textB,
44
  max_length=self.max_length,
@@ -50,34 +50,29 @@ class EndpointHandler():
50
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
51
  """
52
  data args:
53
- inputs (:obj: `list`)
54
  parameters (:obj: `dict`)
55
  Return:
56
  A :obj:`list` | `dict`: will be serialized and returned
57
  """
58
  # get inputs
59
- inputs = data.pop("inputs", data)
60
  params = data.pop("parameters", None)
61
 
62
- utterances = inputs
63
  print("EXAMPLES")
64
- for utt_pair in utterances[:3]:
65
- print("speaker A: %s" % utt_pair[params["speaker_A"]])
66
- print("speaker B: %s" % utt_pair[params["speaker_B"]])
67
- print("----")
68
 
69
  print("Running inference on %d examples..." % len(utterances))
70
  self.model.eval()
71
- uptake_scores = []
 
 
72
  with torch.no_grad():
73
  for i, utt in enumerate(utterances):
74
- prev_num_words = get_num_words(utt[params["speaker_A"]])
75
- if prev_num_words < params["student_min_words"]:
76
- uptake_scores.append(None)
77
- continue
78
- uptake_score = self.get_uptake_score(utterances=utt,
79
- speakerA=params["speaker_A"],
80
- speakerB=params["speaker_B"])
81
- uptake_scores.append(uptake_score)
82
 
83
  return uptake_scores
 
35
  return_pooler_output=False)
36
  return output
37
 
38
+ def get_uptake_score(self, textA, textB):
39
 
40
+ textA = self.get_clean_text(textA, remove_punct=False)
41
+ textB = self.get_clean_text(textB, remove_punct=False)
42
 
43
  instance = self.input_builder.build_inputs([textA], textB,
44
  max_length=self.max_length,
 
50
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
51
  """
52
  data args:
53
+ utterances (:obj: `list`)
54
  parameters (:obj: `dict`)
55
  Return:
56
  A :obj:`list` | `dict`: will be serialized and returned
57
  """
58
  # get inputs
59
+ utterances = data.pop("utterances", data)
60
  params = data.pop("parameters", None)
61
 
 
62
  print("EXAMPLES")
63
+ for utt in utterances[:3]:
64
+ print("speaker %s: %s" % (utt["speaker"], utt["text"]))
 
 
65
 
66
  print("Running inference on %d examples..." % len(utterances))
67
  self.model.eval()
68
+ prev_num_words = 0
69
+ prev_text = ""
70
+ uptake_scores = {}
71
  with torch.no_grad():
72
  for i, utt in enumerate(utterances):
73
+ if utt["speaker"] == params["speaker_2"] and (prev_num_words >= params["speaker_1_min_num_words"]):
74
+ uptake_scores[utt["id"]] = self.get_uptake_score(textA=prev_text, textB=utt["text"])
75
+ prev_num_words = get_num_words(utt["text"])
76
+ prev_text = utt["text"]
 
 
 
 
77
 
78
  return uptake_scores
test.py CHANGED
@@ -4,12 +4,17 @@ from handler import EndpointHandler
4
  my_handler = EndpointHandler(path=".")
5
 
6
  # prepare sample payload
7
- example = {"inputs": [{"speaker_A": "I am quite excited how this will turn out",
8
- "speaker_B": "I'm excited, too"}],
9
- "parameters": {"speaker_A": "speaker_A",
10
- "speaker_B": "speaker_B",
11
- "student_min_words": 5
12
- }}
 
 
 
 
 
13
 
14
  # test the handler
15
  print(my_handler(example))
 
4
  my_handler = EndpointHandler(path=".")
5
 
6
  # prepare sample payload
7
+ example = {
8
+ "utterances": [
9
+ {"id": 1, "speaker": "Alice", "text": "How much is the fish?" },
10
+ {"id": 2, "speaker": "Bob", "text": "I have no idea, ask Alice" }
11
+ ],
12
+ "parameters": {
13
+ "speaker_1_min_num_words": 5,
14
+ "speaker_2": "Bob"
15
+ }
16
+ }
17
+
18
 
19
  # test the handler
20
  print(my_handler(example))