| import argparse |
| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer |
| from threading import Thread |
|
|
| findings = "enlarged cardiomediastinum, cardiomegaly, lung opacity, lung lesion, edema, consolidation, pneumonia, atelectasis, pneumothorax, pleural Effusion, pleural other, fracture, support devices" |
|
|
| templates = { |
| "single-image": ( |
| "radiology image: <image> Which of the following findings are present in the radiology image? Findings: {findings}", |
| "Based on the previous conversation, provide a description of the findings in the radiology image.", |
| ), |
| "multi-image": ( |
| "radiology images: {images} Which of the following findings are present in the radiology images? Findings: {findings}", |
| "Based on the previous conversation, provide a description of the findings in the radiology images.", |
| ), |
| "multi-study": ( |
| "prior radiology images: {prior_images}, prior radiology report: {prior_report} follow-up images: {images}, The radiology studies are given in chronological order. Which of the following findings are present in the current follow-up radiology images? Findings: {findings}", |
| "Based on the previous conversation, provide a description of the findings in the current follow-up radiology images.", |
| ), |
| "visual-grounding": "Provide the bounding box coordinate of the region this phrase describes: {phrase}", |
| "easy-language": "Explain the description with easy language.", |
| "summarize": "Summarize the description in one concise sentence.", |
| "recommend": "What further diagnosis and treatment do you recommend based on the given x-ray?", |
| } |
|
|
| title_markdown = """ |
| **Usage Instructions**: |
| 1. Add chest x-ray images of a study to the "Study images" section. |
| 2. (Optional) Add "Prior study images" and "Prior study report". |
| 3. Click the "Medical Report Generation" button. |
| 4. You can also have additional conversations. Please refer to the "Examples" for guidance. |
| |
| **Notice**: Enabling "do_sample" in the "Parameters" may introduce some randomness to the output. |
| """ |
|
|
|
|
| def load_model(device, dtype): |
| |
| processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| "Deepnoid/M4CXR-TNNLS", |
| trust_remote_code=True, |
| torch_dtype=dtype, |
| device_map=device, |
| ) |
| return processor, model |
|
|
|
|
| def medical_report_generation(history, *args): |
| ( |
| study_images, |
| do_sample, |
| temperature, |
| top_k, |
| top_p, |
| length_penalty, |
| num_beams, |
| no_repeat_ngram_size, |
| max_new_tokens, |
| prior_images, |
| prior_report, |
| ) = args |
| if history: |
| raise gr.Error('Please "Clear" the chat history or reload this page.') |
|
|
| if not study_images: |
| raise gr.Error('Please add "Study images".') |
|
|
| images = [i[0] for i in study_images] |
|
|
| if prior_images: |
| images = [i[0] for i in prior_images] + images |
| prior_image_tokens = " ".join("<image>" for _ in prior_images) |
| follow_up_image_tokens = " ".join("<image>" for _ in study_images) |
| questions = list(templates["multi-study"]) |
| questions[0] = questions[0].format( |
| prior_images=prior_image_tokens, |
| prior_report=prior_report, |
| images=follow_up_image_tokens, |
| findings=findings, |
| ) |
| else: |
| if len(images) == 1: |
| questions = list(templates["single-image"]) |
| questions[0] = questions[0].format(findings=findings) |
| else: |
| image_tokens = " ".join("<image>" for _ in images) |
| questions = list(templates["multi-image"]) |
| questions[0] = questions[0].format(images=image_tokens, findings=findings) |
|
|
| generator = predict( |
| questions[0], |
| history, |
| study_images, |
| do_sample, |
| temperature, |
| top_k, |
| top_p, |
| length_penalty, |
| num_beams, |
| no_repeat_ngram_size, |
| max_new_tokens, |
| prior_images, |
| prior_report, |
| ) |
| for output in generator: |
| response = output |
|
|
| history.append([questions[0], response]) |
| generator = predict( |
| questions[1], |
| history, |
| study_images, |
| do_sample, |
| temperature, |
| top_k, |
| top_p, |
| length_penalty, |
| num_beams, |
| no_repeat_ngram_size, |
| max_new_tokens, |
| prior_images, |
| prior_report, |
| ) |
| for output in generator: |
| response = output |
| history.append([questions[1], response]) |
|
|
| return history, history |
|
|
|
|
| def predict(message, history, *args): |
| ( |
| study_images, |
| do_sample, |
| temperature, |
| top_k, |
| top_p, |
| length_penalty, |
| num_beams, |
| no_repeat_ngram_size, |
| max_new_tokens, |
| prior_images, |
| prior_report, |
| ) = args |
|
|
| |
| chats = [] |
|
|
| for question, answer in history: |
| chats.append({"role": "user", "content": question}) |
| chats.append({"role": "assistant", "content": answer}) |
|
|
| chats.append({"role": "user", "content": message}) |
|
|
| prompt = processor.apply_chat_template(chats, tokenize=False) |
| prompts = [prompt] |
|
|
| if study_images: |
| images = [i[0] for i in study_images] |
| |
| if prior_images: |
| images = [i[0] for i in prior_images] + images |
| else: |
| images = None |
|
|
| |
| inputs = processor(texts=prompts, images=images) |
|
|
| |
| inputs = { |
| k: v.to(model.dtype) if v.dtype == torch.float else v for k, v in inputs.items() |
| } |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| streamer = TextIteratorStreamer( |
| processor.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True |
| ) |
|
|
| generate_kwargs = dict( |
| inputs, |
| streamer=streamer, |
| max_new_tokens=max_new_tokens, |
| do_sample=do_sample, |
| top_p=top_p, |
| top_k=top_k, |
| temperature=temperature, |
| num_beams=num_beams, |
| length_penalty=length_penalty, |
| no_repeat_ngram_size=no_repeat_ngram_size, |
| ) |
| t = Thread(target=model.generate, kwargs=generate_kwargs) |
| t.start() |
|
|
| partial_message = "" |
| for new_token in streamer: |
| partial_message += new_token |
| yield partial_message |
|
|
|
|
| def build_demo(model_name: str = "M4CXR"): |
| title_model_name = f"""<h1 align="center">{model_name} </h1>""" |
|
|
| with gr.Blocks(title=model_name) as demo: |
| state = gr.State() |
|
|
| gr.Markdown(title_model_name) |
| gr.Markdown(title_markdown) |
|
|
| with gr.Row(): |
| with gr.Column(scale=3): |
|
|
| mrg = gr.Button(value="Medical Report Generation", variant="primary") |
|
|
| with gr.Row(visible=True) as button_row: |
| prior_images = gr.Gallery(label="Prior study images", type="pil") |
| study_images = gr.Gallery(label="Study images", type="pil") |
| prior_report = gr.Textbox(label="Prior study report") |
|
|
| with gr.Accordion( |
| "Parameters", open=False, visible=True |
| ) as generate_config: |
| do_sample = gr.Checkbox( |
| interactive=True, value=False, label="do_sample" |
| ) |
| |
| temperature = gr.Slider( |
| 0, 1, 1, step=0.1, interactive=True, label="Temperature" |
| ) |
| top_k = gr.Slider(1, 5, 3, step=1, interactive=True, label="Top K") |
| top_p = gr.Slider( |
| 0, 1, 0.9, step=0.1, interactive=True, label="Top p" |
| ) |
| length_penalty = gr.Slider( |
| 1, 5, 1, step=0.1, interactive=True, label="length_penalty" |
| ) |
| num_beams = gr.Slider( |
| 1, 5, 1, step=1, interactive=True, label="Beam Size" |
| ) |
| no_repeat_ngram_size = gr.Slider( |
| 1, 5, 2, step=1, interactive=True, label="no_repeat_ngram_size" |
| ) |
| max_new_tokens = gr.Slider( |
| 0, |
| 1024, |
| 512, |
| step=64, |
| interactive=True, |
| label="Max New tokens", |
| ) |
|
|
| with gr.Column(scale=6): |
|
|
| chat_interface = gr.ChatInterface( |
| fn=predict, |
| additional_inputs=[ |
| study_images, |
| do_sample, |
| temperature, |
| top_k, |
| top_p, |
| length_penalty, |
| num_beams, |
| no_repeat_ngram_size, |
| max_new_tokens, |
| prior_images, |
| prior_report, |
| ], |
| examples=[ |
| [templates["summarize"]], |
| [templates["easy-language"]], |
| [templates["recommend"]], |
| [templates["visual-grounding"]], |
| ], |
| ) |
|
|
| |
| mrg.click( |
| medical_report_generation, |
| inputs=[ |
| chat_interface.chatbot_state, |
| study_images, |
| do_sample, |
| temperature, |
| top_k, |
| top_p, |
| length_penalty, |
| num_beams, |
| no_repeat_ngram_size, |
| max_new_tokens, |
| prior_images, |
| prior_report, |
| ], |
| outputs=[ |
| chat_interface.chatbot, |
| chat_interface.chatbot_state, |
| ], |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--host", type=str, default="0.0.0.0") |
| parser.add_argument("--debug", action="store_true", help="using debug mode") |
| parser.add_argument("--port", type=int) |
| parser.add_argument("--share", action="store_true", help="share") |
| parser.add_argument("--dtype", type=str, default="torch.bfloat16") |
| args = parser.parse_args() |
|
|
| device = torch.device("cuda") |
| dtype = eval(args.dtype) |
| processor, model = load_model(device, dtype) |
|
|
| demo = build_demo("M4CXR") |
| demo.queue(status_update_rate=10, api_open=False).launch( |
| server_name=args.host, debug=args.debug, server_port=args.port, share=args.share |
| ) |
|
|