|
|
""" |
|
|
Module providing a wrapper class for T5-Gemma model to extract hidden states. |
|
|
|
|
|
This module defines a simple interface for using the T5-Gemma encoder model |
|
|
to generate hidden states from text inputs, which can then be used by other |
|
|
components like the SDXL adapter. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any |
|
|
from transformers import T5GemmaEncoderModel, AutoTokenizer |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
class T5GemmaStates: |
|
|
"""A wrapper class for T5-Gemma model to extract hidden states from text inputs.""" |
|
|
|
|
|
def __init__(self, |
|
|
model_id: str, |
|
|
device: str = "cpu", |
|
|
dtype: torch.dtype = torch.bfloat16, |
|
|
max_length: int = 512): |
|
|
""" |
|
|
Initialize the T5-Gemma state extractor. |
|
|
|
|
|
Args: |
|
|
model_id: Path or identifier for the pre-trained model |
|
|
device: Device to run the model on (e.g., 'cuda' or 'cpu') |
|
|
dtype: Data type for model computations |
|
|
max_length: Maximum sequence length for tokenization |
|
|
""" |
|
|
self.max_length = max_length |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.model = T5GemmaEncoderModel.from_pretrained( |
|
|
model_id, |
|
|
device_map=device, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=dtype, |
|
|
is_encoder_decoder=False, |
|
|
) |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
def __call__(self, texts: List[str]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Generate hidden states from input text sequences. |
|
|
|
|
|
Args: |
|
|
texts: List of input text strings |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- llm_hidden_states: Last hidden states from the model |
|
|
- attention_mask: Attention mask indicating valid tokens |
|
|
""" |
|
|
|
|
|
prompts = [text + self.tokenizer.eos_token for text in texts] |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
prompts, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
max_length=self.max_length, |
|
|
truncation=True |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
outputs = self.model(**inputs) |
|
|
|
|
|
result = { |
|
|
"llm_hidden_states": outputs.last_hidden_state, |
|
|
"attention_mask": inputs.attention_mask |
|
|
} |
|
|
return result |
|
|
|
|
|
|
|
|
def process_caption_directory( |
|
|
model_path: str, |
|
|
caption_dir: str, |
|
|
output_dir: str, |
|
|
device: str = None, |
|
|
dtype: torch.dtype = torch.bfloat16 |
|
|
) -> None: |
|
|
""" |
|
|
Process all text files in a directory and save their hidden states. |
|
|
Just a simple example, for proper performance use dataloaders and batches. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the pre-trained T5-Gemma model |
|
|
caption_dir: Directory containing input text files |
|
|
output_dir: Directory to save .safetensors files |
|
|
device: Device to run inference on (auto-detected if None) |
|
|
dtype: Data type for computation |
|
|
""" |
|
|
|
|
|
if device is None: |
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
llm_processor = T5GemmaStates(model_path, device, dtype) |
|
|
|
|
|
|
|
|
for filename in os.listdir(caption_dir): |
|
|
filepath = os.path.join(caption_dir, filename) |
|
|
|
|
|
|
|
|
if not os.path.isfile(filepath): |
|
|
continue |
|
|
|
|
|
|
|
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
|
text_content = f.read().strip() |
|
|
|
|
|
texts = [text_content] |
|
|
result = llm_processor(texts) |
|
|
|
|
|
|
|
|
output_filename = Path(filename).stem + ".safetensors" |
|
|
save_file(result, os.path.join(output_dir, output_filename)) |
|
|
print(f"Processed {filename} -> {output_filename}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
""" |
|
|
Example usage of the T5GemmaStates class. |
|
|
This demonstrates how to process a directory of text files and extract their hidden states. |
|
|
""" |
|
|
model_id = "./t5gemma-2b-2b-ul2" |
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
caption_dir = "./inputs" |
|
|
output_path = "./outputs" |
|
|
|
|
|
process_caption_directory(model_id, caption_dir, output_path, str(device)) |
|
|
|