Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Union | |
| import torch | |
| from omegaconf import DictConfig, OmegaConf | |
| from torch import nn | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| CLIPTextModel, | |
| CLIPTokenizerFast, | |
| T5EncoderModel, | |
| T5TokenizerFast, | |
| ) | |
| from transformers.tokenization_utils_base import BatchEncoding | |
| from common.fs import download_and_extract | |
| from common.logger import get_logger | |
| logger = get_logger(__name__) | |
| MODEL_TYPES = { | |
| "clip": (CLIPTokenizerFast, CLIPTextModel), | |
| "t5": (T5TokenizerFast, T5EncoderModel), | |
| "llm14b": (AutoTokenizer, AutoModelForCausalLM), | |
| } | |
| class TextEncoderOutput: | |
| embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]] | |
| masks: Union[torch.BoolTensor, List[torch.BoolTensor]] | |
| pooled: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]] | |
| class TextEncoder(nn.Module): | |
| def __init__(self, config: DictConfig): | |
| super().__init__() | |
| self.config = config | |
| self.tokenizers = [] | |
| self.models = nn.ModuleList([]) | |
| # Disable tokenizer parallelism since we already use distributed training. | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| for model in config.models: | |
| tokenizer_cls, model_cls = MODEL_TYPES[model.type] | |
| path = download_and_extract(model.path) | |
| max_length = model.max_length | |
| if model.type == "llm14b": | |
| tokenizer = tokenizer_cls.from_pretrained( | |
| path, | |
| model_max_length=max_length, | |
| use_fast=False, | |
| trust_remote_code=True, | |
| padding_side="right", | |
| truncation_side="right", | |
| add_eod_token=True, | |
| ) | |
| tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"}) | |
| model = model_cls.from_pretrained(path, trust_remote_code=True, bf16=True) | |
| else: | |
| tokenizer = tokenizer_cls.from_pretrained(path, model_max_length=max_length) | |
| model = model_cls.from_pretrained(path, torch_dtype=torch.bfloat16) | |
| self.tokenizers.append(tokenizer) | |
| self.models.append(model) | |
| def forward(self, text: Union[str, List[str]]) -> TextEncoderOutput: | |
| embeddings, masks, pooled = [], [], [] | |
| for encoder_config, tokenizer, model in zip( | |
| self.config.models, self.tokenizers, self.models | |
| ): | |
| if encoder_config.type == "llm14b": | |
| use_mask = encoder_config.get("mask", True) | |
| tokens = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| ).to(model.device) | |
| token_ids = tokens["input_ids"] | |
| attention_mask = tokens["attention_mask"] | |
| num_tokens = attention_mask.sum(dim=1) | |
| range_ids = torch.arange(len(token_ids), device=token_ids.device, dtype=torch.long) | |
| token_ids[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = ( | |
| tokenizer.pad_token_id | |
| ) | |
| attention_mask[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = 1 | |
| tokens = BatchEncoding({"input_ids": token_ids, "attention_mask": attention_mask}) | |
| output = model.transformer( | |
| input_ids=tokens.input_ids, | |
| attention_mask=attention_mask if use_mask else None, | |
| output_hidden_states=False, | |
| use_cache=False, | |
| ) | |
| emb = output.last_hidden_state # batch_size, num_tokens, feat_dim | |
| # emb *= tokens.attention_mask.unsqueeze(-1) | |
| embeddings.append(emb) | |
| masks.append( | |
| tokens.attention_mask.bool() if use_mask else tokens.attention_mask > -1 | |
| ) | |
| else: | |
| # Tokenizer | |
| tokens = tokenizer( | |
| text=text, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| # Encoder | |
| use_mask = encoder_config.get("mask", True) | |
| input_ids = tokens.input_ids.to(model.device) | |
| attention_mask = tokens.attention_mask.to(model.device) | |
| output = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask if use_mask else None, | |
| output_hidden_states=True, | |
| ) | |
| # Save embeddings from the defined layer. | |
| layer = encoder_config.get("layer", "last") | |
| if layer == "last": | |
| embeddings.append(output.last_hidden_state) | |
| elif layer == "penultimate": | |
| embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2])) | |
| elif layer == "penultimate_nonorm": | |
| embeddings.append(output.hidden_states[-2]) | |
| else: | |
| raise NotImplementedError(f"Unknown layer type: {layer}.") | |
| # Save masks | |
| masks.append(attention_mask.bool() if use_mask else attention_mask > -1) | |
| # Save pooled output if available. | |
| if hasattr(output, "pooler_output"): | |
| pooled.append(output.pooler_output) | |
| output_config = self.config.get("output") or OmegaConf.create() | |
| embedding_output_type = output_config.get("embedding_and_mask", "undefined") | |
| pooled_output_type = output_config.get("pooled", "undefined") | |
| # Select or merge embeddings and mask if needed. | |
| if embedding_output_type == "undefined" and len(self.models) == 1: | |
| embeddings = embeddings[0] | |
| masks = masks[0] | |
| elif embedding_output_type == "channel_concat": | |
| embeddings = torch.cat(embeddings, dim=-1) | |
| masks = sum(masks).bool() | |
| elif embedding_output_type == "last": | |
| embeddings = embeddings[-1] | |
| masks = masks[-1] | |
| else: | |
| raise NotImplementedError(f"output.embedding_and_mask: {embedding_output_type}") | |
| # Select or merge pooled output if needed. | |
| if pooled_output_type == "undefined": | |
| pooled = None | |
| elif pooled_output_type == "channel_concat": | |
| pooled = torch.cat(pooled, dim=-1) | |
| elif pooled_output_type == "first": | |
| pooled = pooled[0] | |
| elif pooled_output_type == "last": | |
| pooled = pooled[-1] | |
| else: | |
| raise NotImplementedError(f"output.pooled: {pooled_output_type}") | |
| # Return final results. | |
| return TextEncoderOutput(embeddings, masks, pooled) | |