|
import os |
|
from unsloth import FastVisionModel |
|
import torch |
|
from PIL import Image |
|
from datasets import load_dataset |
|
from transformers import TextStreamer |
|
import matplotlib.pyplot as plt |
|
|
|
import gradio as gr |
|
|
|
|
|
model, tokenizer = FastVisionModel.from_pretrained( |
|
"0llheaven/Llama-3.2-11B-Vision-Radiology-mini", |
|
load_in_4bit=True, |
|
use_gradient_checkpointing="unsloth", |
|
).to("cpu") |
|
|
|
|
|
FastVisionModel.for_inference(model) |
|
|
|
|
|
cached_image = None |
|
cached_response = None |
|
|
|
|
|
def predict_radiology_description(image, instruction): |
|
global cached_image, cached_response |
|
|
|
try: |
|
|
|
current_image_tensor = torch.tensor(image.getdata()) |
|
|
|
|
|
if cached_image is not None and torch.equal(cached_image, current_image_tensor): |
|
|
|
return cached_response |
|
|
|
|
|
messages = [{"role": "user", "content": [ |
|
{"type": "image"}, |
|
{"type": "text", "text": instruction} |
|
]}] |
|
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True) |
|
|
|
|
|
inputs = tokenizer( |
|
image, |
|
input_text, |
|
add_special_tokens=False, |
|
return_tensors="pt", |
|
).to("cpu") |
|
|
|
|
|
text_streamer = TextStreamer(tokenizer, skip_prompt=True) |
|
|
|
|
|
output_ids = model.generate( |
|
**inputs, |
|
streamer=text_streamer, |
|
max_new_tokens=256, |
|
use_cache=True, |
|
temperature=1.5, |
|
min_p=0.1 |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
cached_image = current_image_tensor |
|
cached_response = generated_text.replace("assistant", "\n\nAssistant").strip() |
|
return cached_response |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
def chat_process(image, instruction, history=None): |
|
if history is None: |
|
history = [] |
|
|
|
|
|
response = predict_radiology_description(image, instruction) |
|
|
|
|
|
history.append((instruction, response)) |
|
return history, history |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore", category=UserWarning, module="gradio.helpers") |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 🩻 Radiology Image ChatBot") |
|
gr.Markdown("Upload a radiology image and provide an instruction for the AI to describe the findings.") |
|
gr.Markdown("Example instruction : You are an expert radiographer. Describe accurately what you see in this image.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
image_input = gr.Image(type="pil", label="Upload Radiology Image") |
|
|
|
instruction_input = gr.Textbox( |
|
label="Instruction", |
|
value="You are an expert radiographer. Describe accurately what you see in this image.", |
|
placeholder="Provide specific instructions..." |
|
) |
|
with gr.Column(): |
|
|
|
chatbot = gr.Chatbot(label="Chat History") |
|
|
|
with gr.Row(): |
|
clear_btn = gr.Button("Clear") |
|
submit_btn = gr.Button("Submit") |
|
|
|
|
|
submit_btn.click( |
|
lambda image, instruction, history: ( |
|
*chat_process(image, instruction, history), |
|
image, |
|
"" |
|
), |
|
inputs=[image_input, instruction_input, chatbot], |
|
outputs=[chatbot, chatbot, image_input, instruction_input] |
|
) |
|
|
|
|
|
clear_btn.click( |
|
lambda: (None, None, None, None), |
|
inputs=[], |
|
outputs=[chatbot, chatbot, image_input, instruction_input] |
|
) |
|
|
|
|
|
demo.launch(debug=True) |
|
|