File size: 7,481 Bytes
ce23029
 
98d13ed
ce23029
 
 
 
 
 
 
 
 
 
 
aa4cde5
ce23029
 
f0f6b28
ce23029
 
 
 
eef83fc
 
ce23029
50ca679
 
580925d
ce23029
 
 
50ca679
7a118e2
 
e36e5a2
 
ce23029
 
 
 
 
 
 
 
aa4cde5
eef83fc
b3cab48
 
 
eef83fc
 
6889a60
 
e36e5a2
 
 
 
 
6889a60
 
eef83fc
 
 
 
 
 
 
 
6889a60
eef83fc
 
 
 
 
 
 
bc84733
eef83fc
b3cab48
 
 
 
 
 
 
 
 
 
eef83fc
 
 
 
 
 
 
 
 
 
 
 
b3cab48
eef83fc
b3cab48
 
 
 
 
 
 
eef83fc
 
 
 
 
 
 
 
 
 
 
 
ce23029
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import spaces
import io
import os
import torch
from PIL import Image
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig

title = """# Welcome to🌟Tonic's CheXRay⚕⚛ !
You can use this ZeroGPU Space to test out the current model [StanfordAIMI/CheXagent-8b](https://huggingface.co/StanfordAIMI/CheXagent-8b). CheXRay⚕⚛ is fine tuned to analyze chest x-rays with a different and generally better results than other multimodal models. 
You can also useCheXRay⚕⚛  by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/CheXRay?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></h3> 

### How To use

Upload a medical image and enter a prompt to receive an AI-generated analysis.
simply upload an image with the right prompt (coming soon!) and anaylze your Xray !

Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻  [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to 🌟 [DataTonic](https://github.com/Tonic-AI/DataTonic) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""

device = "cuda"
dtype = torch.float16
example_images = ["00000174_003.png", "00006596_000.png", "00006663_000.png", 
                  "00012976_002.png", "00018401_000.png", "00019799_000.png"]

processor = AutoProcessor.from_pretrained("liuhaotian/llava-v1.6-mistral-7b", trust_remote_code=True)
generation_config = GenerationConfig.from_pretrained("liuhaotian/llava-v1.6-mistral-7b")
# model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True)

@spaces.GPU
def generate(image, prompt):
    model = AutoModelForCausalLM.from_pretrained("liuhaotian/llava-v1.6-mistral-7b", torch_dtype=dtype, trust_remote_code=True).to(device)
    if hasattr(image, "read"):
        image = Image.open(io.BytesIO(image.read())).convert("RGB")
    else:
        image = image
    images = [image]
    inputs = processor(images=images, text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt").to(device=device, dtype=dtype)
    output = model.generate(**inputs, generation_config=generation_config)[0]
    response = processor.tokenizer.decode(output, skip_special_tokens=True)
    return response


with gr.Blocks() as demo:
    gr.Markdown(title)
    with gr.Accordion("Custom Prompt Analysis", open=False):
        with gr.Row():
            image_input_custom = gr.Image(type="pil")
            prompt_input_custom = gr.Textbox(label="Enter your custom prompt")
            generate_button_custom = gr.Button("Generate")
            output_text_custom = gr.Textbox(label="Response")

        def custom_generate(image, prompt):
            if isinstance(image, str) and os.path.exists(image):
                with open(image, 'rb') as file:
                    return generate(file, prompt)
            else:
                return generate(image, prompt)

        generate_button_custom.click(fn=custom_generate, inputs=[image_input_custom, prompt_input_custom], outputs=output_text_custom)
        custom_prompt_examples = [
            [os.path.join(os.path.dirname(__file__), img), "You are an expert X-Ray Analyst, describe this chest x-ray in detail focussing on the lung condition:"]
            for img in example_images
            ]
    
        # example_prompt = "65 y/m Chronic cough and weight loss x 6 months. Chest X-rays normal. Consulted multiple pulmonologists with not much benefit. One wise pulmonologist thinks of GERD and sends him to the Gastro department. Can you name the classical finding here?"
        # example_image_path = os.path.join(os.path.dirname(__file__), "hegde.jpg")
        with gr.Accordion("Examples", open=False):

            gr.Examples(
                examples=custom_prompt_examples,
                inputs=[image_input_custom, prompt_input_custom],
                outputs=[output_text_custom],
                fn=custom_generate,
                cache_examples=True
            )
    
    with gr.Accordion("Anatomical Feature Analysis", open=False):
        anatomies = [
            "Airway", "Breathing", "Cardiac", "Diaphragm",
            "Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, and pacemakers)"
        ]
        with gr.Row():
            image_input_feature = gr.Image(type="pil")
            prompt_select = gr.Dropdown(label="Select an anatomical feature", choices=anatomies)
        generate_button_feature = gr.Button("Analyze Feature")
        output_text_feature = gr.Textbox(label="Response")
        generate_button_feature.click(fn=lambda image, feature: generate(image, f'Describe "{feature}"'), inputs=[image_input_feature, prompt_select], outputs=output_text_feature)
        anatomical_feature_examples = [
            [os.path.join(os.path.dirname(__file__), img), "Airway"]
            for img in example_images
        ]
        with gr.Accordion("Examples", open=False):
            gr.Examples(
                examples=anatomical_feature_examples,
                inputs=[image_input_feature, prompt_select],
                outputs=[output_text_feature],
                fn=lambda image, feature: generate(image, f'Describe "{feature}"'),
                cache_examples=True
            )

    with gr.Accordion("Common Abnormalities Analysis", open=False):
        common_abnormalities = ["Lung Nodule", "Pleural Effusion", "Pneumonia"]
        with gr.Row():
            image_input_abnormality = gr.Image(type="pil")
            abnormality_select = gr.Dropdown(label="Select a common abnormality", choices=common_abnormalities)
        generate_button_abnormality = gr.Button("Analyze Abnormality")
        output_text_abnormality = gr.Textbox(label="Response")
        generate_button_abnormality.click(fn=lambda image, abnormality: generate(image, f'Analyze for "{abnormality}"'), inputs=[image_input_abnormality, abnormality_select], outputs=output_text_abnormality)
        common_abnormalities_examples = [
            [os.path.join(os.path.dirname(__file__), img), "Lung Nodule"]
            for img in example_images
        ]
        with gr.Accordion("Examples", open=False):
            gr.Examples(
                examples=common_abnormalities_examples,
                inputs=[image_input_abnormality, abnormality_select],
                outputs=[output_text_abnormality],
                fn=lambda image, abnormality: generate(image, f'Analyze for "{abnormality}"'),
                cache_examples=True
            )
demo.launch()