aibankingai / app.py
admin08077's picture
Create app.py
aab002d verified
import os
import requests
from flask import Flask, request, jsonify
from transformers import pipeline
from PIL import Image
import io
import base64
# Import for image generation
from diffusers import AutoPipelineForText2Image
app = Flask(__name__)
# --- Configuration ---
GEMMA_MODEL_ID = "google/gemma-4-E2B-it"
IMAGE_GEN_MODEL_ID = "stabilityai/sd-turbo" # A fast, small Stable Diffusion model for demonstration
MAX_NEW_TOKENS = 200 # Adjust as needed for Gemma 4 response length
IMAGE_SIZE = (512, 512) # For generated images
# Determine device for models
# For a CPU-focused Dockerfile, this will default to CPU (-1 or "cpu")
if os.environ.get("USE_GPU", "false").lower() == "true" and os.getenv("CUDA_VISIBLE_DEVICES", "") != "":
device = 0 # Use the first GPU
torch_device_name = "cuda"
else:
device = -1 # Use CPU
torch_device_name = "cpu"
# --- Model Loading ---
gemma_pipeline = None
image_gen_pipeline = None
try:
print(f"Loading Gemma 4 multimodal model: {GEMMA_MODEL_ID} on device {torch_device_name} (pipeline device {device})...")
gemma_pipeline = pipeline("any-to-any", model=GEMMA_MODEL_ID, device=device)
print("Gemma 4 model loaded successfully.")
except Exception as e:
print(f"Error loading Gemma 4 model: {e}")
try:
print(f"Loading Image Generation model: {IMAGE_GEN_MODEL_ID} on device {torch_device_name}...")
image_gen_pipeline = AutoPipelineForText2Image.from_pretrained(IMAGE_GEN_MODEL_ID).to(torch_device_name)
# Only enable xformers if on GPU
if torch_device_name == "cuda":
try:
# Note: xformers might require a specific CUDA version or manual installation.
# If this line causes issues, comment it out.
image_gen_pipeline.enable_xformers_memory_efficient_attention() # Optional: for memory efficiency on GPU
print("xFormers enabled for image generation.")
except ImportError:
print("xFormers not installed or not available, skipping memory efficient attention.")
print("Image Generation model loaded successfully.")
except Exception as e:
print(f"Error loading Image Generation model: {e}")
# --- Helper Functions ---
def encode_image_to_base64(image: Image.Image) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
# --- API Endpoints ---
@app.route('/')
def home():
return "Multimodal AI (Gemma 4) and Image Generation API is running. Use /gemma-predict or /generate-image."
@app.route('/gemma-predict', methods=['POST'])
def gemma_predict():
"""
Endpoint for Gemma 4 multimodal text generation (image + text -> text).
"""
if gemma_pipeline is None:
return jsonify({"error": "Gemma 4 model not loaded. Please check server logs."}), 503
try:
data = request.json
if not data:
return jsonify({"error": "No JSON data provided"}), 400
image_base64 = data.get('image_base64')
text_prompt = data.get('text_prompt', '')
if not image_base64 and not text_prompt:
return jsonify({"error": "At least 'image_base64' or 'text_prompt' must be provided"}), 400
messages = []
if image_base64:
try:
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
messages.append({
"type": "image",
"image": image,
})
except Exception as e:
return jsonify({"error": f"Invalid image_base64 provided: {e}"}), 400
if text_prompt:
messages.append({
"type": "text",
"text": text_prompt,
})
if not messages:
return jsonify({"error": "No valid input (image or text) provided for Gemma."}), 400
full_messages = [
{
"role": "user",
"content": messages,
}
]
output = gemma_pipeline(full_messages, max_new_tokens=MAX_NEW_TOKENS, return_full_text=False)
if output and len(output) > 0 and "generated_text" in output[0]:
return jsonify({"prediction": output[0]["generated_text"]})
else:
return jsonify({"error": "Gemma 4 model did not return generated text."}), 500
except Exception as e:
print(f"Error during Gemma 4 prediction: {e}")
return jsonify({"error": f"An error occurred during Gemma 4 prediction: {str(e)}"}), 500
@app.route('/generate-image', methods=['POST'])
def generate_image():
"""
Endpoint for text-to-image generation.
"""
if image_gen_pipeline is None:
return jsonify({"error": "Image generation model not loaded. Please check server logs."}), 503
try:
data = request.json
if not data:
return jsonify({"error": "No JSON data provided"}), 400
prompt = data.get('prompt')
if not prompt:
return jsonify({"error": "Missing 'prompt' for image generation."}), 400
# Generate image
# You can add more parameters here like num_inference_steps, guidance_scale
generated_image = image_gen_pipeline(prompt).images[0]
# Encode the generated image to base64
image_base64 = encode_image_to_base64(generated_image)
return jsonify({"image_base64": image_base64, "prompt": prompt})
except Exception as e:
print(f"Error during image generation: {e}")
return jsonify({"error": f"An error occurred during image generation: {str(e)}"}), 500
@app.route('/status', methods=['GET'])
def status():
"""
Checks the status of both AI models.
"""
gemma_status = "ready" if gemma_pipeline else "not_loaded"
image_gen_status = "ready" if image_gen_pipeline else "not_loaded"
return jsonify({
"gemma_4_model_id": GEMMA_MODEL_ID,
"gemma_4_status": gemma_status,
"image_gen_model_id": IMAGE_GEN_MODEL_ID,
"image_gen_status": image_gen_status,
"device_used": torch_device_name
})
# --- Main Execution ---
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)