hylee719 commited on
Commit
f4e0c16
1 Parent(s): cf8f521

only run reasoning for non-teacher utterances

Browse files
Files changed (1) hide show
  1. handler.py +3 -3
handler.py CHANGED
@@ -256,11 +256,11 @@ class ReasoningModel:
256
  self.model = BertForSequenceClassification.from_pretrained(path)
257
  self.model.to(self.device)
258
 
259
- def run_inference(self, transcript, min_num_words=8):
260
  self.model.eval()
261
  with torch.no_grad():
262
  for i, utt in enumerate(transcript.utterances):
263
- if utt.get_num_words() >= min_num_words:
264
  instance = self.input_builder.build_inputs([], utt.text,
265
  max_length=self.max_length,
266
  input_str=True)
@@ -430,7 +430,7 @@ class EndpointHandler():
430
  # Reasoning
431
  reasoning_model = ReasoningModel(
432
  self.device, self.tokenizer, self.input_builder)
433
- reasoning_model.run_inference(transcript)
434
 
435
  # Question
436
  question_model = QuestionModel(
 
256
  self.model = BertForSequenceClassification.from_pretrained(path)
257
  self.model.to(self.device)
258
 
259
+ def run_inference(self, transcript, min_num_words=8, uptake_speaker=None):
260
  self.model.eval()
261
  with torch.no_grad():
262
  for i, utt in enumerate(transcript.utterances):
263
+ if utt.get_num_words() >= min_num_words and utt.speaker != uptake_speaker:
264
  instance = self.input_builder.build_inputs([], utt.text,
265
  max_length=self.max_length,
266
  input_str=True)
 
430
  # Reasoning
431
  reasoning_model = ReasoningModel(
432
  self.device, self.tokenizer, self.input_builder)
433
+ reasoning_model.run_inference(transcript, uptake_speaker=uptake_speaker)
434
 
435
  # Question
436
  question_model = QuestionModel(