from skimage import io as io_img import io import requests import torch from PIL import Image import numpy as np from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8 from LLAVA_Biovil.llava.model.builder import load_pretrained_model from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1 from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX from utils import create_chest_xray_transform_for_inference if __name__ == '__main__': config = None model_path = "/home/guests/chantal_pellegrini/RaDialog_LLaVA/LLAVA/checkpoints/llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5/checkpoint-21000" #TODO hardcoded in huggingface repo probably model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b', model_name=model_name, load_8bit=False, load_4bit=False) model.config.tokenizer_padding_side = "left" findings = "edema, pleural effusion" #TODO should these come from chexpert classifier? Or not needed for this demo/test? conv = conv_vicuna_v1.copy() REPORT_GEN_PROMPT = f". Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons." print("USER: ", REPORT_GEN_PROMPT) conv.append_message("USER", REPORT_GEN_PROMPT) conv.append_message("ASSISTANT", None) text_input = conv.get_prompt() # get the image vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448) sample_img_path = "https://openi.nlm.nih.gov/imgs/512/10/10/CXR10_IM-0002-2001.png?keywords=Calcified%20Granuloma" #TODO find good image response = requests.get(sample_img_path) image = Image.open(io.BytesIO(response.content)) image = remap_to_uint8(np.array(image)) image = Image.fromarray(image).convert("L") image_tensor = vis_transforms_biovil(image).unsqueeze(0) image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16) input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids) # generate a report with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=False, use_cache=True, max_new_tokens=300, stopping_criteria=[stopping_criteria], pad_token_id=tokenizer.pad_token_id ) pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("", "") print("ASSISTANT: ", pred) # add prediction to conversation conv.messages.pop() conv.append_message("ASSISTANT", pred) conv.append_message("USER", "Translate this report to easy language for a patient to understand.") conv.append_message("ASSISTANT", None) text_input = conv.get_prompt() print("USER: ", "Translate this report to easy language for a patient to understand.") # generate easy language report input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=False, use_cache=True, max_new_tokens=300, stopping_criteria=[stopping_criteria], pad_token_id=tokenizer.pad_token_id ) pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("", "") print("ASSISTANT: ", pred)