Rouwei-T5Gemma-adapter_v0.2 / code /t5gemma_states.py
Minthy's picture
Upload folder using huggingface_hub
5a8dc28 verified
"""
Module providing a wrapper class for T5-Gemma model to extract hidden states.
This module defines a simple interface for using the T5-Gemma encoder model
to generate hidden states from text inputs, which can then be used by other
components like the SDXL adapter.
"""
import os
import torch
from pathlib import Path
from typing import List, Dict, Any
from transformers import T5GemmaEncoderModel, AutoTokenizer
from safetensors.torch import save_file
class T5GemmaStates:
"""A wrapper class for T5-Gemma model to extract hidden states from text inputs."""
def __init__(self,
model_id: str,
device: str = "cpu",
dtype: torch.dtype = torch.bfloat16,
max_length: int = 512):
"""
Initialize the T5-Gemma state extractor.
Args:
model_id: Path or identifier for the pre-trained model
device: Device to run the model on (e.g., 'cuda' or 'cpu')
dtype: Data type for model computations
max_length: Maximum sequence length for tokenization
"""
self.max_length = max_length
self.device = device
# Load the encoder-only T5-Gemma model
self.model = T5GemmaEncoderModel.from_pretrained(
model_id,
device_map=device,
low_cpu_mem_usage=True,
torch_dtype=dtype,
is_encoder_decoder=False,
)
# Load the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
def __call__(self, texts: List[str]) -> Dict[str, torch.Tensor]:
"""
Generate hidden states from input text sequences.
Args:
texts: List of input text strings
Returns:
Dictionary containing:
- llm_hidden_states: Last hidden states from the model
- attention_mask: Attention mask indicating valid tokens
"""
# Append EOS token to each prompt
prompts = [text + self.tokenizer.eos_token for text in texts]
# Tokenize with padding and truncation
inputs = self.tokenizer(
prompts,
return_tensors="pt",
padding="max_length",
max_length=self.max_length,
truncation=True
).to(self.device)
# Forward pass through the model
outputs = self.model(**inputs)
result = {
"llm_hidden_states": outputs.last_hidden_state,
"attention_mask": inputs.attention_mask
}
return result
def process_caption_directory(
model_path: str,
caption_dir: str,
output_dir: str,
device: str = None,
dtype: torch.dtype = torch.bfloat16
) -> None:
"""
Process all text files in a directory and save their hidden states.
Just a simple example, for proper performance use dataloaders and batches.
Args:
model_path: Path to the pre-trained T5-Gemma model
caption_dir: Directory containing input text files
output_dir: Directory to save .safetensors files
device: Device to run inference on (auto-detected if None)
dtype: Data type for computation
"""
# Auto-detect device if not specified
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Initialize the processor
llm_processor = T5GemmaStates(model_path, device, dtype)
# Process each file in the caption directory
for filename in os.listdir(caption_dir):
filepath = os.path.join(caption_dir, filename)
# Skip if not a file
if not os.path.isfile(filepath):
continue
# Read text content using context manager
with open(filepath, 'r', encoding='utf-8') as f:
text_content = f.read().strip()
texts = [text_content]
result = llm_processor(texts)
# Save to output directory with same basename but .safetensors extension
output_filename = Path(filename).stem + ".safetensors"
save_file(result, os.path.join(output_dir, output_filename))
print(f"Processed {filename} -> {output_filename}")
if __name__ == "__main__":
"""
Example usage of the T5GemmaStates class.
This demonstrates how to process a directory of text files and extract their hidden states.
"""
model_id = "./t5gemma-2b-2b-ul2"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
caption_dir = "./inputs"
output_path = "./outputs"
process_caption_directory(model_id, caption_dir, output_path, str(device))