sounar commited on
Commit
5be90eb
1 Parent(s): 84feed9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -74
app.py CHANGED
@@ -1,91 +1,66 @@
1
- import os
2
  import torch
3
- from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
- import gradio as gr
5
  from PIL import Image
6
- from torchvision.transforms import ToTensor
7
 
8
  # Get API token from environment variable
9
  api_token = os.getenv("HF_TOKEN").strip()
10
 
11
- # Quantization configuration
12
  bnb_config = BitsAndBytesConfig(
13
- load_in_4bit=True,
14
- bnb_4bit_quant_type="nf4",
15
- bnb_4bit_use_double_quant=True,
16
- bnb_4bit_compute_dtype=torch.float16
17
  )
18
 
19
- # Initialize model and tokenizer
20
  model = AutoModel.from_pretrained(
21
- "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
22
- quantization_config=bnb_config,
23
- device_map="auto",
24
- torch_dtype=torch.float16,
25
- trust_remote_code=True,
26
  attn_implementation="flash_attention_2",
27
- token=api_token
28
  )
29
-
30
  tokenizer = AutoTokenizer.from_pretrained(
31
- "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
32
- trust_remote_code=True,
33
- token=api_token
34
  )
35
 
36
- def analyze_input(image, question):
37
- try:
38
- if image is not None:
39
- # Convert to RGB if image is provided
40
- image = image.convert('RGB')
41
-
42
- # Prepare messages in the format expected by the model
43
- msgs = [{'role': 'user', 'content': [image, question]}]
44
-
45
- # Generate response using the chat method
46
- response_stream = model.chat(
47
- image=image,
48
- msgs=msgs,
49
- tokenizer=tokenizer,
50
- sampling=True,
51
- temperature=0.95,
52
- stream=True
53
- )
54
-
55
- # Collect the streamed response
56
- generated_text = ""
57
- for new_text in response_stream:
58
- generated_text += new_text
59
- print(new_text, flush=True, end='')
60
-
61
- return {"status": "success", "response": generated_text}
62
-
63
- except Exception as e:
64
- import traceback
65
- error_trace = traceback.format_exc()
66
- print(f"Error occurred: {error_trace}")
67
- return {"status": "error", "message": str(e)}
68
 
69
- # Create Gradio interface
70
- demo = gr.Interface(
71
- fn=analyze_input,
72
- inputs=[
73
- gr.Image(type="pil", label="Upload Medical Image"),
74
- gr.Textbox(
75
- label="Medical Question",
76
- placeholder="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?",
77
- value="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?"
78
- )
79
- ],
80
- outputs=gr.JSON(label="Analysis"),
81
- title="Medical Image Analysis Assistant",
82
- description="Upload a medical image and ask questions about it. The AI will analyze the image and provide detailed responses."
83
- )
84
 
85
- # Launch the Gradio app
86
- if __name__ == "__main__":
87
- demo.launch(
88
- share=True,
89
- server_name="0.0.0.0",
90
- server_port=7860
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
  import torch
 
 
3
  from PIL import Image
4
+ from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
5
 
6
  # Get API token from environment variable
7
  api_token = os.getenv("HF_TOKEN").strip()
8
 
9
+ # Model configuration
10
  bnb_config = BitsAndBytesConfig(
11
+ load_in_4bit=True,
12
+ bnb_4bit_quant_type="nf4",
13
+ bnb_4bit_use_double_quant=True,
14
+ bnb_4bit_compute_dtype=torch.float16,
15
  )
16
 
17
+ # Model and tokenizer loading
18
  model = AutoModel.from_pretrained(
19
+ "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
20
+ quantization_config=bnb_config,
21
+ device_map="auto",
22
+ torch_dtype=torch.float16,
23
+ trust_remote_code=True,
24
  attn_implementation="flash_attention_2",
 
25
  )
 
26
  tokenizer = AutoTokenizer.from_pretrained(
27
+ "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
28
+ trust_remote_code=True
 
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ app = Flask(__name__)
34
+
35
+ # Model configuration and loading (same as before)
36
+
37
+ @app.route('/analyze', methods=['POST'])
38
+ def analyze():
39
+ image = request.files['image']
40
+ question = request.form['question']
41
+
42
+ # Preprocess image
43
+ image = Image.open(image).convert('RGB')
44
+
45
+ # Prepare input
46
+ msgs = [{'role': 'user', 'content': [image, question]}]
47
+
48
+ # Generate response
49
+ res = model.chat(
50
+ image=image,
51
+ msgs=msgs,
52
+ tokenizer=tokenizer,
53
+ sampling=True,
54
+ temperature=0.95,
55
+ stream=True
56
+ )
57
+
58
+ # Process response
59
+ generated_text = ""
60
+ for new_text in res:
61
+ generated_text += new_text
62
+
63
+ return jsonify({'response': generated_text})
64
+
65
+ if __name__ == '__main__':
66
+ app.run(debug=True)