msarmi9 commited on
Commit
3815353
1 Parent(s): 8c7a320

[fix]: lowercase input and end with a period.

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -45,6 +45,7 @@ def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights
45
  @torch.inference_mode()
46
  def run(input: str) -> Tuple[str, plt.Figure]:
47
  """Run inference on a single sentence. Returns prediction and attention heatmap."""""
 
48
  input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
49
  output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
50
  output = target_spm.decode(output.detach().tolist())
 
45
  @torch.inference_mode()
46
  def run(input: str) -> Tuple[str, plt.Figure]:
47
  """Run inference on a single sentence. Returns prediction and attention heatmap."""""
48
+ input = input.lower().strip().rstrip(".") + "."
49
  input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
50
  output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
51
  output = target_spm.decode(output.detach().tolist())