File size: 4,780 Bytes
5a8dc28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Module providing a wrapper class for T5-Gemma model to extract hidden states.

This module defines a simple interface for using the T5-Gemma encoder model
to generate hidden states from text inputs, which can then be used by other
components like the SDXL adapter.
"""

import os
import torch
from pathlib import Path
from typing import List, Dict, Any
from transformers import T5GemmaEncoderModel, AutoTokenizer
from safetensors.torch import save_file

class T5GemmaStates:
    """A wrapper class for T5-Gemma model to extract hidden states from text inputs."""

    def __init__(self,
                 model_id: str,
                 device: str = "cpu",
                 dtype: torch.dtype = torch.bfloat16,
                 max_length: int = 512):
        """
        Initialize the T5-Gemma state extractor.

        Args:
            model_id: Path or identifier for the pre-trained model
            device: Device to run the model on (e.g., 'cuda' or 'cpu')
            dtype: Data type for model computations
            max_length: Maximum sequence length for tokenization
        """
        self.max_length = max_length
        self.device = device
        
        # Load the encoder-only T5-Gemma model
        self.model = T5GemmaEncoderModel.from_pretrained(
            model_id,
            device_map=device,
            low_cpu_mem_usage=True,
            torch_dtype=dtype,
            is_encoder_decoder=False,
        )
        
        # Load the tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)

    def __call__(self, texts: List[str]) -> Dict[str, torch.Tensor]:
        """
        Generate hidden states from input text sequences.

        Args:
            texts: List of input text strings

        Returns:
            Dictionary containing:
                - llm_hidden_states: Last hidden states from the model
                - attention_mask: Attention mask indicating valid tokens
        """
        # Append EOS token to each prompt
        prompts = [text + self.tokenizer.eos_token for text in texts]
        
        # Tokenize with padding and truncation
        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length,
            truncation=True
        ).to(self.device)
        
        # Forward pass through the model
        outputs = self.model(**inputs)

        result = {
            "llm_hidden_states": outputs.last_hidden_state,
            "attention_mask": inputs.attention_mask
        }
        return result


def process_caption_directory(
    model_path: str,
    caption_dir: str,
    output_dir: str,
    device: str = None,
    dtype: torch.dtype = torch.bfloat16
    ) -> None:
    """
    Process all text files in a directory and save their hidden states.
    Just a simple example, for proper performance use dataloaders and batches.

    Args:
        model_path: Path to the pre-trained T5-Gemma model
        caption_dir: Directory containing input text files
        output_dir: Directory to save .safetensors files
        device: Device to run inference on (auto-detected if None)
        dtype: Data type for computation
    """
    # Auto-detect device if not specified
    if device is None:
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize the processor
    llm_processor = T5GemmaStates(model_path, device, dtype)
    
    # Process each file in the caption directory
    for filename in os.listdir(caption_dir):
        filepath = os.path.join(caption_dir, filename)
        
        # Skip if not a file
        if not os.path.isfile(filepath):
            continue
            
        # Read text content using context manager
        with open(filepath, 'r', encoding='utf-8') as f:
            text_content = f.read().strip()
            
        texts = [text_content]
        result = llm_processor(texts)
        
        # Save to output directory with same basename but .safetensors extension
        output_filename = Path(filename).stem + ".safetensors"
        save_file(result, os.path.join(output_dir, output_filename))
        print(f"Processed {filename} -> {output_filename}")


if __name__ == "__main__":
    """
    Example usage of the T5GemmaStates class.
    This demonstrates how to process a directory of text files and extract their hidden states.
    """
    model_id = "./t5gemma-2b-2b-ul2"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    caption_dir = "./inputs"
    output_path = "./outputs"

    process_caption_directory(model_id, caption_dir, output_path, str(device))