Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Decoding script for morphological reinflection using TagTransformer | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| from typing import Dict, List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from transformer import TagTransformer, PAD_IDX, DEVICE | |
| from morphological_dataset import build_vocabulary, tokenize_sequence | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def load_model(checkpoint_path: str, config: Dict, src_vocab: Dict[str, int], | |
| tgt_vocab: Dict[str, int]) -> TagTransformer: | |
| """Load trained model from checkpoint""" | |
| # Count feature tokens | |
| feature_tokens = [token for token in src_vocab.keys() | |
| if token.startswith('<') and token.endswith('>')] | |
| nb_attr = len(feature_tokens) | |
| # Create model | |
| model = TagTransformer( | |
| src_vocab_size=len(src_vocab), | |
| trg_vocab_size=len(tgt_vocab), | |
| embed_dim=config['embed_dim'], | |
| nb_heads=config['nb_heads'], | |
| src_hid_size=config['src_hid_size'], | |
| src_nb_layers=config['src_nb_layers'], | |
| trg_hid_size=config['trg_hid_size'], | |
| trg_nb_layers=config['trg_nb_layers'], | |
| dropout_p=0.0, # No dropout during inference | |
| tie_trg_embed=config['tie_trg_embed'], | |
| label_smooth=0.0, # No label smoothing during inference | |
| nb_attr=nb_attr, | |
| src_c2i=src_vocab, | |
| trg_c2i=tgt_vocab, | |
| attr_c2i={}, | |
| ) | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location=DEVICE) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| logger.info(f"Model loaded from {checkpoint_path}") | |
| return model | |
| def beam_search(model: TagTransformer, src_tokens: List[str], src_vocab: Dict[str, int], | |
| tgt_vocab: Dict[str, int], beam_width: int = 5, max_length: int = 100) -> List[str]: | |
| """Perform beam search decoding""" | |
| # Tokenize source | |
| src_indices, _ = tokenize_sequence(src_tokens, src_vocab, max_length, add_bos_eos=True) | |
| src_tensor = torch.tensor([src_indices], dtype=torch.long).to(DEVICE).t() # [seq_len, batch_size] | |
| # Create source mask | |
| src_mask = torch.zeros(src_tensor.size(0), src_tensor.size(1), dtype=torch.bool).to(DEVICE) | |
| # Encode source | |
| with torch.no_grad(): | |
| encoded = model.encode(src_tensor, src_mask) | |
| # Initialize beam | |
| beam = [([tgt_vocab['<BOS>']], 0.0)] # (sequence, score) | |
| for step in range(max_length): | |
| candidates = [] | |
| for sequence, score in beam: | |
| if sequence[-1] == tgt_vocab['<EOS>']: | |
| candidates.append((sequence, score)) | |
| continue | |
| # Prepare target input | |
| tgt_tensor = torch.tensor([sequence], dtype=torch.long).to(DEVICE).t() | |
| tgt_mask = torch.zeros(tgt_tensor.size(0), tgt_tensor.size(1), dtype=torch.bool).to(DEVICE) | |
| # Decode | |
| with torch.no_grad(): | |
| output = model.decode(encoded, src_mask, tgt_tensor, tgt_mask) | |
| # Get next token probabilities | |
| next_token_probs = output[-1, 0] # [vocab_size] | |
| # Get top-k candidates | |
| top_k_probs, top_k_indices = torch.topk(next_token_probs, beam_width) | |
| for prob, idx in zip(top_k_probs, top_k_indices): | |
| new_sequence = sequence + [idx.item()] | |
| new_score = score + prob.item() | |
| candidates.append((new_sequence, new_score)) | |
| # Select top beam_width candidates | |
| candidates.sort(key=lambda x: x[1], reverse=True) | |
| beam = candidates[:beam_width] | |
| # Check if all sequences end with EOS | |
| if all(seq[-1] == tgt_vocab['<EOS>'] for seq, _ in beam): | |
| break | |
| # Return best sequence | |
| best_sequence, _ = beam[0] | |
| # Convert indices to tokens | |
| idx_to_token = {idx: token for token, idx in tgt_vocab.items()} | |
| result_tokens = [idx_to_token[idx] for idx in best_sequence[1:-1]] # Remove BOS and EOS | |
| return result_tokens | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Decode using trained TagTransformer') | |
| parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint') | |
| parser.add_argument('--config', type=str, default='./models/config.json', help='Path to config file') | |
| parser.add_argument('--src_file', type=str, help='Source file for decoding') | |
| parser.add_argument('--output_file', type=str, help='Output file for predictions') | |
| parser.add_argument('--beam_width', type=int, default=5, help='Beam width for decoding') | |
| args = parser.parse_args() | |
| # Load configuration | |
| with open(args.config, 'r') as f: | |
| config = json.load(f) | |
| # Data file paths for vocabulary building | |
| train_src = '10L_90NL/train/run1/train.10L_90NL_1_1.src' | |
| train_tgt = '10L_90NL/train/run1/train.10L_90NL_1_1.tgt' | |
| dev_src = '10L_90NL/dev/run1/train.10L_90NL_1_1.src' | |
| dev_tgt = '10L_90NL/dev/run1/train.10L_90NL_1_1.tgt' | |
| test_src = '10L_90NL/test/run1/train.10L_90NL_1_1.src' | |
| test_tgt = '10L_90NL/test/run1/train.10L_90NL_1_1.tgt' | |
| # Build vocabularies | |
| logger.info("Building vocabulary...") | |
| src_vocab = build_vocabulary([train_src, dev_src, test_src]) | |
| tgt_vocab = build_vocabulary([train_tgt, dev_tgt, test_tgt]) | |
| # Load model | |
| model = load_model(args.checkpoint, config, src_vocab, tgt_vocab) | |
| # Decode | |
| if args.src_file and args.output_file: | |
| # Batch decoding from file | |
| logger.info(f"Decoding from {args.src_file} to {args.output_file}") | |
| with open(args.src_file, 'r', encoding='utf-8') as src_f, \ | |
| open(args.output_file, 'w', encoding='utf-8') as out_f: | |
| for line_num, line in enumerate(src_f, 1): | |
| src_tokens = line.strip().split() | |
| if not src_tokens: | |
| continue | |
| try: | |
| result = beam_search(model, src_tokens, src_vocab, tgt_vocab, args.beam_width) | |
| out_f.write(' '.join(result) + '\n') | |
| if line_num % 100 == 0: | |
| logger.info(f"Processed {line_num} lines") | |
| except Exception as e: | |
| logger.error(f"Error processing line {line_num}: {e}") | |
| out_f.write('<ERROR>\n') | |
| logger.info("Decoding completed!") | |
| else: | |
| # Interactive decoding | |
| logger.info("Interactive decoding mode. Type 'quit' to exit.") | |
| while True: | |
| try: | |
| user_input = input("Enter source sequence: ").strip() | |
| if user_input.lower() == 'quit': | |
| break | |
| if not user_input: | |
| continue | |
| src_tokens = user_input.split() | |
| result = beam_search(model, src_tokens, src_vocab, tgt_vocab, args.beam_width) | |
| print(f"Prediction: {' '.join(result)}") | |
| print() | |
| except KeyboardInterrupt: | |
| break | |
| except Exception as e: | |
| logger.error(f"Error: {e}") | |
| if __name__ == '__main__': | |
| main() | |