CheXRay / app.py
Tonic's picture
Update app.py
580925d verified
raw history blame
No virus
5.85 kB
import spaces
import io
import os
import torch
from PIL import Image
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
title = """# Welcome to🌟Tonic's CheXRay⚕⚛ !
You can use this ZeroGPU Space to test out the current model [StanfordAIMI/CheXagent-8b](https://huggingface.co/StanfordAIMI/CheXagent-8b). CheXRay⚕⚛ is fine tuned to analyze chest x-rays with a different and generally better results than other multimodal models.
You can also useCheXRay⚕⚛ by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/CheXRay?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3>
### How To use
Upload a medical image and enter a prompt to receive an AI-generated analysis.
simply upload an image with the right prompt (coming soon!) and anaylze your Xray !
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to 🌟 [DataTonic](https://github.com/Tonic-AI/DataTonic) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
device = "cuda"
dtype = torch.float16
processor = AutoProcessor.from_pretrained("StanfordAIMI/CheXagent-8b", trust_remote_code=True)
generation_config = GenerationConfig.from_pretrained("StanfordAIMI/CheXagent-8b")
# model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True)
@spaces.GPU
def generate(image, prompt):
model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True).to(device)
if hasattr(image, "read"):
image = Image.open(io.BytesIO(image.read())).convert("RGB")
else:
image = image
images = [image]
inputs = processor(images=images, text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt").to(device=device, dtype=dtype)
output = model.generate(**inputs, generation_config=generation_config)[0]
response = processor.tokenizer.decode(output, skip_special_tokens=True)
return response
with gr.Blocks() as demo:
gr.Markdown(title)
with gr.Accordion("Custom Prompt Analysis"):
with gr.Row():
image_input_custom = gr.Image(type="pil")
prompt_input_custom = gr.Textbox(label="Enter your custom prompt")
generate_button_custom = gr.Button("Generate")
output_text_custom = gr.Textbox(label="Response")
def custom_generate(image, prompt):
if isinstance(image, str) and os.path.exists(image):
with open(image, 'rb') as file:
return generate(file, prompt)
else:
return generate(image, prompt)
generate_button_custom.click(fn=custom_generate, inputs=[image_input_custom, prompt_input_custom], outputs=output_text_custom)
example_prompt = "65 y/m Chronic cough and weight loss x 6 months. Chest X-rays normal. Consulted multiple pulmonologists with not much benefit. One wise pulmonologist thinks of GERD and sends him to the Gastro department. Can you name the classical finding here?"
example_image_path = os.path.join(os.path.dirname(__file__), "hegde.jpg")
gr.Examples(
examples=[[example_image_path, example_prompt]],
inputs=[image_input_custom, prompt_input_custom],
outputs=[output_text_custom],
fn=custom_generate,
cache_examples=True
)
with gr.Accordion("Anatomical Feature Analysis"):
anatomies = [
"Airway", "Breathing", "Cardiac", "Diaphragm",
"Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, and pacemakers)"
]
with gr.Row():
image_input_feature = gr.Image(type="pil")
prompt_select = gr.Dropdown(label="Select an anatomical feature", choices=anatomies)
generate_button_feature = gr.Button("Analyze Feature")
output_text_feature = gr.Textbox(label="Response")
generate_button_feature.click(fn=lambda image, feature: generate(image, f'Describe "{feature}"'), inputs=[image_input_feature, prompt_select], outputs=output_text_feature)
with gr.Accordion("Common Abnormalities Analysis"):
common_abnormalities = ["Lung Nodule", "Pleural Effusion", "Pneumonia"]
with gr.Row():
image_input_abnormality = gr.Image(type="pil")
abnormality_select = gr.Dropdown(label="Select a common abnormality", choices=common_abnormalities)
generate_button_abnormality = gr.Button("Analyze Abnormality")
output_text_abnormality = gr.Textbox(label="Response")
generate_button_abnormality.click(fn=lambda image, abnormality: generate(image, f'Analyze for "{abnormality}"'), inputs=[image_input_abnormality, abnormality_select], outputs=output_text_abnormality)
demo.launch()