FT_Llama / app.py
0llheaven's picture
Update app.py
7afd75d verified
import os
from unsloth import FastVisionModel
import torch
from PIL import Image
from datasets import load_dataset
from transformers import TextStreamer
import matplotlib.pyplot as plt
import gradio as gr
# Load the model
model, tokenizer = FastVisionModel.from_pretrained(
"0llheaven/Llama-3.2-11B-Vision-Radiology-mini",
load_in_4bit=True,
use_gradient_checkpointing="unsloth",
).to("cpu")
# เปลี่ยนโหมดของโมเดลเป็นสำหรับ inference
FastVisionModel.for_inference(model)
# ตัวแปรสำหรับแคช
cached_image = None
cached_response = None
# ฟังก์ชันประมวลผลภาพและสร้างคำอธิบาย
def predict_radiology_description(image, instruction):
global cached_image, cached_response
try:
current_image_tensor = torch.tensor(image.getdata())
# ตรวจสอบว่าภาพเหมือนเดิมและข้อความเหมือนเดิมหรือไม่
if cached_image is not None and torch.equal(cached_image, current_image_tensor):
# ใช้ cached_response กับ text ใหม่
return cached_response
# เตรียมข้อความในรูปแบบที่โมเดลรองรับ
messages = [{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": instruction}
]}]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# เตรียม input สำหรับโมเดล
inputs = tokenizer(
image,
input_text,
add_special_tokens=False,
return_tensors="pt",
).to("cpu")
# ใช้ TextStreamer สำหรับการพยากรณ์
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
# ทำนายข้อความ
output_ids = model.generate(
**inputs,
streamer=text_streamer,
max_new_tokens=256,
use_cache=True,
temperature=1.5,
min_p=0.1
)
# แปลงข้อความที่สร้างเป็นผลลัพธ์
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
cached_image = current_image_tensor # แคชภาพเป็น Tensor
cached_response = generated_text.replace("assistant", "\n\nAssistant").strip()
return cached_response
except Exception as e:
return f"Error: {str(e)}"
# ฟังก์ชัน ChatBot
def chat_process(image, instruction, history=None):
if history is None:
history = []
# ประมวลผลภาพและคำสั่ง
response = predict_radiology_description(image, instruction)
# อัปเดตประวัติ
history.append((instruction, response))
return history, history
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="gradio.helpers")
# UI ของ Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🩻 Radiology Image ChatBot")
gr.Markdown("Upload a radiology image and provide an instruction for the AI to describe the findings.")
gr.Markdown("Example instruction : You are an expert radiographer. Describe accurately what you see in this image.")
with gr.Row():
with gr.Column():
# อัปโหลดรูปภาพ
image_input = gr.Image(type="pil", label="Upload Radiology Image")
# ป้อนคำสั่ง (instruction)
instruction_input = gr.Textbox(
label="Instruction",
value="You are an expert radiographer. Describe accurately what you see in this image.",
placeholder="Provide specific instructions..."
)
with gr.Column():
# แสดงประวัติ Chat
chatbot = gr.Chatbot(label="Chat History")
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit")
# การทำงานของปุ่ม Submit พร้อมล้างเฉพาะข้อความใน instruction_input
submit_btn.click(
lambda image, instruction, history: (
*chat_process(image, instruction, history),
image, # รีเซ็ตค่า image_input
""
),
inputs=[image_input, instruction_input, chatbot],
outputs=[chatbot, chatbot, image_input, instruction_input]
)
# การทำงานของปุ่ม Clear
clear_btn.click(
lambda: (None, None, None, None),
inputs=[],
outputs=[chatbot, chatbot, image_input, instruction_input]
)
# รันแอป
demo.launch(debug=True)