File size: 5,008 Bytes
bd9df5d e094577 4b589f5 bd9df5d e094577 bd9df5d e094577 bd9df5d e094577 bd9df5d 7afd75d 70ce441 bd9df5d e094577 bd9df5d e094577 bd9df5d e094577 bd9df5d 16400a1 bd9df5d e094577 bd9df5d 7afd75d bd9df5d ddfe551 e094577 bd9df5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
from unsloth import FastVisionModel
import torch
from PIL import Image
from datasets import load_dataset
from transformers import TextStreamer
import matplotlib.pyplot as plt
import gradio as gr
# Load the model
model, tokenizer = FastVisionModel.from_pretrained(
"0llheaven/Llama-3.2-11B-Vision-Radiology-mini",
load_in_4bit=True,
use_gradient_checkpointing="unsloth",
).to("cpu")
# เปลี่ยนโหมดของโมเดลเป็นสำหรับ inference
FastVisionModel.for_inference(model)
# ตัวแปรสำหรับแคช
cached_image = None
cached_response = None
# ฟังก์ชันประมวลผลภาพและสร้างคำอธิบาย
def predict_radiology_description(image, instruction):
global cached_image, cached_response
try:
current_image_tensor = torch.tensor(image.getdata())
# ตรวจสอบว่าภาพเหมือนเดิมและข้อความเหมือนเดิมหรือไม่
if cached_image is not None and torch.equal(cached_image, current_image_tensor):
# ใช้ cached_response กับ text ใหม่
return cached_response
# เตรียมข้อความในรูปแบบที่โมเดลรองรับ
messages = [{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": instruction}
]}]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# เตรียม input สำหรับโมเดล
inputs = tokenizer(
image,
input_text,
add_special_tokens=False,
return_tensors="pt",
).to("cpu")
# ใช้ TextStreamer สำหรับการพยากรณ์
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
# ทำนายข้อความ
output_ids = model.generate(
**inputs,
streamer=text_streamer,
max_new_tokens=256,
use_cache=True,
temperature=1.5,
min_p=0.1
)
# แปลงข้อความที่สร้างเป็นผลลัพธ์
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
cached_image = current_image_tensor # แคชภาพเป็น Tensor
cached_response = generated_text.replace("assistant", "\n\nAssistant").strip()
return cached_response
except Exception as e:
return f"Error: {str(e)}"
# ฟังก์ชัน ChatBot
def chat_process(image, instruction, history=None):
if history is None:
history = []
# ประมวลผลภาพและคำสั่ง
response = predict_radiology_description(image, instruction)
# อัปเดตประวัติ
history.append((instruction, response))
return history, history
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="gradio.helpers")
# UI ของ Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🩻 Radiology Image ChatBot")
gr.Markdown("Upload a radiology image and provide an instruction for the AI to describe the findings.")
gr.Markdown("Example instruction : You are an expert radiographer. Describe accurately what you see in this image.")
with gr.Row():
with gr.Column():
# อัปโหลดรูปภาพ
image_input = gr.Image(type="pil", label="Upload Radiology Image")
# ป้อนคำสั่ง (instruction)
instruction_input = gr.Textbox(
label="Instruction",
value="You are an expert radiographer. Describe accurately what you see in this image.",
placeholder="Provide specific instructions..."
)
with gr.Column():
# แสดงประวัติ Chat
chatbot = gr.Chatbot(label="Chat History")
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit")
# การทำงานของปุ่ม Submit พร้อมล้างเฉพาะข้อความใน instruction_input
submit_btn.click(
lambda image, instruction, history: (
*chat_process(image, instruction, history),
image, # รีเซ็ตค่า image_input
""
),
inputs=[image_input, instruction_input, chatbot],
outputs=[chatbot, chatbot, image_input, instruction_input]
)
# การทำงานของปุ่ม Clear
clear_btn.click(
lambda: (None, None, None, None),
inputs=[],
outputs=[chatbot, chatbot, image_input, instruction_input]
)
# รันแอป
demo.launch(debug=True)
|