| | """ |
| | Evaluation on real-world damaged characters from Jiucheng Palace inscription. |
| | Implements real-world scenario testing from the paper. |
| | """ |
| |
|
| | import torch |
| | from torch.utils.data import Dataset, DataLoader |
| | from transformers import BertTokenizer |
| | from PIL import Image |
| | import numpy as np |
| | import os |
| |
|
| | from config import Config |
| | from models.mmrm import MMRM |
| | from evaluation.metrics import RestorationMetrics |
| |
|
| |
|
| | class RealWorldDataset(Dataset): |
| | """ |
| | Dataset for real-world damaged characters. |
| | Loads images from data/real/pic/ and contexts from data/real/restore.txt |
| | """ |
| | |
| | def __init__(self, config: Config, tokenizer: BertTokenizer): |
| | """ |
| | Initialize real-world dataset. |
| | |
| | Args: |
| | config: Configuration object |
| | tokenizer: Tokenizer for text encoding |
| | """ |
| | self.config = config |
| | self.tokenizer = tokenizer |
| | |
| | |
| | true_path = os.path.join(config.real_data_dir, 'true.txt') |
| | with open(true_path, 'r', encoding='utf-8') as f: |
| | self.labels = [line.strip() for line in f.readlines()] |
| | |
| | |
| | restore_path = os.path.join(config.real_data_dir, 'restore.txt') |
| | with open(restore_path, 'r', encoding='utf-8') as f: |
| | self.contexts = [line.strip() for line in f.readlines()] |
| | |
| | |
| | self.image_dir = os.path.join(config.real_data_dir, 'pic') |
| | |
| | |
| | self.samples = [] |
| | label_idx = 0 |
| | |
| | for context in self.contexts: |
| | |
| | num_masks = context.count('[MASK]') |
| | |
| | if num_masks > 0: |
| | |
| | context_labels = [] |
| | for _ in range(num_masks): |
| | if label_idx < len(self.labels): |
| | context_labels.append(self.labels[label_idx]) |
| | label_idx += 1 |
| | |
| | self.samples.append({ |
| | 'context': context, |
| | 'labels': context_labels, |
| | 'image_indices': list(range(label_idx - num_masks + 1, label_idx + 1)) |
| | }) |
| | |
| | print(f"Loaded {len(self.samples)} real-world samples") |
| | |
| | def __len__(self): |
| | return len(self.samples) |
| | |
| | def __getitem__(self, idx): |
| | """ |
| | Get a real-world sample. |
| | |
| | Returns: |
| | Dictionary with tokenized context, damaged images, and labels |
| | """ |
| | sample = self.samples[idx] |
| | |
| | |
| | encoding = self.tokenizer( |
| | sample['context'], |
| | max_length=self.config.max_seq_length, |
| | padding='max_length', |
| | truncation=True, |
| | return_tensors='pt' |
| | ) |
| | |
| | |
| | mask_token_id = self.tokenizer.mask_token_id |
| | input_ids = encoding['input_ids'].squeeze(0) |
| | mask_positions = (input_ids == mask_token_id).nonzero(as_tuple=True)[0] |
| | |
| | |
| | damaged_images = [] |
| | for img_idx in sample['image_indices']: |
| | img_path = os.path.join(self.image_dir, f'o{img_idx}.png') |
| | img = Image.open(img_path).convert('L') |
| | |
| | |
| | img = img.resize((self.config.image_size, self.config.image_size)) |
| | |
| | |
| | img_array = np.array(img).astype(np.float32) / 255.0 |
| | img_tensor = torch.from_numpy(img_array).unsqueeze(0) |
| | |
| | damaged_images.append(img_tensor) |
| | |
| | damaged_images = torch.stack(damaged_images) if len(damaged_images) > 0 else torch.zeros(1, 1, 64, 64) |
| | |
| | |
| | label_ids = [] |
| | for label in sample['labels']: |
| | label_id = self.tokenizer.convert_tokens_to_ids(label) |
| | label_ids.append(label_id) |
| | |
| | labels = torch.tensor(label_ids, dtype=torch.long) |
| | |
| | return { |
| | 'input_ids': input_ids, |
| | 'attention_mask': encoding['attention_mask'].squeeze(0), |
| | 'mask_positions': mask_positions, |
| | 'damaged_images': damaged_images, |
| | 'labels': labels |
| | } |
| |
|
| |
|
| | def evaluate_real_world(config: Config, checkpoint_path: str) -> str: |
| | """ |
| | Evaluate on real-world damaged characters. |
| | |
| | Args: |
| | config: Configuration object |
| | checkpoint_path: Path to model checkpoint |
| | |
| | Returns: |
| | Formatted results string |
| | """ |
| | device = torch.device(config.device if torch.cuda.is_available() or config.device == "cuda" else "cpu") |
| | |
| | |
| | model = MMRM(config).to(device) |
| | checkpoint = torch.load(checkpoint_path, map_location=device, weights_only = False) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | |
| | print(f"Loaded model from {checkpoint_path}") |
| | |
| | |
| | tokenizer = BertTokenizer.from_pretrained(config.roberta_model) |
| | |
| | |
| | real_dataset = RealWorldDataset(config, tokenizer) |
| | real_loader = DataLoader( |
| | real_dataset, |
| | batch_size=1, |
| | shuffle=False |
| | ) |
| | |
| | |
| | metrics = RestorationMetrics(config.top_k_values) |
| | |
| | print("\nEvaluating on real-world data...") |
| | |
| | with torch.no_grad(): |
| | for batch in real_loader: |
| | input_ids = batch['input_ids'].to(device) |
| | attention_mask = batch['attention_mask'].to(device) |
| | mask_positions = batch['mask_positions'].to(device) |
| | damaged_images = batch['damaged_images'].to(device) |
| | labels = batch['labels'].to(device) |
| | |
| | |
| | text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images) |
| | |
| | |
| | metrics.update(text_logits, labels) |
| | |
| | results = metrics.compute() |
| | |
| | output = f"\nReal-world Evaluation Results (38 characters):\n" |
| | output += f"{'='*50}\n" |
| | output += f"Accuracy: {results['accuracy']:.2f}%\n" |
| | output += f"Hit@5: {results['hit_5']:.2f}%\n" |
| | output += f"Hit@10: {results['hit_10']:.2f}%\n" |
| | output += f"Hit@20: {results['hit_20']:.2f}%\n" |
| | output += f"MRR: {results['mrr']:.2f}\n" |
| | output += f"{'='*50}\n" |
| | output += f"\nCompare with paper results:\n" |
| | output += f" Paper - Accuracy: 55.26%, MRR: 62.28\n" |
| | |
| | return output |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| | |
| | if len(sys.argv) < 2: |
| | print("Usage: python evaluate_real.py <checkpoint_path>") |
| | sys.exit(1) |
| | |
| | checkpoint_path = sys.argv[1] |
| | |
| | config = Config() |
| | results = evaluate_real_world(config, checkpoint_path) |
| | print(results) |
| |
|