[fix]: lowercase input and end with a period.
Browse files
app.py
CHANGED
@@ -45,6 +45,7 @@ def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights
|
|
45 |
@torch.inference_mode()
|
46 |
def run(input: str) -> Tuple[str, plt.Figure]:
|
47 |
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
|
|
|
48 |
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
|
49 |
output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
|
50 |
output = target_spm.decode(output.detach().tolist())
|
|
|
45 |
@torch.inference_mode()
|
46 |
def run(input: str) -> Tuple[str, plt.Figure]:
|
47 |
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
|
48 |
+
input = input.lower().strip().rstrip(".") + "."
|
49 |
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
|
50 |
output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
|
51 |
output = target_spm.decode(output.detach().tolist())
|