| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from functools import lru_cache |
| from typing import Iterable, List |
|
|
| import torch |
|
|
|
|
| @dataclass |
| class Blip2TextStateEncoderConfig: |
| model_name_or_path: str = "Salesforce/blip2-itm-vit-g" |
| device: str = "cpu" |
| torch_dtype: torch.dtype = torch.float16 |
| max_length: int = 32 |
|
|
|
|
| class Blip2TextStateEncoder: |
| """ |
| 用 BLIP2 的 `Blip2TextModelWithProjection` 把状态文本编码为一个向量(text_embeds)。 |
| |
| 设计目标: |
| - 状态在数据里用可读字符串(例如 "raw", "cooked") |
| - 训练/推理阶段把这些字符串变成 state_features: (B,N,D_text) |
| - 下游 InstanceFeatureExtractor 再把 D_text 投影到 DiT hidden_dim |
| """ |
|
|
| def __init__(self, cfg: Blip2TextStateEncoderConfig): |
| self.cfg = cfg |
| self._tokenizer = None |
| self._model = None |
|
|
| def _lazy_init(self): |
| if self._model is not None: |
| return |
| from transformers import AutoTokenizer, Blip2TextModelWithProjection |
|
|
| self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name_or_path) |
| self._model = Blip2TextModelWithProjection.from_pretrained( |
| self.cfg.model_name_or_path, |
| torch_dtype=self.cfg.torch_dtype, |
| ) |
| self._model.eval() |
| for p in self._model.parameters(): |
| p.requires_grad_(False) |
| self._model.to(device=self.cfg.device) |
|
|
| @torch.inference_mode() |
| def encode_texts(self, texts: List[str]) -> torch.Tensor: |
| self._lazy_init() |
| tok = self._tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=self.cfg.max_length, |
| return_tensors="pt", |
| ) |
| tok = {k: v.to(self.cfg.device) for k, v in tok.items()} |
| out = self._model(**tok) |
| |
| return out.text_embeds.to(dtype=torch.float32, device="cpu") |
|
|
|
|
| def encode_state_text_tensor( |
| state_texts: list, |
| model_name_or_path: str = "Salesforce/blip2-itm-vit-g", |
| device: str = "cpu", |
| torch_dtype: torch.dtype = torch.float16, |
| max_length: int = 32, |
| ) -> torch.Tensor: |
| """ |
| 将嵌套 list 的 state_texts(B,N)编码成 tensor: (B,N,D_text) float32 on CPU。 |
| """ |
| if not isinstance(state_texts, list) or not state_texts: |
| raise ValueError("state_texts must be a non-empty nested list (B,N)") |
| if not isinstance(state_texts[0], list): |
| raise ValueError("state_texts must be nested list like [[...], [...]]") |
|
|
| encoder = Blip2TextStateEncoder( |
| Blip2TextStateEncoderConfig( |
| model_name_or_path=model_name_or_path, |
| device=device, |
| torch_dtype=torch_dtype, |
| max_length=max_length, |
| ) |
| ) |
|
|
| |
| all_texts = [] |
| for row in state_texts: |
| for t in row: |
| if not isinstance(t, str): |
| raise ValueError(f"state_text must be str, got: {type(t)}") |
| all_texts.append(t) |
| uniq = sorted(set(all_texts)) |
| emb = encoder.encode_texts(uniq) |
| table = {t: emb[i] for i, t in enumerate(uniq)} |
|
|
| b = len(state_texts) |
| n = len(state_texts[0]) |
| out = torch.stack([torch.stack([table[t] for t in row], dim=0) for row in state_texts], dim=0) |
| |
| if out.shape[0] != b or out.shape[1] != n: |
| raise RuntimeError(f"unexpected encoded shape: {tuple(out.shape)}") |
| return out |
|
|
|
|