Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import simple_slice_viewer as ssv | |
| import SimpleITK as sikt | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| device = torch.device('cpu') # Set to 'cuda' if using a GPU | |
| dtype = torch.float32 # Data type for model processing | |
| model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B' | |
| proj_out_num = 256 # Number of projection outputs required for the image | |
| # Load model and tokenizer | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| torch_dtype=torch.float32, | |
| device_map='cpu', | |
| trust_remote_code=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name_or_path, | |
| model_max_length=512, | |
| padding_side="right", | |
| use_fast=False, | |
| trust_remote_code=True | |
| ) | |
| def process_image(image_path, question): | |
| # Load the image | |
| image_np = np.load(image_path) # Load the .npy image | |
| image_tokens = "<im_patch>" * proj_out_num | |
| input_txt = image_tokens + question | |
| input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device) | |
| # Prepare image for model | |
| image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device) | |
| # Generate model response | |
| generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0) | |
| generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True) | |
| return generated_texts[0], image_np | |
| # Gradio Interface | |
| def gradio_interface(image, question): | |
| response, image_np = process_image(image.name, question) | |
| # Extract slices from the image | |
| slices = [] | |
| for i in range(image_np.shape[0]): # Assuming the image is 3D | |
| slices.append(image_np[i, :, :]) # Extract each slice | |
| # Plot the slices and save them as images | |
| fig, axes = plt.subplots(1, len(slices), figsize=(15, 5)) | |
| if len(slices) == 1: | |
| axes = [axes] | |
| for ax, slice_data in zip(axes, slices): | |
| ax.imshow(slice_data, cmap='gray') | |
| ax.axis('off') | |
| plt.tight_layout() | |
| plt.savefig('slices.png') # Save the slices as a PNG image | |
| return response, 'slices.png' | |
| # Gradio App | |
| gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image | |
| gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Model Response"), | |
| gr.Image(label="Image Slices", type="filepath", image_mode='L') | |
| ], | |
| title="Medical Image Analysis", | |
| description="Upload a .npy image and ask a question to analyze it using the model. The image slices will be displayed." | |
| ).launch() | |