CheXRay / app.py
nisten's picture
Update app.py
50ca679 verified
raw
history blame
7.48 kB
import spaces
import io
import os
import torch
from PIL import Image
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
title = """# Welcome to🌟Tonic's CheXRay⚕⚛ !
You can use this ZeroGPU Space to test out the current model [StanfordAIMI/CheXagent-8b](https://huggingface.co/StanfordAIMI/CheXagent-8b). CheXRay⚕⚛ is fine tuned to analyze chest x-rays with a different and generally better results than other multimodal models.
You can also useCheXRay⚕⚛ by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/CheXRay?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3>
### How To use
Upload a medical image and enter a prompt to receive an AI-generated analysis.
simply upload an image with the right prompt (coming soon!) and anaylze your Xray !
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to 🌟 [DataTonic](https://github.com/Tonic-AI/DataTonic) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
device = "cuda"
dtype = torch.float16
example_images = ["00000174_003.png", "00006596_000.png", "00006663_000.png",
"00012976_002.png", "00018401_000.png", "00019799_000.png"]
processor = AutoProcessor.from_pretrained("liuhaotian/llava-v1.6-mistral-7b", trust_remote_code=True)
generation_config = GenerationConfig.from_pretrained("liuhaotian/llava-v1.6-mistral-7b")
# model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True)
@spaces.GPU
def generate(image, prompt):
model = AutoModelForCausalLM.from_pretrained("liuhaotian/llava-v1.6-mistral-7b", torch_dtype=dtype, trust_remote_code=True).to(device)
if hasattr(image, "read"):
image = Image.open(io.BytesIO(image.read())).convert("RGB")
else:
image = image
images = [image]
inputs = processor(images=images, text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt").to(device=device, dtype=dtype)
output = model.generate(**inputs, generation_config=generation_config)[0]
response = processor.tokenizer.decode(output, skip_special_tokens=True)
return response
with gr.Blocks() as demo:
gr.Markdown(title)
with gr.Accordion("Custom Prompt Analysis", open=False):
with gr.Row():
image_input_custom = gr.Image(type="pil")
prompt_input_custom = gr.Textbox(label="Enter your custom prompt")
generate_button_custom = gr.Button("Generate")
output_text_custom = gr.Textbox(label="Response")
def custom_generate(image, prompt):
if isinstance(image, str) and os.path.exists(image):
with open(image, 'rb') as file:
return generate(file, prompt)
else:
return generate(image, prompt)
generate_button_custom.click(fn=custom_generate, inputs=[image_input_custom, prompt_input_custom], outputs=output_text_custom)
custom_prompt_examples = [
[os.path.join(os.path.dirname(__file__), img), "You are an expert X-Ray Analyst, describe this chest x-ray in detail focussing on the lung condition:"]
for img in example_images
]
# example_prompt = "65 y/m Chronic cough and weight loss x 6 months. Chest X-rays normal. Consulted multiple pulmonologists with not much benefit. One wise pulmonologist thinks of GERD and sends him to the Gastro department. Can you name the classical finding here?"
# example_image_path = os.path.join(os.path.dirname(__file__), "hegde.jpg")
with gr.Accordion("Examples", open=False):
gr.Examples(
examples=custom_prompt_examples,
inputs=[image_input_custom, prompt_input_custom],
outputs=[output_text_custom],
fn=custom_generate,
cache_examples=True
)
with gr.Accordion("Anatomical Feature Analysis", open=False):
anatomies = [
"Airway", "Breathing", "Cardiac", "Diaphragm",
"Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, and pacemakers)"
]
with gr.Row():
image_input_feature = gr.Image(type="pil")
prompt_select = gr.Dropdown(label="Select an anatomical feature", choices=anatomies)
generate_button_feature = gr.Button("Analyze Feature")
output_text_feature = gr.Textbox(label="Response")
generate_button_feature.click(fn=lambda image, feature: generate(image, f'Describe "{feature}"'), inputs=[image_input_feature, prompt_select], outputs=output_text_feature)
anatomical_feature_examples = [
[os.path.join(os.path.dirname(__file__), img), "Airway"]
for img in example_images
]
with gr.Accordion("Examples", open=False):
gr.Examples(
examples=anatomical_feature_examples,
inputs=[image_input_feature, prompt_select],
outputs=[output_text_feature],
fn=lambda image, feature: generate(image, f'Describe "{feature}"'),
cache_examples=True
)
with gr.Accordion("Common Abnormalities Analysis", open=False):
common_abnormalities = ["Lung Nodule", "Pleural Effusion", "Pneumonia"]
with gr.Row():
image_input_abnormality = gr.Image(type="pil")
abnormality_select = gr.Dropdown(label="Select a common abnormality", choices=common_abnormalities)
generate_button_abnormality = gr.Button("Analyze Abnormality")
output_text_abnormality = gr.Textbox(label="Response")
generate_button_abnormality.click(fn=lambda image, abnormality: generate(image, f'Analyze for "{abnormality}"'), inputs=[image_input_abnormality, abnormality_select], outputs=output_text_abnormality)
common_abnormalities_examples = [
[os.path.join(os.path.dirname(__file__), img), "Lung Nodule"]
for img in example_images
]
with gr.Accordion("Examples", open=False):
gr.Examples(
examples=common_abnormalities_examples,
inputs=[image_input_abnormality, abnormality_select],
outputs=[output_text_abnormality],
fn=lambda image, abnormality: generate(image, f'Analyze for "{abnormality}"'),
cache_examples=True
)
demo.launch()