project_02_DS / task /task_04 /step6_visualize.py
griddev's picture
Deploy Streamlit Space app
0710b5c verified
"""
step6_visualize.py
===================
Task 4 β€” Component 6: Generate publication-quality figures.
Produces three figures from the Task 4 results:
Figure 1: diversity_histogram.png
Histogram of per-image diversity scores (200 images).
"Diverse" (>0.75) and "Repetitive" (<0.40) zones are shaded.
Figure 2: diverse_vs_repetitive.png
3-row grid: actual image thumbnail (or coloured placeholder) on the
left, 5 generated captions + diversity score badge on the right.
Two sides: top-3 most diverse (left half) vs top-3 most repetitive
(right half). Thumbnails loaded from results/images/img_{id}.jpg.
Figure 3: steering_lambda_sweep.png
Dual-axis line chart: Ξ» on x-axis, mean caption length (left y-axis)
and mean unique word count (right y-axis). Ξ»=0 baseline annotated.
Public API
----------
plot_diversity_histogram(records, save_dir) -> str (path)
plot_diverse_vs_repetitive(records, save_dir) -> str
plot_steering_lambda_sweep(results, save_dir) -> str
visualize_all(records, steering_results, save_dir) -> dict[str, str]
Standalone usage
----------------
export PYTHONPATH=.
venv/bin/python task/task_04/step6_visualize.py
"""
import os
import sys
import json
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
# ─────────────────────────────────────────────────────────────────────────────
# Colour palette
# ─────────────────────────────────────────────────────────────────────────────
C_DIVERSE = "#4C72B0" # blue
C_MEDIUM = "#55A868" # green
C_REPETITIVE = "#C44E52" # red
C_LAMBDA = "#DD8452" # orange
C_UNIQ = "#8172B2" # purple
# ─────────────────────────────────────────────────────────────────────────────
# Figure 1 β€” Diversity histogram
# ─────────────────────────────────────────────────────────────────────────────
def plot_diversity_histogram(records: list,
save_dir: str = "task/task_04/results") -> str:
"""
Histogram of per-image diversity scores. Zones:
"Repetitive" < 0.40 -> shaded red
"Diverse" > 0.75 -> shaded blue
"""
os.makedirs(save_dir, exist_ok=True)
scores = [r["diversity_score"] for r in records]
fig, ax = plt.subplots(figsize=(9, 5))
# Shaded zones
ax.axvspan(0.0, 0.40, alpha=0.12, color=C_REPETITIVE, label="Repetitive zone (<0.40)")
ax.axvspan(0.75, 1.00, alpha=0.12, color=C_DIVERSE, label="Diverse zone (>0.75)")
# Histogram
n_bins = 25
ax.hist(scores, bins=n_bins, color=C_MEDIUM, edgecolor="white",
linewidth=0.6, alpha=0.9, label="Image count")
# Threshold lines
ax.axvline(0.40, color=C_REPETITIVE, linewidth=1.8, linestyle="--")
ax.axvline(0.75, color=C_DIVERSE, linewidth=1.8, linestyle="--")
# Mean line
mean_score = np.mean(scores)
ax.axvline(mean_score, color="#333", linewidth=1.4, linestyle=":",
label=f"Mean = {mean_score:.3f}")
ax.set_xlabel("Diversity Score (unique n-grams / total n-grams)", fontsize=12)
ax.set_ylabel("Number of Images", fontsize=12)
ax.set_title("Caption Diversity Distribution Across COCO Images\n"
"(5 nucleus-sampled captions per image, p=0.9)",
fontsize=13, fontweight="bold", pad=10)
ax.legend(fontsize=10)
ax.yaxis.set_minor_locator(mticker.AutoMinorLocator())
ax.grid(axis="y", linestyle="--", alpha=0.4)
ax.set_xlim(0, 1)
# Annotations
n_rep = sum(1 for s in scores if s < 0.40)
n_div = sum(1 for s in scores if s > 0.75)
ax.text(0.20, ax.get_ylim()[1] * 0.85, f"{n_rep} images\n(repetitive)",
ha="center", va="top", color=C_REPETITIVE, fontsize=9)
ax.text(0.875, ax.get_ylim()[1] * 0.85, f"{n_div} images\n(diverse)",
ha="center", va="top", color=C_DIVERSE, fontsize=9)
fig.tight_layout()
path = os.path.join(save_dir, "diversity_histogram.png")
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" OK Saved: {path}")
return path
# ─────────────────────────────────────────────────────────────────────────────
# Figure 2 β€” Diverse vs Repetitive (image thumbnail + captions)
# ─────────────────────────────────────────────────────────────────────────────
def _load_thumb_array(img_id: int, save_dir: str):
"""Load thumbnail JPEG from results/images/. Returns numpy array or None."""
from PIL import Image as PILImage
path = os.path.join(save_dir, "images", f"img_{img_id}.jpg")
if os.path.exists(path):
try:
return np.array(PILImage.open(path).convert("RGB"))
except Exception:
pass
return None
def plot_diverse_vs_repetitive(records: list,
save_dir: str = "task/task_04/results") -> str:
"""
3-row image-caption grid.
Each row: [thumbnail | 5 captions] for one image.
Left half = top-3 diverse, right half = top-3 repetitive.
Thumbnails come from results/images/img_{id}.jpg (generated by step3).
Falls back to a coloured score-labelled placeholder if file missing.
"""
os.makedirs(save_dir, exist_ok=True)
def _get_top_unique(recs, reverse=True, n=3):
sorted_recs = sorted(recs, key=lambda r: r["diversity_score"], reverse=reverse)
unique_recs = []
seen = set()
for r in sorted_recs:
cap_hash = tuple(r["captions"])
if cap_hash not in seen:
seen.add(cap_hash)
unique_recs.append(r)
if len(unique_recs) == n:
break
return unique_recs
diverse = _get_top_unique(records, reverse=True, n=3)
repetitive = _get_top_unique(records, reverse=False, n=3)
N = 3
fig = plt.figure(figsize=(17, 12), facecolor="#F4F4F4")
# 4 columns: [img_div | cap_div | img_rep | cap_rep]
gs = GridSpec(N, 4, figure=fig,
hspace=0.60, wspace=0.10,
left=0.03, right=0.97,
top=0.90, bottom=0.03,
width_ratios=[1, 2.2, 1, 2.2])
def _render_row(row, rec, img_col, cap_col, badge_color):
arr = _load_thumb_array(rec["image_id"], save_dir)
score = rec["diversity_score"]
cat = rec.get("category", "")
# Image cell
ax_img = fig.add_subplot(gs[row, img_col])
ax_img.axis("off")
if arr is not None:
ax_img.imshow(arr, aspect="auto", interpolation="bilinear")
else:
ax_img.set_facecolor(badge_color)
ax_img.text(0.5, 0.55, f"Image #{rec['image_id']}",
ha="center", va="center", color="white",
fontsize=9, fontweight="bold",
transform=ax_img.transAxes)
ax_img.text(0.5, 0.35, f"{score:.3f}",
ha="center", va="center", color="white",
fontsize=9, transform=ax_img.transAxes)
ax_img.set_title(f"Score: {score:.3f} [{cat}]",
fontsize=8, color=badge_color,
pad=3, fontweight="bold")
# Caption cell
ax_cap = fig.add_subplot(gs[row, cap_col])
ax_cap.set_facecolor("#FAFAFA")
ax_cap.axis("off")
y = 0.97
for ci, cap in enumerate(rec["captions"][:5], 1):
words, line, lines = cap.split(), "", []
for w in words:
if len(line) + len(w) + 1 > 52:
lines.append(line)
line = w
else:
line = (line + " " + w).strip()
if line:
lines.append(line)
first = True
for ln in lines:
prefix = f"{ci}. " if first else " "
ax_cap.text(0.03, y, prefix + ln,
transform=ax_cap.transAxes,
fontsize=7.8, color="#333", va="top",
fontfamily="monospace")
y -= 0.13
first = False
y -= 0.04
for row, rec in enumerate(diverse):
_render_row(row, rec, img_col=0, cap_col=1, badge_color=C_DIVERSE)
for row, rec in enumerate(repetitive):
_render_row(row, rec, img_col=2, cap_col=3, badge_color=C_REPETITIVE)
# Column headers
fig.text(0.28, 0.945, "Top-3 Most DIVERSE Images",
ha="center", va="bottom", fontsize=13, fontweight="bold",
color=C_DIVERSE)
fig.text(0.75, 0.945, "Top-3 Most REPETITIVE Images",
ha="center", va="bottom", fontsize=13, fontweight="bold",
color=C_REPETITIVE)
# Centre divider
sep = Line2D([0.505, 0.505], [0.02, 0.94],
transform=fig.transFigure,
color="#BBBBBB", linewidth=1.5, linestyle="--")
fig.add_artist(sep)
fig.suptitle("Caption Style Extremes β€” COCO Validation Set\n"
"(5 nucleus-sampled captions per image, top_p=0.9)",
fontsize=13, fontweight="bold", y=0.995)
path = os.path.join(save_dir, "diverse_vs_repetitive.png")
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" OK Saved: {path}")
return path
# ─────────────────────────────────────────────────────────────────────────────
# Figure 3 β€” Lambda sweep chart
# ─────────────────────────────────────────────────────────────────────────────
def plot_steering_lambda_sweep(steering_results: list,
save_dir: str = "task/task_04/results") -> str:
"""
Dual-axis line chart: lambda (x) vs mean caption length (left y) and
mean unique word count (right y). lambda=0 baseline marked.
"""
os.makedirs(save_dir, exist_ok=True)
lambdas = [r["lambda"] for r in steering_results]
lengths = [r["mean_length"] for r in steering_results]
uniq = [r["mean_unique_words"] for r in steering_results]
fig, ax1 = plt.subplots(figsize=(9, 5.5))
# Left axis: mean length
ax1.plot(lambdas, lengths, "-o", color=C_LAMBDA, linewidth=2,
markersize=7, label="Mean Caption Length (words)")
ax1.set_xlabel("Steering Strength (lambda)", fontsize=12)
ax1.set_ylabel("Mean Caption Length (words)", color=C_LAMBDA, fontsize=12)
ax1.tick_params(axis="y", labelcolor=C_LAMBDA)
# Right axis: unique words
ax2 = ax1.twinx()
ax2.plot(lambdas, uniq, "-s", color=C_UNIQ, linewidth=2,
markersize=7, label="Mean Unique Words")
ax2.set_ylabel("Mean Unique Word Count", color=C_UNIQ, fontsize=12)
ax2.tick_params(axis="y", labelcolor=C_UNIQ)
# Baseline
ax1.axvline(0.0, color="#999", linewidth=1.4, linestyle="--")
ax1.text(0.02, max(lengths) * 0.97, "lambda=0\nbaseline",
color="#777", fontsize=8.5)
# Double-headed arrow annotation
ax1.annotate("", xy=(lambdas[-1], lengths[-1] + 0.4),
xytext=(lambdas[0], lengths[0] + 0.4),
arrowprops=dict(arrowstyle="<->", color="#555", lw=1.2))
mid_x = (lambdas[0] + lambdas[-1]) / 2
ax1.text(mid_x, lengths[-1] + 0.6, "steering effect on length",
ha="center", fontsize=8.5, color="#555")
# Combined legend
h1, l1 = ax1.get_legend_handles_labels()
h2, l2 = ax2.get_legend_handles_labels()
ax1.legend(h1 + h2, l1 + l2, fontsize=9, loc="upper left")
ax1.set_title("Concept Steering Effect: lambda x d_short2detail\n"
"(BLIP decoder hidden-state injection, beam=3)",
fontsize=12, fontweight="bold", pad=10)
ax1.grid(linestyle="--", alpha=0.35)
fig.tight_layout()
path = os.path.join(save_dir, "steering_lambda_sweep.png")
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" OK Saved: {path}")
return path
# ─────────────────────────────────────────────────────────────────────────────
# Master
# ─────────────────────────────────────────────────────────────────────────────
def visualize_all(records: list, steering_results: list,
save_dir: str = "task/task_04/results") -> dict:
"""
Generate all three figures.
Returns:
dict with keys 'histogram', 'extremes', 'lambda_sweep' -> absolute paths
"""
print("=" * 62)
print(" Task 4 -- Step 6: Generate Visualizations")
print("=" * 62)
paths = {
"histogram": plot_diversity_histogram(records, save_dir),
"extremes": plot_diverse_vs_repetitive(records, save_dir),
"lambda_sweep": plot_steering_lambda_sweep(steering_results, save_dir),
}
print(f"\n 3 figures saved to: {save_dir}")
return paths
# ─────────────────────────────────────────────────────────────────────────────
# Standalone entrypoint
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
SAVE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
div_cache = os.path.join(SAVE_DIR, "diversity_results.json")
if os.path.exists(div_cache):
with open(div_cache) as f:
records = json.load(f)
print(f" Loaded diversity results from {div_cache}")
else:
from step3_diversity_analysis import _make_precomputed
records = _make_precomputed()
steer_cache = os.path.join(SAVE_DIR, "steering_results.json")
if os.path.exists(steer_cache):
with open(steer_cache) as f:
steering_results = json.load(f)
print(f" Loaded steering results from {steer_cache}")
else:
from step5_steer_and_eval import PRECOMPUTED_STEERING
steering_results = PRECOMPUTED_STEERING
paths = visualize_all(records, steering_results, SAVE_DIR)
print("\n All figures generated.")
for name, p in paths.items():
print(f" {name:14}: {p}")