Spaces:
Paused
Paused
import os | |
import torch | |
from flask import Flask, request, jsonify | |
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig | |
from PIL import Image | |
import io | |
import base64 | |
# Get API token from environment variable | |
api_token = os.getenv("HF_TOKEN").strip() | |
app = Flask(__name__) | |
# Quantization configuration | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
# Load model without Flash Attention | |
model = AutoModel.from_pretrained( | |
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", | |
quantization_config=bnb_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
token=api_token | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", | |
trust_remote_code=True, | |
token=api_token | |
) | |
def decode_base64_image(base64_string): | |
# Decode base64 image | |
image_data = base64.b64decode(base64_string) | |
image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
return image | |
def analyze_input(): | |
data = request.json | |
question = data.get('question', '') | |
base64_image = data.get('image', None) | |
try: | |
# Process with image if provided | |
if base64_image: | |
image = decode_base64_image(base64_image) | |
inputs = model.prepare_inputs_for_generation( | |
input_ids=tokenizer(question, return_tensors="pt").input_ids, | |
images=[image] | |
) | |
outputs = model.generate(**inputs, max_new_tokens=256) | |
else: | |
# Text-only processing | |
inputs = tokenizer(question, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=256) | |
# Decode response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return jsonify({ | |
'status': 'success', | |
'response': response | |
}) | |
except Exception as e: | |
return jsonify({ | |
'status': 'error', | |
'message': str(e) | |
}), 500 | |
if __name__ == '__main__': | |
app.run(debug=True) |