Spaces:
Paused
Paused
import os | |
import torch | |
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig | |
import gradio as gr | |
from PIL import Image | |
from torchvision.transforms import ToTensor | |
# Get API token from environment variable | |
api_token = os.getenv("HF_TOKEN").strip() | |
# Quantization configuration | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
# Initialize model and tokenizer | |
model = AutoModel.from_pretrained( | |
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", | |
quantization_config=bnb_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
attn_implementation="flash_attention_2", | |
token=api_token | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", | |
trust_remote_code=True, | |
token=api_token | |
) | |
def analyze_input(image, question): | |
try: | |
if image is not None: | |
# Convert to RGB if image is provided | |
image = image.convert('RGB') | |
# Prepare messages in the format expected by the model | |
msgs = [{'role': 'user', 'content': [image, question]}] | |
# Generate response using the chat method | |
response_stream = model.chat( | |
image=image, | |
msgs=msgs, | |
tokenizer=tokenizer, | |
sampling=True, | |
temperature=0.95, | |
stream=True | |
) | |
# Collect the streamed response | |
generated_text = "" | |
for new_text in response_stream: | |
generated_text += new_text | |
print(new_text, flush=True, end='') | |
return {"status": "success", "response": generated_text} | |
except Exception as e: | |
import traceback | |
error_trace = traceback.format_exc() | |
print(f"Error occurred: {error_trace}") | |
return {"status": "error", "message": str(e)} | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=analyze_input, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Medical Image"), | |
gr.Textbox( | |
label="Medical Question", | |
placeholder="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?", | |
value="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?" | |
) | |
], | |
outputs=gr.JSON(label="Analysis"), | |
title="Medical Image Analysis Assistant", | |
description="Upload a medical image and ask questions about it. The AI will analyze the image and provide detailed responses." | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |