Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import pickle | |
| # --- Part 1: Re-define the Model Architecture --- | |
| # This class definition must be EXACTLY the same as in your training script. | |
| class ResidualLSTMModel(nn.Module): | |
| def __init__(self, vocab_size, embedding_dim, hidden_units, dropout_prob): | |
| super(ResidualLSTMModel, self).__init__() | |
| self.embedding = nn.Embedding( | |
| num_embeddings=vocab_size, | |
| embedding_dim=embedding_dim, | |
| padding_idx=0 | |
| ) | |
| self.lstm1 = nn.LSTM( | |
| input_size=embedding_dim, | |
| hidden_size=hidden_units, | |
| num_layers=1, | |
| batch_first=True | |
| ) | |
| self.lstm2 = nn.LSTM( | |
| input_size=hidden_units, | |
| hidden_size=hidden_units, | |
| num_layers=1, | |
| batch_first=True | |
| ) | |
| self.dropout = nn.Dropout(dropout_prob) | |
| self.fc = nn.Linear(hidden_units, vocab_size) | |
| def forward(self, x): | |
| embedded = self.embedding(x) | |
| out1, _ = self.lstm1(embedded) | |
| out2, _ = self.lstm2(out1) | |
| residual_sum = out1 + out2 | |
| dropped_out = self.dropout(residual_sum) | |
| logits = self.fc(dropped_out) | |
| return logits | |
| # --- Part 2: Helper Functions for Processing Text --- | |
| def text_to_sequence(text, vocab, max_length): | |
| """Converts a string of code into a padded tensor.""" | |
| tokens = text.split() | |
| numericalized = [vocab.get(token, vocab['<UNK>']) for token in tokens] | |
| if len(numericalized) > max_length: | |
| numericalized = numericalized[:max_length] | |
| pad_id = vocab['<PAD>'] | |
| padding_needed = max_length - len(numericalized) | |
| padded = numericalized + [pad_id] * padding_needed | |
| return torch.tensor([padded], dtype=torch.long) | |
| def sequence_to_text(sequence, vocab): | |
| """Converts a tensor of token IDs back to a string.""" | |
| id_to_token = {id_val: token for token, id_val in vocab.items()} | |
| tokens = [id_to_token.get(id_val.item(), '<UNK>') for id_val in sequence if id_val.item() != vocab['<PAD>']] | |
| return " ".join(tokens) | |
| # --- Part 3: Main Prediction Logic --- | |
| def predict_next_tokens(model, text, vocab, device, max_length=1000, top_k=5): | |
| """Predicts the top_k next tokens for a given text input.""" | |
| model.eval() | |
| with torch.no_grad(): | |
| input_tensor = text_to_sequence(text, vocab, max_length).to(device) | |
| logits = model(input_tensor) | |
| num_input_tokens = len(text.split()) | |
| last_token_logits = logits[0, num_input_tokens - 1, :] | |
| _, top_k_ids = torch.topk(last_token_logits, top_k, dim=-1) | |
| top_k_tokens = [sequence_to_text([token_id], vocab) for token_id in top_k_ids] | |
| return top_k_tokens | |
| if __name__ == '__main__': | |
| # --- Configuration --- | |
| MODEL_PATH = 'model.pt' | |
| VOCAB_PATH = 'vocab.pkl' # <-- Updated to use .pkl | |
| MAX_LENGTH = 1000 | |
| # --- Load everything --- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load vocabulary using pickle | |
| with open(VOCAB_PATH, 'rb') as f: # <-- Use 'rb' for reading bytes | |
| vocab = pickle.load(f) | |
| print("Vocabulary loaded.") | |
| # Load the model | |
| model = torch.load(MODEL_PATH, map_location=device , weights_only=False) | |
| print("Model loaded.") | |
| # --- Make a Prediction --- | |
| input_code = "import numpy as" # Example input | |
| print(f"\nInput code: '{input_code}'") | |
| suggestions = predict_next_tokens(model, input_code, vocab, device, max_length=MAX_LENGTH) | |
| print("\nTop 5 suggestions:") | |
| for i, suggestion in enumerate(suggestions): | |
| print(f"{i+1}. {suggestion}") |