Spaces:
Running
Running
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import ListedColormap | |
import torch | |
def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=None): | |
""" | |
Plot a heatmap of tensors using seaborn | |
""" | |
sns.heatmap(tensor.cpu().numpy(), ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, annot=True, fmt=".2f", cbar=False) | |
ax.set_title(title) | |
ax.set_yticklabels([]) | |
ax.set_xticklabels([]) | |
def plot_quantization_errors(original_tensor, quantized_tensor, dequantized_tensor, dtype = torch.int8, n_bits = 8): | |
""" | |
A method that plots 4 matrices, the original tensor, the quantized tensor | |
the de-quantized tensor and the error tensor. | |
""" | |
# Get a figure of 4 plots | |
fig, axes = plt.subplots(1, 4, figsize=(15, 4)) | |
# Plot the first matrix | |
plot_matrix(original_tensor, axes[0], 'Original Tensor', cmap=ListedColormap(['white'])) | |
# Get the quantization range and plot the quantized tensor | |
q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max | |
plot_matrix(quantized_tensor, axes[1], f'{n_bits}-bit Linear Quantized Tensor', vmin=q_min, vmax=q_max, cmap='coolwarm') | |
# Plot the de-quantized tensors | |
plot_matrix(dequantized_tensor, axes[2], 'Dequantized Tensor', cmap='coolwarm') | |
# Get the quantization errors | |
q_error_tensor = abs(original_tensor - dequantized_tensor) | |
plot_matrix(q_error_tensor, axes[3], 'Quantization Error Tensor', cmap=ListedColormap(['white'])) | |
fig.tight_layout() | |
plt.show() | |