Spaces:
Sleeping
Sleeping
| """ | |
| Gradio web app for Shakespeare-style text generation using the trained GPT model. | |
| This app provides an interactive interface for users to generate Shakespeare-style text | |
| with customizable parameters. | |
| """ | |
| import os | |
| import torch | |
| import gradio as gr | |
| from model import GPT, GPTConfig | |
| import tiktoken | |
| torch.set_default_device('cpu') | |
| class ShakespeareTextGenerator: | |
| def __init__(self, model_path='compressed_model_cpu_compatible.pt'): | |
| """Initialize the text generator with the trained model""" | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Load checkpoint | |
| checkpoint = torch.load(model_path, map_location=self.device) | |
| # Initialize model with saved config | |
| self.config = GPTConfig(**checkpoint['config']) | |
| self.model = GPT(self.config) | |
| # Load state dict and convert to correct dtype if needed | |
| if checkpoint['dtype'] == 'float16' and self.device == 'cuda': | |
| self.model.half() | |
| elif checkpoint['dtype'] == 'float32': | |
| self.model.float() | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Initialize tokenizer with special token handling | |
| self.tokenizer = tiktoken.get_encoding('gpt2') | |
| self.end_token = self.tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0] | |
| def generate(self, | |
| prompt, | |
| max_length=100, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.9, | |
| num_return_sequences=1): | |
| """ | |
| Generate Shakespeare-style text based on the prompt | |
| """ | |
| # Encode the prompt with special token handling | |
| input_ids = torch.tensor( | |
| self.tokenizer.encode(prompt, allowed_special=set()) | |
| ).unsqueeze(0).to(self.device) | |
| generated_sequences = [] | |
| with torch.no_grad(): | |
| for _ in range(num_return_sequences): | |
| # Initialize sequence with input_ids | |
| cur_ids = input_ids.clone() | |
| for _ in range(max_length): | |
| # Get model's logits for next token | |
| outputs, _ = self.model(cur_ids) | |
| next_token_logits = outputs[:, -1, :] / temperature | |
| # Apply top-k filtering | |
| if top_k > 0: | |
| values, _ = torch.topk(next_token_logits, top_k) | |
| min_value = values[:, -1].unsqueeze(-1).expand_as(next_token_logits) | |
| next_token_logits = torch.where( | |
| next_token_logits < min_value, | |
| torch.ones_like(next_token_logits) * float('-inf'), | |
| next_token_logits | |
| ) | |
| # Apply top-p (nucleus) filtering | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) | |
| cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above the threshold | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Shift the indices to the right to keep also the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| # Scatter sorted tensors to original indexing | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf')) | |
| # Sample next token | |
| probs = torch.softmax(next_token_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| # Append to sequence | |
| cur_ids = torch.cat([cur_ids, next_token], dim=1) | |
| # Stop if we predict the end of text token | |
| if next_token.item() == self.end_token: | |
| break | |
| # Decode the generated sequence | |
| generated_text = self.tokenizer.decode(cur_ids[0].tolist()) | |
| generated_sequences.append(generated_text) | |
| return generated_sequences | |
| # Initialize the generator | |
| generator = ShakespeareTextGenerator() | |
| def generate_text(prompt, max_length, temperature, top_k, top_p, num_sequences): | |
| """Gradio interface function""" | |
| try: | |
| sequences = generator.generate( | |
| prompt=prompt, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| num_return_sequences=num_sequences | |
| ) | |
| return "\n\n---\n\n".join(sequences) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox( | |
| lines=3, | |
| label="Prompt", | |
| placeholder="Enter your prompt here...", | |
| value="To be, or not to be," | |
| ), | |
| gr.Slider( | |
| minimum=10, | |
| maximum=500, | |
| value=100, | |
| step=10, | |
| label="Maximum Length" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature (randomness)" | |
| ), | |
| gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=50, | |
| step=5, | |
| label="Top-k" | |
| ), | |
| gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)" | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| value=1, | |
| step=1, | |
| label="Number of Sequences" | |
| ) | |
| ], | |
| outputs=gr.Textbox( | |
| lines=10, | |
| label="Generated Text" | |
| ), | |
| title="Shakespeare-Style Text Generator", | |
| description="""Generate Shakespeare-style text using a fine-tuned GPT model. Training repository: [https://github.com/dhairyag/ShakespeareGPT-Forge](https://github.com/dhairyag/ShakespeareGPT-Forge) | |
| Adjust the parameters to control the generation: | |
| - Temperature: Higher values make the output more random | |
| - Top-k: Limits the vocabulary to the k most likely tokens | |
| - Top-p: Limits the cumulative probability of tokens considered | |
| - Number of Sequences: Generate multiple variations""", | |
| examples=[ | |
| ["To be, or not to be,", 100, 0.7, 50, 0.9, 1], | |
| ["O Romeo, Romeo,", 150, 0.8, 40, 0.85, 2], | |
| ["All the world's a stage,", 200, 0.6, 60, 0.95, 1] | |
| ] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| iface.launch() |