Update handler.py
Browse files- handler.py +9 -0
handler.py
CHANGED
|
@@ -113,6 +113,15 @@ class EndpointHandler:
|
|
| 113 |
# Decode the tokens back into samples
|
| 114 |
scenarios = self.tokenizer.decode(tokens)[:, -steps:]
|
| 115 |
print(f"Generated {n_scenarios} scenarios in {time.time() - t_start:.2f} seconds ⏱")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
return {
|
| 117 |
"timestamps": (timestamps[-1] + torch.arange(1, steps+1) * torch.median(torch.diff(timestamps)).item()).tolist(),
|
| 118 |
"open": scenarios[:, :, 0].tolist(),
|
|
|
|
| 113 |
# Decode the tokens back into samples
|
| 114 |
scenarios = self.tokenizer.decode(tokens)[:, -steps:]
|
| 115 |
print(f"Generated {n_scenarios} scenarios in {time.time() - t_start:.2f} seconds ⏱")
|
| 116 |
+
print("Nans:", torch.isnan(scenarios).sum().item())
|
| 117 |
+
print("Infs:", torch.isinf(scenarios).sum().item())
|
| 118 |
+
high_not_highest = (scenarios[..., :4].max(-1).values > scenarios[..., 1])
|
| 119 |
+
low_not_lowest = (scenarios[..., :4].min(-1).values < scenarios[..., 2])
|
| 120 |
+
invalid_candle = high_not_highest | low_not_lowest
|
| 121 |
+
print("Highest not high rate:", high_not_highest.float().mean().item())
|
| 122 |
+
print("Lowest not low rate:", low_not_lowest.float().mean().item())
|
| 123 |
+
print("Invalid candles rate:", invalid_candle.float().mean().item())
|
| 124 |
+
print("Invalid scenario rate:", invalid_candle.any(dim=-1).float().mean().item())
|
| 125 |
return {
|
| 126 |
"timestamps": (timestamps[-1] + torch.arange(1, steps+1) * torch.median(torch.diff(timestamps)).item()).tolist(),
|
| 127 |
"open": scenarios[:, :, 0].tolist(),
|