import os import torch from flask import Flask, request, jsonify from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig from PIL import Image import io import base64 # Get API token from environment variable api_token = os.getenv("HF_TOKEN").strip() app = Flask(__name__) # Quantization configuration bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16 ) # Load model without Flash Attention model = AutoModel.from_pretrained( "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", quantization_config=bnb_config, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True, token=api_token ) tokenizer = AutoTokenizer.from_pretrained( "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", trust_remote_code=True, token=api_token ) def decode_base64_image(base64_string): # Decode base64 image image_data = base64.b64decode(base64_string) image = Image.open(io.BytesIO(image_data)).convert('RGB') return image @app.route('/analyze', methods=['POST']) def analyze_input(): data = request.json question = data.get('question', '') base64_image = data.get('image', None) try: # Process with image if provided if base64_image: image = decode_base64_image(base64_image) inputs = model.prepare_inputs_for_generation( input_ids=tokenizer(question, return_tensors="pt").input_ids, images=[image] ) outputs = model.generate(**inputs, max_new_tokens=256) else: # Text-only processing inputs = tokenizer(question, return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=256) # Decode response response = tokenizer.decode(outputs[0], skip_special_tokens=True) return jsonify({ 'status': 'success', 'response': response }) except Exception as e: return jsonify({ 'status': 'error', 'message': str(e) }), 500 if __name__ == '__main__': app.run(debug=True)