Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import tiktoken | |
| from model import GPT, GPTConfig | |
| from transformers import GPT2LMHeadModel | |
| def get_model(): | |
| """Load the trained GPT model.""" | |
| model = GPT(GPTConfig()) | |
| # Load from the Hugging Face Hub instead of local file | |
| model_path = 'mathminakshi/custom_gpt2' | |
| model.load_state_dict(torch.hub.load_state_dict_from_url(f'https://huggingface.co/{model_path}/resolve/main/model.pth', map_location='cpu')) | |
| model.eval() | |
| return model | |
| def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40): | |
| """Generate text based on the prompt.""" | |
| # Get cached model | |
| model = get_model() | |
| device = next(model.parameters()).device | |
| # Tokenize prompt with special token handling | |
| enc = tiktoken.get_encoding("gpt2") | |
| input_ids = torch.tensor(enc.encode(prompt, allowed_special={'<|endoftext|>'})).unsqueeze(0).to(device) | |
| # Get end token id | |
| end_token = enc.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0] | |
| with torch.no_grad(): | |
| output_sequence = [] | |
| progress_bar = st.progress(0) | |
| for i in range(max_tokens): | |
| progress_bar.progress(i / max_tokens) | |
| # Get predictions | |
| logits,_ = model(input_ids) | |
| logits = logits[:, -1, :] / temperature | |
| # Apply top-k filtering | |
| if top_k > 0: | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits[indices_to_remove] = float('-inf') | |
| # Sample from the filtered distribution | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| # Append to output | |
| output_sequence.append(next_token.item()) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| # Stop if we generate an EOS token | |
| if next_token.item() == 50256: | |
| break | |
| progress_bar.progress(1.0) | |
| generated_text = enc.decode(output_sequence) | |
| return prompt + generated_text | |
| def main(): | |
| st.title("GPT Text Generator") | |
| st.write("Enter a prompt to generate text using GPT-2.") | |
| # Sidebar for parameters | |
| st.sidebar.header("Generation Parameters") | |
| max_tokens = st.sidebar.slider( | |
| "Max Tokens", | |
| min_value=1, | |
| max_value=1000, | |
| value=100, | |
| help="Maximum number of tokens to generate" | |
| ) | |
| temperature = st.sidebar.slider( | |
| "Temperature", | |
| min_value=0.1, | |
| max_value=2.0, | |
| value=0.8, | |
| help="Higher values make the output more random" | |
| ) | |
| top_k = st.sidebar.slider( | |
| "Top-K", | |
| min_value=1, | |
| max_value=100, | |
| value=40, | |
| help="Limits the number of tokens to choose from" | |
| ) | |
| prompt = st.text_area( | |
| "Enter your prompt:", | |
| height=100, | |
| placeholder="Once upon a time..." | |
| ) | |
| if st.button("Generate"): | |
| if prompt: | |
| with st.spinner("Generating text..."): | |
| generated_text = generate_text( | |
| prompt=prompt, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k | |
| ) | |
| st.write("### Generated Text:") | |
| st.write(generated_text) | |
| else: | |
| st.warning("Please enter a prompt first!") | |
| if __name__ == "__main__": | |
| main() |