File size: 7,260 Bytes
185432e
 
 
 
 
 
 
 
 
 
0d4be4c
7249037
 
 
 
 
 
185432e
7249037
 
0d4be4c
 
 
 
 
7249037
0d4be4c
7249037
 
 
0a80f4a
 
 
 
 
 
7249037
 
 
0a80f4a
 
 
 
 
 
 
 
 
7249037
 
 
 
3d7ebe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185432e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43bd89d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185432e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f6ce96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import gradio as gr
import torch
import json
from transformers import GPT2Tokenizer
from safetensors.torch import load_file
from transformers import GPT2Config as GPTConfig
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass

# Define the GPTConfig class with filtering
class GPTConfig:
    def __init__(self, n_embd, n_head, n_layer, vocab_size):
        self.n_embd = n_embd
        self.n_head = n_head
        self.n_layer = n_layer
        self.vocab_size = vocab_size

    @classmethod
    def from_dict(cls, config_dict):
        # Define the expected keys
        expected_keys = {'n_embd', 'n_head', 'n_layer', 'vocab_size'}
        # Filter out unexpected keys
        filtered_dict = {key: value for key, value in config_dict.items() if key in expected_keys}
        return cls(**filtered_dict)

# Define the GPT class
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Initialize the embedding layer
        self.embedding = nn.Embedding(config.vocab_size, config.n_embd)
        # Initialize the Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=config.n_embd, nhead=config.n_head, dim_feedforward=config.n_embd, dropout=0.1)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=config.n_layer)
        # Initialize the language model head
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)

    def forward(self, input_ids):
        # Embed the input tokens
        input_embeddings = self.embedding(input_ids)
        # Transpose the input to match the expected shape for TransformerDecoder
        input_embeddings = input_embeddings.transpose(0, 1)
        # Pass through the Transformer decoder
        transformer_output = self.transformer(input_embeddings, input_embeddings)
        # Transpose back to the original shape
        transformer_output = transformer_output.transpose(0, 1)
        # Get the logits from the language model head
        logits = self.lm_head(transformer_output)
        return logits

    def generate(self, input_ids, max_new_tokens, temperature, top_k):
        # Implement the text generation logic
        output_ids = input_ids
        for _ in range(max_new_tokens):
            logits = self.forward(output_ids[:, -1:])
            logits = logits / temperature
            probs = F.softmax(logits, dim=-1)
            
            # Ensure probs is 2D
            if probs.dim() == 3:
                probs = probs.squeeze(0)  # Remove the batch dimension if it exists
            
            top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
            
            # Ensure top_k_probs is 2D
            if top_k_probs.dim() == 1:
                top_k_probs = top_k_probs.unsqueeze(0)
            
            next_token = torch.multinomial(top_k_probs, num_samples=1)
            next_token = top_k_indices.gather(-1, next_token)
            
            # Ensure next_token is 2D
            if next_token.dim() == 1:
                next_token = next_token.unsqueeze(0)
            
            output_ids = torch.cat([output_ids, next_token], dim=1)
        return output_ids

# Initialize global variables
model = None
tokenizer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model():
    """Load the Leap0 model and tokenizer."""
    global model, tokenizer
    
    try:
        # Paths to config and model files
        config_path = "config.json"
        model_path = "model.safetensors"
        
        print(f"Loading configuration from {config_path}...")
        # Load the configuration
        with open(config_path, "r") as f:
            config_dict = json.load(f)
            
        print("Configuration loaded. Creating model config...")
        config = GPTConfig.from_dict(config_dict)
        print(f"Model config created: {config}")
        
        print(f"Loading model weights from {model_path}...")
        # Load the model weights
        tensors = load_file(model_path)
        
        print("Instantiating model...")
        # Instantiate the model with the loaded config
        model = GPT(config)
        
        print("Loading weights into model...")
        model.load_state_dict(tensors, strict=False)
        model.to(device)
        model.eval()
        
        print("Loading tokenizer...")
        # Load the tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        
        print("Model and tokenizer loaded successfully")
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        raise

def generate_text(prompt, max_length=50, temperature=0.7, top_k=40):
    """Generate text based on the provided prompt."""
    if model is None or tokenizer is None:
        load_model()
    
    # Tokenize the input text
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    # Generate text
    with torch.no_grad():
        output_ids = model.generate(
            input_ids, 
            max_new_tokens=max_length,
            temperature=temperature,
            top_k=top_k
        )
    
    # Decode the output
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    return output_text

# Create the Gradio interface
def create_interface():
    with gr.Blocks(css="footer {visibility: hidden}") as demo:
        gr.Markdown("# Leap0 Language Model")
        gr.Markdown("A GPT-2 based model trained on the Tiny Stories dataset")
        
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(
                    label="Enter your prompt",
                    placeholder="once upon a time in the village of",
                    lines=3
                )
                
                with gr.Row():
                    max_length = gr.Slider(
                        minimum=1, 
                        maximum=200, 
                        value=50, 
                        step=1, 
                        label="Max Length"
                    )
                    temperature = gr.Slider(
                        minimum=0.1, 
                        maximum=2.0, 
                        value=0.7, 
                        step=0.1, 
                        label="Temperature"
                    )
                    top_k = gr.Slider(
                        minimum=1, 
                        maximum=100, 
                        value=40, 
                        step=1, 
                        label="Top K"
                    )
                
                generate_btn = gr.Button("Generate Text")
                
            with gr.Column():
                output = gr.Textbox(
                    label="Generated Output",
                    lines=10,
                    placeholder="Your generated text will appear here..."
                )
        
        generate_btn.click(
            fn=generate_text,
            inputs=[prompt, max_length, temperature, top_k],
            outputs=output
        )
    
    return demo

# Load the model when the script is run
load_model()

# Create and launch the interface
demo = create_interface()

if __name__ == "__main__":
    demo.launch()