Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
from models.r2gen import R2GenModel | |
from modules.tokenizers import Tokenizer | |
import argparse | |
# Assuming you have a predefined configuration function for model args | |
def get_model_args(): | |
parser = argparse.ArgumentParser() | |
# Model loader settings | |
parser.add_argument('--load', type=str, default='ckpts/few-shot.pth', help='the path to the model weights.') | |
parser.add_argument('--prompt', type=str, default='prompt/prompt.pth', help='the path to the prompt weights.') | |
# Data input settings | |
parser.add_argument('--image_path', type=str, default='example_figs/example_fig1.jpg', help='the path to the test image.') | |
parser.add_argument('--image_dir', type=str, default='data/images/', help='the path to the directory containing the data.') | |
parser.add_argument('--ann_path', type=str, default='data/annotation.json', help='the path to the directory containing the data.') | |
# Data loader settings | |
parser.add_argument('--dataset_name', type=str, default='mimic_cxr', help='the dataset to be used.') | |
parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.') | |
parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.') | |
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') | |
parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch') | |
# Model settings (for visual extractor) | |
parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.') | |
parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor') | |
# Model settings (for Transformer) | |
parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.') | |
parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.') | |
parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.') | |
parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.') | |
parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.') | |
parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.') | |
parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') | |
parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.') | |
parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.') | |
parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.') | |
parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.') | |
parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.') | |
# for Relational Memory | |
parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.') | |
parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.') | |
parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.') | |
# Sample related | |
parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.') | |
parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.') | |
parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.') | |
parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.') | |
parser.add_argument('--group_size', type=int, default=1, help='the group size.') | |
parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.') | |
parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.') | |
parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.') | |
# Trainer settings | |
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') | |
parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.') | |
parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.') | |
parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments') | |
parser.add_argument('--save_period', type=int, default=1, help='the saving period.') | |
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') | |
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') | |
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') | |
# Optimization | |
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') | |
parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.') | |
parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.') | |
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') | |
parser.add_argument('--amsgrad', type=bool, default=True, help='.') | |
# Learning Rate Scheduler | |
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') | |
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') | |
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') | |
# Others | |
parser.add_argument('--seed', type=int, default=9233, help='.') | |
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') | |
args = parser.parse_args() | |
return args | |
def load_model(): | |
args = get_model_args() | |
tokenizer = Tokenizer(args) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Determine the device dynamically | |
model = R2GenModel(args, tokenizer).to(device) | |
checkpoint_path = args.load | |
# Ensure the state dict is loaded onto the same device as the model | |
state_dict = torch.load(checkpoint_path, map_location=device) | |
model_state_dict = state_dict['state_dict'] if 'state_dict' in state_dict else state_dict | |
model.load_state_dict(model_state_dict) | |
model.eval() | |
return model, tokenizer | |
model, tokenizer = load_model() | |
def generate_report(image): | |
image = Image.fromarray(image).convert('RGB') | |
with torch.no_grad(): | |
output = model([image], mode='sample') | |
reports = tokenizer.decode_batch(output.cpu().numpy()) | |
outputs = reports[0].split('.', 1)[-1].strip() | |
return outputs | |
# Define Gradio interface | |
iface = gr.Interface( | |
fn=generate_report, | |
inputs=gr.Image(), # Define input shape as needed | |
outputs="text", | |
title="PromptNet", | |
description="Upload a medical image for thorax disease reporting.", | |
examples=[["example_figs/0.png"], ["example_figs/1.png"], ["example_figs/2.png"]] | |
) | |
if __name__ == "__main__": | |
iface.launch() | |