File size: 2,142 Bytes
d60797e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b792a8c
d60797e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import gradio as gr
import torch
from train_get2_8_init import GPT, GPTConfig, generate_text, TrainingConfig
from huggingface_hub import hf_hub_download
from torch.serialization import add_safe_globals

# Add GPTConfig to safe globals
add_safe_globals([GPTConfig])

def load_trained_model():
    config = TrainingConfig()
    model_config = GPTConfig(
        block_size=config.block_size,
        n_layer=config.n_layer,
        n_head=config.n_head,
        n_embd=config.n_embd,
        dropout=config.dropout
    )
    
    model = GPT(model_config)
    model_path = hf_hub_download(
        repo_id="padmanabhbosamia/Short_Shakesphere",
        filename="best_model_compressed.pt",
        token=os.getenv('HF_TOKEN')
    )
    checkpoint = torch.load(model_path, map_location=config.device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(config.device)
    model.eval()
    return model

def create_gradio_interface():
    model = load_trained_model()
    
    def predict(prompt, max_length, temperature=0.7):
        return generate_text(model, prompt, max_length, temperature)
    
    interface = gr.Interface(
        fn=predict,
        inputs=[
            gr.Textbox(
                lines=3,
                label="Enter your prompt",
                placeholder="Start typing here..."
            ),
            gr.Slider(
                minimum=10,
                maximum=500,
                value=100,
                step=10,
                label="Maximum Length"
            ),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.7,
                step=0.1,
                label="Temperature (Higher = more creative)"
            )
        ],
        outputs=gr.Textbox(lines=5, label="Generated Text"),
        title="Custom GPT Text Generator (124M) based on Shakespeare",
        description="A GPT-style language model trained on custom data by Shakespeare with 124M parameters"
    )
    return interface

# For Hugging Face Spaces
if __name__ == "__main__":
    interface = create_gradio_interface()
    interface.launch()