| import argparse |
| import os |
| import torch |
| from safetensors.torch import load_file |
| from rich.console import Console |
| from rich.table import Table |
|
|
| from transformers import AutoModelForMaskedLM, AutoConfig, AutoModel |
|
|
| from e1_fastplms.modeling_e1 import E1ForMaskedLM, E1Config, E1Model |
|
|
|
|
| def load_weights(path, cast_fp32=True): |
| assert os.path.exists(path), f"File {path} not found." |
| if path.endswith(".safetensors"): |
| sd = load_file(path) |
| elif path.endswith(".pth") or path.endswith(".pt"): |
| sd = torch.load(path, map_location="cpu", weights_only=True) |
| if isinstance(sd, dict) and "state_dict" in sd: |
| sd = sd["state_dict"] |
| elif isinstance(sd, dict) and "model" in sd: |
| sd = sd["model"] |
| else: |
| try: |
| sd = load_file(path) |
| except Exception: |
| sd = torch.load(path, map_location="cpu", weights_only=True) |
| |
| if cast_fp32: |
| return {k: v.float() if isinstance(v, torch.Tensor) else v for k, v in sd.items()} |
| return sd |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--file1", type=str, default=None) |
| parser.add_argument("--files", type=str, nargs="+", default=None) |
| parser.add_argument("--strict", action="store_true") |
| parser.add_argument("--assert_exact", action="store_true") |
| args = parser.parse_args() |
|
|
| model = E1ForMaskedLM.from_pretrained('Profluent-Bio/E1-150m', dtype=torch.float32).eval() |
| torch.save(model.state_dict(), 'official.pth') |
|
|
| config = AutoConfig.from_pretrained('Synthyra/Profluent-E1-150M', trust_remote_code=True) |
| model1 = AutoModel.from_pretrained('Synthyra/Profluent-E1-150M', dtype=torch.float32, trust_remote_code=True).eval() |
| torch.save(model1.state_dict(), 'load_from_pretrained_1.pth') |
| model2 = AutoModelForMaskedLM.from_pretrained('Synthyra/Profluent-E1-150M', dtype=torch.float32, trust_remote_code=True).eval() |
| torch.save(model2.state_dict(), 'load_from_pretrained_2.pth') |
|
|
| if args.file1 is None: |
| args.file1 = 'official.pth' |
| if args.files is None: |
| args.files = ['load_from_pretrained_1.pth', 'load_from_pretrained_2.pth', 'old.safetensors'] |
|
|
| paths = [args.file1] + args.files |
| sds = [load_weights(p, cast_fp32=not args.strict) for p in paths] |
| all_keys = sorted(set().union(*(sd.keys() for sd in sds))) |
| strict_mismatches = [] |
|
|
| console = Console() |
| table = Table(title=f"Weights Comparison (Reference: {os.path.basename(paths[0])})") |
| table.add_column("Tensor Name", style="cyan", no_wrap=True) |
| |
| for p in paths[1:]: |
| table.add_column(f"{os.path.basename(p)} == Ref", justify="center") |
| |
| sd1 = sds[0] |
| for k in all_keys: |
| row = [k] |
| |
| has_ref = k in sd1 |
| ref_w = sd1[k] if has_ref else None |
| |
| for sd in sds[1:]: |
| has_other = k in sd |
| other_w = sd[k] if has_other else None |
| |
| if not has_ref or not has_other: |
| if not has_ref and not has_other: |
| row.append("[dim]β[/dim]") |
| else: |
| row.append("[red]β[/red]") |
| else: |
| |
| assert isinstance(ref_w, torch.Tensor), f"Weight {k} in reference is not a tensor." |
| assert isinstance(other_w, torch.Tensor), f"Weight {k} in comparison file is not a tensor." |
| |
| if ref_w.shape != other_w.shape: |
| row.append("[red]β (Shape)[/red]") |
| else: |
| if args.strict: |
| if torch.equal(ref_w, other_w): |
| row.append("[green]β[/green]") |
| else: |
| mse = torch.mean((ref_w.float() - other_w.float())**2).item() |
| row.append(f"[red]β (Strict, MSE: {mse:.2e})[/red]") |
| strict_mismatches.append(k) |
| else: |
| mse = torch.mean((ref_w - other_w)**2).item() |
| if mse == 0: |
| row.append("[green]β[/green]") |
| else: |
| row.append(f"[red]β (MSE: {mse:.2e})[/red]") |
| |
| table.add_row(*row) |
|
|
| console.print(table) |
| if args.strict and args.assert_exact: |
| assert len(strict_mismatches) == 0, ( |
| f"Found {len(strict_mismatches)} strict mismatches. " |
| f"First mismatches: {strict_mismatches[:10]}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |