jsflow / fid_custom.py
xiangzai's picture
Add files using upload-large-folder tool
b65e56d verified
#!/usr/bin/env python3
"""
Fréchet Inception Distance(FID),与 evaluator 中 sqrt(Σ₁Σ₂) 路径对齐。
具体哪几行会引入虚部/负值,见 fid() 内按行注释与 --diagnose 输出。
"""
import argparse
import os
import warnings
import numpy as np
from scipy import linalg
import tensorflow.compat.v1 as tf
from evaluator import Evaluator
def _trace_sqrt_symmetric_psd(M):
"""
对“应当对称 PSD”的矩阵 M 计算 Tr(sqrt(M))。
这里不直接用 scipy.linalg.sqrtm(它会对数值上“几乎 PSD”的矩阵也返回复数),
而是:
- 先对称化:M ← (M+M^T)/2
- 做对称特征分解:M = Q diag(λ) Q^T
- 将数值误差造成的负特征值截断到 0(PSD 投影)
- Tr(sqrt(M)) = Σ sqrt(max(λ,0))
"""
M = (M + M.T) / 2.0
if not np.isfinite(M).all():
warnings.warn("mid matrix contains NaN/Inf; returning NaN for Tr(sqrt(M))")
return float("nan")
# eigvalsh 在“极端病态/数值误差较大”的情况下可能不收敛;
# 加一丁点对角 jitter 后重试,通常能恢复稳定。
I = np.eye(M.shape[0], dtype=M.dtype)
trace = float(np.trace(M))
diag_jitter = 1e-6 * (trace / max(M.shape[0], 1))
if not np.isfinite(diag_jitter) or diag_jitter <= 0:
diag_jitter = 1e-6
try:
w = np.linalg.eigvalsh(M)
except np.linalg.LinAlgError:
w = np.linalg.eigvalsh(M + diag_jitter * I)
w_min = float(w.min()) if w.size else 0.0
if w_min < -1e-6:
warnings.warn(
"PSD 中间矩阵出现明显负特征值 λ_min={:.3e};已截断到 0 以稳定 Tr(sqrt(M))".format(
w_min
)
)
w = np.clip(w, 0.0, None)
return float(np.sqrt(w).sum())
def fid_psd_geometric(mn1, cov1, mn2, cov2):
"""
Fréchet 项用 Tr( sqrt( Σ₁^{1/2} Σ₂ Σ₁^{1/2} ) )(中间矩阵对称 PSD),
与 sqrt(Σ₁Σ₂) 在一般情形下不等价;此值在精确算术下应对应非负距离平方(数值上可能极小负)。
用于对照:若本函数非负而 fid() 为负,问题在 sqrt(Σ₁Σ₂) 与 .real 那几行。
"""
mn1 = np.atleast_1d(mn1)
mn2 = np.atleast_1d(mn2)
cov1 = np.atleast_2d(cov1)
cov2 = np.atleast_2d(cov2)
cov1 = (cov1 + cov1.T) / 2.0
cov2 = (cov2 + cov2.T) / 2.0
diff = mn1 - mn2
# 用对称特征分解计算 Σ₁^{1/2}(并做 PSD 投影),避免 sqrtm 的复数分支。
w1, v1 = np.linalg.eigh(cov1)
w1_min = float(w1.min()) if w1.size else 0.0
if w1_min < -1e-6:
warnings.warn(
"Σ₁ 出现明显负特征值 λ_min={:.3e};已截断到 0 以稳定 Σ₁^{1/2}".format(w1_min)
)
w1 = np.clip(w1, 0.0, None)
s1 = (v1 * np.sqrt(w1)) @ v1.T
mid = s1 @ cov2 @ s1
tr_cross = _trace_sqrt_symmetric_psd(mid)
return float(diff.dot(diff) + np.trace(cov1) + np.trace(cov2) - 2.0 * tr_cross)
def fid(mn1, cov1, mn2, cov2, eps=1e-6, diagnose=False, cross_type="psd_geometric"):
"""
‖μ₁−μ₂‖² + Tr(Σ₁)+Tr(Σ₂) − 2·Tr(交叉项)。
- `cross_type="product"`: 交叉项用 `Tr(sqrt(Σ₁Σ₂))`(对齐 evaluator 的 sqrt(σ1·σ2) 路径)。
- `cross_type="psd_geometric"`: 交叉项用 `Tr(sqrt(Σ₁^{1/2} Σ₂ Σ₁^{1/2}))`(中间矩阵对称 PSD,数值更稳)。
diagnose=True 时返回 (fid_value, diag_dict),用于定位是哪类数值失配。
"""
mn1 = np.atleast_1d(mn1)
mn2 = np.atleast_1d(mn2)
cov1 = np.atleast_2d(cov1)
cov2 = np.atleast_2d(cov2)
if mn1.shape != mn2.shape:
raise ValueError(f"mean shape mismatch: {mn1.shape} vs {mn2.shape}")
if cov1.shape != cov2.shape:
raise ValueError(f"cov shape mismatch: {cov1.shape} vs {cov2.shape}")
diff = mn1 - mn2
# 对称化(减小 sqrtm 输入的数值非对称);但这不会改变数学定义。
asym1 = np.max(np.abs(cov1 - cov1.T))
asym2 = np.max(np.abs(cov2 - cov2.T))
cov1 = (cov1 + cov1.T) / 2.0
cov2 = (cov2 + cov2.T) / 2.0
d = cov1.shape[0]
# 默认走更稳的 PSD 几何平均路径;但 diagnose=True 时仍会计算 product 路径以定位虚部根因。
need_product = (cross_type == "product") or diagnose
fid_product = None
product_diag = {}
if need_product:
prod = cov1.dot(cov2)
prod_sym_err = np.max(np.abs(prod - prod.T))
jitter_list = [0.0, eps, 1e-5, 1e-4, 1e-3]
covmean = None
j_used = None
for j in jitter_list:
if j == 0.0:
# 问题点①:sqrtm(cov1.dot(cov2)) 中间矩阵一般不对称/非 PSD,可能落入复域。
cm, _ = linalg.sqrtm(cov1.dot(cov2), disp=False)
else:
offset = np.eye(d) * j
# 问题点②:sqrtm 作用在 (Σ₁+δI)(Σ₂+δI),但下面 trace(Σ₁)+trace(Σ₂) 仍用原 Σ。
cm, _ = linalg.sqrtm((cov1 + offset).dot(cov2 + offset), disp=False)
if np.isfinite(cm).all():
covmean = cm
j_used = j
break
if covmean is None:
raise ValueError("FID sqrtm failed: non-finite covmean even after jitter retries")
had_imag = bool(np.iscomplexobj(covmean))
imag_max = float(np.max(np.abs(covmean.imag))) if had_imag else 0.0
if had_imag:
# 问题点③:截断虚部导致 Tr(covmean) 不再是同一交叉项的严格展开。
if imag_max > 1e-3:
warnings.warn(
"Large imaginary component in sqrtm ({:.6f}); taking real part.".format(imag_max)
)
covmean = covmean.real
tr_covmean = np.trace(covmean)
fid_product = float(diff.dot(diff) + np.trace(cov1) + np.trace(cov2) - 2.0 * tr_covmean)
product_diag = {
"asym_before_symmetrize_max": (float(asym1), float(asym2)),
"prod_symmetry_error": float(prod_sym_err),
"jitter_used": float(j_used) if j_used is not None else None,
"had_imaginary_sqrtm": had_imag,
"imag_max": imag_max,
"fid_sqrt_product_path": fid_product,
}
if cross_type == "psd_geometric" and not diagnose:
# 快速路径:只算稳定的 PSD 几何平均交叉项。
return fid_psd_geometric(mn1, cov1, mn2, cov2)
# diagnose=True 或 cross_type="product" 时:算两条路径并拼 diag。
fid_psd = None
fid_psd_err = None
try:
fid_psd = fid_psd_geometric(mn1, cov1, mn2, cov2)
except Exception as e:
fid_psd = float("nan")
fid_psd_err = repr(e)
if cross_type == "product":
fid_val = fid_product
else:
fid_val = fid_psd
if not diagnose:
return float(fid_val)
# 诊断告警:主要针对 product 路径的数值风险点。
wlist = []
if asym1 > 1e-4 or asym2 > 1e-4:
wlist.append(
"对称化前协方差不对称量过大 (max asym1={:.3e}, asym2={:.3e})".format(asym1, asym2)
)
if product_diag.get("prod_symmetry_error", 0.0) > 1e-2:
wlist.append(
"Σ₁Σ₂ 与对称矩阵偏离较大 (max|A-A^T|={:.3e}),sqrtm 复值风险高".format(
product_diag["prod_symmetry_error"]
)
)
if product_diag.get("jitter_used") and product_diag["jitter_used"] > 0.0:
wlist.append(
"使用了 jitter={}:sqrtm 与 Tr(Σ₁)+Tr(Σ₂) 项不一致,负 FID 风险".format(
product_diag["jitter_used"]
)
)
if product_diag.get("had_imaginary_sqrtm"):
wlist.append(
"sqrtm(Σ₁Σ₂) 含虚部 (max|Im|={:.3e}),已取 .real,与理论 Fréchet 项可能失配".format(
product_diag["imag_max"]
)
)
if product_diag.get("fid_sqrt_product_path") is not None and product_diag["fid_sqrt_product_path"] < 0:
wlist.append("product 路径 FID<0(数值上不应出现)")
if not np.isnan(fid_psd) and fid_psd >= -1e-5 and product_diag.get("fid_sqrt_product_path", 0.0) < 0:
wlist.append(
"对照 fid_psd_geometric≈{:.6f} 非负,负值主要来自 sqrt(Σ₁Σ₂)/.real/jitter 路径".format(fid_psd)
)
diag = {
**product_diag,
"fid_psd_geometric": fid_psd,
"fid_psd_geometric_error": fid_psd_err,
"warnings": wlist,
}
return float(fid_val), diag
def print_diagnosis(diag):
print("--- FID 诊断(按代码风险点)---")
a0, a1 = diag.get("asym_before_symmetrize_max", (None, None))
if a0 is not None:
print(f" 对称化前 |Σ-Σ^T|_max: ref={a0:.3e}, sample={a1:.3e}")
print(f" |Σ₁Σ₂ - (Σ₁Σ₂)^T|_max: {diag.get('prod_symmetry_error', float('nan')):.3e}")
print(f" sqrtm 使用的 jitter: {diag.get('jitter_used', None)}")
if diag.get("had_imaginary_sqrtm"):
print(
f" sqrtm(Σ₁Σ₂) 含虚部: {diag['had_imaginary_sqrtm']}, max|Im|: {diag['imag_max']:.3e}"
)
else:
print(f" sqrtm(Σ₁Σ₂) 含虚部: {diag.get('had_imaginary_sqrtm', False)}")
if diag.get("fid_sqrt_product_path") is not None:
print(f" FID(sqrt(Σ₁Σ₂) 路径): {diag['fid_sqrt_product_path']:.6f}")
if diag.get("fid_psd_geometric_error"):
print(f" FID(PSD 几何平均路径) 未算: {diag['fid_psd_geometric_error']}")
else:
print(f" FID(PSD 几何平均路径): {diag.get('fid_psd_geometric', float('nan')):.6f}")
for s in diag["warnings"]:
print(f" [检查] {s}")
print("---")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--ref_batch",
default="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/VIRTUAL_imagenet256_labeled/imagenet/VIRTUAL_imagenet256_labeled.npz",
help="reference npz (arr_0 images or stats npz)",
)
parser.add_argument("--sample_batch", required=True, help="sample npz path")
parser.add_argument("--save_txt", type=str, default=None, help="optional txt output path")
parser.add_argument(
"--cross-type",
type=str,
default="psd_geometric",
choices=["product", "psd_geometric"],
help="交叉项计算方式:product=Tr(sqrt(Σ₁Σ₂))(对齐 evaluator),psd_geometric=Tr(sqrt(Σ₁^{1/2}Σ₂Σ₁^{1/2}))(更稳)",
)
parser.add_argument(
"--diagnose",
action="store_true",
help="打印逐行风险检测:jitter/虚部/Σ₁Σ₂ 对称性,并对照 PSD 几何平均路径",
)
args = parser.parse_args()
# keep same behavior as evaluator.py: force CPU for TF eval
os.environ["CUDA_VISIBLE_DEVICES"] = ""
config = tf.ConfigProto(allow_soft_placement=True, device_count={"GPU": 0})
evaluator = Evaluator(tf.Session(config=config))
print("warming up TensorFlow...")
evaluator.warmup()
print("computing reference batch activations...")
ref_acts = evaluator.read_activations(args.ref_batch)
print("computing/reading reference batch statistics...")
ref_stats, _ = evaluator.read_statistics(args.ref_batch, ref_acts)
print("computing sample batch activations...")
sample_acts = evaluator.read_activations(args.sample_batch)
print("computing/reading sample batch statistics...")
sample_stats, _ = evaluator.read_statistics(args.sample_batch, sample_acts)
print("Computing custom FID...")
# Fréchet 距离对两分布对称;此处参数顺序与 evaluator 中 frechet_distance(sample, ref) 一致。
if args.diagnose:
fid_value, diag = fid(
sample_stats.mu,
sample_stats.sigma,
ref_stats.mu,
ref_stats.sigma,
diagnose=True,
cross_type=args.cross_type,
)
print_diagnosis(diag)
else:
fid_value = fid(
sample_stats.mu,
sample_stats.sigma,
ref_stats.mu,
ref_stats.sigma,
cross_type=args.cross_type,
)
print(f"FID(custom): {fid_value}")
if args.save_txt:
with open(args.save_txt, "w", encoding="utf-8") as f:
f.write(f"ref_batch: {args.ref_batch}\n")
f.write(f"sample_batch: {args.sample_batch}\n")
f.write(f"FID(custom): {fid_value}\n")
if args.diagnose:
f.write(f"diagnose: jitter_used={diag['jitter_used']}, imag_max={diag['imag_max']}\n")
for s in diag["warnings"]:
f.write(f" {s}\n")
print(f"Saved report to {args.save_txt}")
if __name__ == "__main__":
main()