File size: 7,505 Bytes
6e32a75 6c83c5b 6e32a75 36a9fd5 6e32a75 db3243e 6e32a75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import torch
from PIL import Image
from models.r2gen import R2GenModel
from modules.tokenizers import Tokenizer
import argparse
# Assuming you have a predefined configuration function for model args
def get_model_args():
parser = argparse.ArgumentParser()
# Model loader settings
parser.add_argument('--load', type=str, default='ckpts/few-shot.pth', help='the path to the model weights.')
parser.add_argument('--prompt', type=str, default='prompt/prompt.pth', help='the path to the prompt weights.')
# Data input settings
parser.add_argument('--image_path', type=str, default='example_figs/example_fig1.jpg', help='the path to the test image.')
parser.add_argument('--image_dir', type=str, default='data/images/', help='the path to the directory containing the data.')
parser.add_argument('--ann_path', type=str, default='data/annotation.json', help='the path to the directory containing the data.')
# Data loader settings
parser.add_argument('--dataset_name', type=str, default='mimic_cxr', help='the dataset to be used.')
parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.')
parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.')
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch')
# Model settings (for visual extractor)
parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.')
parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor')
# Model settings (for Transformer)
parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.')
parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.')
parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.')
parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.')
parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.')
parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.')
parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.')
parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.')
parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.')
parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.')
parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.')
parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.')
# for Relational Memory
parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.')
parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.')
parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.')
# Sample related
parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.')
parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.')
parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.')
parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.')
parser.add_argument('--group_size', type=int, default=1, help='the group size.')
parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.')
parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.')
parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.')
# Trainer settings
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.')
parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.')
parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments')
parser.add_argument('--save_period', type=int, default=1, help='the saving period.')
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
# Optimization
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.')
parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.')
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
parser.add_argument('--amsgrad', type=bool, default=True, help='.')
# Learning Rate Scheduler
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')
# Others
parser.add_argument('--seed', type=int, default=9233, help='.')
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
args = parser.parse_args()
return args
def load_model():
args = get_model_args()
tokenizer = Tokenizer(args)
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Determine the device dynamically
model = R2GenModel(args, tokenizer).to(device)
checkpoint_path = args.load
# Ensure the state dict is loaded onto the same device as the model
state_dict = torch.load(checkpoint_path, map_location=device)
model_state_dict = state_dict['state_dict'] if 'state_dict' in state_dict else state_dict
return model, tokenizer
model, tokenizer = load_model()
def generate_report(image):
image = Image.fromarray(image).convert('RGB')
with torch.no_grad():
output = model([image], mode='sample')
reports = tokenizer.decode_batch(output.cpu().numpy())
outputs = reports[0].split('.', 1)[-1].strip()
return outputs
# Define Gradio interface
iface = gr.Interface(
inputs=gr.Image(), # Define input shape as needed
description="Upload a medical image for thorax disease reporting.",
examples=[["example_figs/0.png"], ["example_figs/1.png"], ["example_figs/2.png"]]
if __name__ == "__main__":