rrrc / server.py
doniramdani820's picture
Update server.py
3717eaf verified
import io
import base64
import os
import logging
import traceback
import hashlib
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
import threading
import gc
import psutil
from flask import Flask, request, jsonify
from flask_cors import CORS
from PIL import Image
from ultralytics import YOLO
import numpy as np
import cv2
import re
import yaml
# ============================================================================
# πŸ” SECURITY CONFIGURATION
# ============================================================================
SECRET_KEY = os.getenv('SECRET_KEY', 'your-secret-key-here-change-this')
API_KEYS = {
os.getenv('API_KEY_1', 'key1-change-this'),
os.getenv('API_KEY_2', 'key2-change-this'),
os.getenv('API_KEY_3', 'key3-change-this')
}
def verify_api_key(request):
"""Verifikasi API key dari header atau query parameter"""
api_key = request.headers.get('X-API-Key') or request.args.get('api_key')
return api_key in API_KEYS
# ============================================================================
# πŸš€ PERFORMANCE OPTIMIZATION CONFIGURATION
# ============================================================================
MAX_WORKERS = min(4, (os.cpu_count() or 1) + 1)
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
MEMORY_THRESHOLD = 80
MAX_CACHE_SIZE = 50
os.environ['OMP_NUM_THREADS'] = '2'
os.environ['OPENBLAS_NUM_THREADS'] = '2'
os.environ['MKL_NUM_THREADS'] = '2'
# ============================================================================
# πŸ“Š MONITORING & LOGGING
# ============================================================================
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
class PerformanceMonitor:
def __init__(self):
self.request_count = 0
self.total_processing_time = 0
self.lock = threading.Lock()
def log_request(self, processing_time):
with self.lock:
self.request_count += 1
self.total_processing_time += processing_time
def get_stats(self):
with self.lock:
avg_time = self.total_processing_time / max(self.request_count, 1)
return {
'requests': self.request_count,
'avg_processing_time': avg_time,
'memory_usage': psutil.virtual_memory().percent,
'cpu_usage': psutil.cpu_percent()
}
monitor = PerformanceMonitor()
# ============================================================================
# 🎯 CONFIGURATION
# ============================================================================
app = Flask(__name__)
CORS(app)
CONFIDENCE_THRESHOLD_3X3 = float(os.getenv('CONFIDENCE_3X3', '0.45'))
CONFIDENCE_THRESHOLD_4X4 = float(os.getenv('CONFIDENCE_4X4', '0.25'))
MIN_COVERAGE_PERCENTAGE = int(os.getenv('MIN_COVERAGE', '10'))
# ============================================================================
# πŸ’Ύ CACHING SYSTEM
# ============================================================================
@lru_cache(maxsize=128)
def get_image_hash(image_b64):
"""Generate hash for image caching"""
return hashlib.md5(image_b64.encode()).hexdigest()
prediction_cache = {}
cache_lock = threading.Lock()
def cache_prediction(key, result):
with cache_lock:
if len(prediction_cache) >= MAX_CACHE_SIZE:
oldest_keys = list(prediction_cache.keys())[:MAX_CACHE_SIZE//2]
for old_key in oldest_keys:
del prediction_cache[old_key]
prediction_cache[key] = result
def get_cached_prediction(key):
with cache_lock:
return prediction_cache.get(key)
# ============================================================================
# πŸ”€ CLASS ALIASES (OPTIMIZED)
# ============================================================================
CLASS_ALIASES = {
'cars': {'car', 'cars', 'mobil', 'mobil-mobil', 'kendaraan', 'auto', 'vehicle', 'vehicles', 'sedan', 'hatchback', 'suv'},
'buses': {'bus', 'buses', 'bis', 'autobus', 'bus umum', 'transjakarta'},
'bicycles': {'bicycle', 'bicycles', 'sepeda', 'bike', 'bikes'},
'motorcycles': {'motorcycle', 'motorcycles', 'sepeda motor', 'motor', 'motorbike'},
'taxis': {'taxi', 'taxis', 'taksi', 'cab', 'grab', 'gojek'},
'bridge': {'bridge', 'bridges', 'jembatan', 'flyover', 'overpass'},
'traffic lights': {'traffic light', 'traffic lights', 'lampu lalu lintas', 'lampu merah'},
'a fire hydrant': {'fire hydrant', 'hydrant', 'hidran', 'hidran kebakaran'},
'chimneys': {'chimney', 'chimneys', 'cerobong asap', 'cerobong'},
'stairs': {'stair', 'stairs', 'tangga', 'steps', 'escalator'},
'crosswalks': {'crosswalk', 'crosswalks', 'zebra cross', 'penyeberangan'}
}
# ============================================================================
# πŸ” OPTIMIZED TEXT PROCESSING
# ============================================================================
@lru_cache(maxsize=256)
def normalize_text(text):
if not text: return ""
text = text.lower().strip()
text = re.sub(r'[^\w\s]', ' ', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
@lru_cache(maxsize=256)
def find_class_match(input_text):
if not input_text: return None
normalized_input = normalize_text(input_text)
for canonical_name, aliases in CLASS_ALIASES.items():
if normalized_input in [alias.lower() for alias in aliases]:
return canonical_name
for canonical_name, aliases in CLASS_ALIASES.items():
for alias in aliases:
if alias.lower() in normalized_input:
return canonical_name
return None
# ============================================================================
# πŸ“ MODEL LOADING & INITIALIZATION (STRATEGI: LAZY PER-WORKER)
# ============================================================================
_worker_models = {}
_model_lock = threading.Lock()
model_class_maps = {}
def load_yaml_classes(yaml_path):
try:
with open(yaml_path, 'r', encoding='utf-8') as file:
data = yaml.safe_load(file)
return {idx: name for idx, name in enumerate(data.get('names', []))}
except Exception as e:
logging.error(f"Error loading {yaml_path}: {e}")
return {}
try:
for model_type, yaml_file in [('3x3', 'data.yaml'), ('4x4', 'data4x4.yaml')]:
class_map = {}
yaml_classes = load_yaml_classes(yaml_file)
for class_id, class_name in yaml_classes.items():
canonical_match = find_class_match(class_name.lower())
if canonical_match:
class_map[canonical_match] = class_id
else:
class_map[class_name.lower()] = class_id
model_class_maps[model_type] = class_map
logging.info("βœ… Class maps loaded successfully!")
except Exception as e:
logging.error(f"❌ FATAL: Failed to initialize class maps: {e}")
raise
def get_model(model_type: str):
"""
Loads a model only once per worker process (lazy initialization).
This is the robust solution for multi-process servers like Gunicorn.
"""
if model_type in _worker_models:
return _worker_models[model_type]
with _model_lock:
if model_type in _worker_models:
return _worker_models[model_type]
logging.info(f"WORKER_INIT: Loading model '{model_type}' for worker PID: {os.getpid()}...")
model_path, task_type = '', ''
if model_type == '3x3':
model_path, task_type = 'best.onnx', 'classify'
elif model_type == '4x4':
model_path, task_type = 'best4x4.onnx', 'segment'
else:
logging.error(f"Attempted to load unknown model type: {model_type}")
return None
try:
model = YOLO(model_path, task=task_type)
_worker_models[model_type] = model
logging.info(f"WORKER_INIT: Model '{model_type}' loaded successfully for worker PID: {os.getpid()}.")
return model
except Exception as e:
logging.error(f"WORKER_INIT: Failed to load model '{model_path}' for worker PID: {os.getpid()}: {e}")
return None
# ============================================================================
# πŸ–ΌοΈ OPTIMIZED IMAGE PROCESSING
# ============================================================================
def decode_image_optimized(base64_string):
try:
image_data = base64.b64decode(base64_string.split(',')[1])
image = Image.open(io.BytesIO(image_data)).convert("RGB")
return image
except Exception as e:
logging.error(f"Image decode error: {e}")
return None
def divide_image_into_4x4_grid(image_cv2):
height, width = image_cv2.shape[:2]
grid_height, grid_width = height // 4, width // 4
grid_images, grid_coordinates = [], []
for row in range(4):
for col in range(4):
y1, y2 = row * grid_height, (row + 1) * grid_height if row < 3 else height
x1, x2 = col * grid_width, (col + 1) * grid_width if col < 3 else width
grid_images.append(image_cv2[y1:y2, x1:x2])
grid_coordinates.append((x1, y1, x2, y2))
return grid_images, grid_coordinates
def is_object_in_grid_cell(mask_contour, grid_coords, min_coverage_percentage=MIN_COVERAGE_PERCENTAGE):
x1, y1, x2, y2 = grid_coords
grid_width, grid_height = x2 - x1, y2 - y1
grid_area = grid_width * grid_height
contour_bounds = cv2.boundingRect(mask_contour)
cb_x, cb_y, cb_w, cb_h = contour_bounds
if (cb_x > x2 or cb_x + cb_w < x1 or cb_y > y2 or cb_y + cb_h < y1): return False, 0.0
grid_mask = np.zeros((grid_height, grid_width), dtype=np.uint8)
adjusted_contour = mask_contour - [x1, y1]
clipped_contour = np.clip(adjusted_contour, [0, 0], [grid_width-1, grid_height-1])
if len(clipped_contour) < 3: return False, 0.0
cv2.fillPoly(grid_mask, [clipped_contour.astype(np.int32)], 255)
object_area = np.sum(grid_mask > 0)
coverage_percentage = (object_area / grid_area) * 100
return coverage_percentage >= min_coverage_percentage, coverage_percentage
# ============================================================================
# πŸ”§ UTILITY FUNCTIONS
# ============================================================================
def get_target_class_index(input_title, model_type):
model_classes = model_class_maps.get(model_type, {})
if not input_title or not model_classes: return None
canonical_name = find_class_match(input_title)
if canonical_name and canonical_name in model_classes:
return model_classes[canonical_name]
normalized_input = normalize_text(input_title)
return model_classes.get(normalized_input)
def memory_cleanup():
gc.collect()
current_memory = psutil.virtual_memory().percent
if current_memory > MEMORY_THRESHOLD:
logging.warning(f"High memory usage: {current_memory}%")
# ============================================================================
# πŸ›‘οΈ MIDDLEWARE
# ============================================================================
@app.before_request
def check_api_key():
if request.endpoint in ['health', 'stats']: return
if not verify_api_key(request):
return jsonify({"error": "Invalid or missing API key"}), 401
# ============================================================================
# πŸ“‘ API ENDPOINTS
# ============================================================================
@app.route('/health', methods=['GET'])
def health():
return jsonify({
"status": "healthy",
"models_loaded_in_worker": len(_worker_models),
"memory_usage": psutil.virtual_memory().percent,
"cpu_usage": psutil.cpu_percent()
})
@app.route('/stats', methods=['GET'])
def stats():
return jsonify(monitor.get_stats())
@app.route('/predict', methods=['POST'])
def predict():
import time
start_time = time.time()
try:
data = request.get_json(silent=True)
if not data: return jsonify({"error": "Invalid request body"}), 400
model = get_model('3x3')
if not model: return jsonify({"error": "3x3 model not loaded"}), 500
input_title = data.get('title', '')
target_class_index = get_target_class_index(input_title, '3x3')
if target_class_index is None:
return jsonify({"indices_to_click": [], "message": f"Class '{input_title}' not found", "available_classes": list(model_class_maps['3x3'].keys())})
images_hash = hashlib.md5(str(data.get('images', [])).encode()).hexdigest()
cache_key = f"3x3_{input_title}_{images_hash}"
cached_result = get_cached_prediction(cache_key)
if cached_result:
logging.info(f"Cache hit for {cache_key}")
return jsonify(cached_result)
def process_image(item):
try:
image = decode_image_optimized(item['base64'])
if image is None: return None
results = model(image, verbose=False)
if not results: return None
res = results[0]
if res.probs is None or res.probs.data is None: return None
confidence = res.probs.data[target_class_index].item()
return {'index': item['index'], 'confidence': confidence, 'selected': confidence >= CONFIDENCE_THRESHOLD_3X3}
except Exception as e:
logging.error(f"Error processing image {item.get('index', 'N/A')}: {e}", exc_info=False)
return None
with ThreadPoolExecutor(max_workers=min(len(data.get('images', [])), MAX_WORKERS)) as pool:
results = list(pool.map(process_image, data.get('images', [])))
results_to_click = [r['index'] for r in results if r and r['selected']]
response = {"indices_to_click": results_to_click, "detected_class": find_class_match(input_title) or input_title, "total_detected": len(results_to_click)}
cache_prediction(cache_key, response)
processing_time = time.time() - start_time
monitor.log_request(processing_time)
if psutil.virtual_memory().percent > MEMORY_THRESHOLD: memory_cleanup()
return jsonify(response)
except Exception as e:
logging.error(f"Error in /predict: {e}", exc_info=True)
return jsonify({"error": "Internal server error"}), 500
@app.route('/predict_4x4', methods=['POST'])
def predict_4x4():
"""Optimized 4x4 prediction endpoint"""
import time
start_time = time.time()
try:
data = request.get_json(silent=True)
if not data: return jsonify({"error": "Invalid request body"}), 400
model = get_model('4x4')
if not model: return jsonify({"error": "4x4 model not loaded"}), 500
input_title = data.get('title', '')
target_class_index = get_target_class_index(input_title, '4x4')
if target_class_index is None:
return jsonify({"indices_to_click": [], "message": f"Class '{input_title}' not found", "available_classes": list(model_class_maps['4x4'].keys())})
image_hash = get_image_hash(data['image_b64'])
cache_key = f"4x4_{input_title}_{image_hash}"
cached_result = get_cached_prediction(cache_key)
if cached_result:
logging.info(f"Cache hit for {cache_key}")
return jsonify(cached_result)
image_pil = decode_image_optimized(data['image_b64'])
if image_pil is None: return jsonify({"error": "Invalid image data"}), 400
image_cv2 = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
# baru perbaikan: Baris ini hilang dan menyebabkan NameError. Kita tambahkan kembali.
grid_images, grid_coordinates = divide_image_into_4x4_grid(image_cv2)
results = model(image_cv2, verbose=False)
indices_to_click = []
if results and results[0].masks is not None and results[0].boxes is not None:
for mask, box in zip(results[0].masks, results[0].boxes):
class_id = int(box.cls.item())
confidence = box.conf.item()
if class_id == target_class_index and confidence >= CONFIDENCE_THRESHOLD_4X4:
contour = mask.xy[0].astype(np.int32)
for grid_idx, grid_coords in enumerate(grid_coordinates):
is_selected, coverage = is_object_in_grid_cell(contour, grid_coords)
if is_selected and grid_idx not in indices_to_click:
indices_to_click.append(grid_idx)
response = {"indices_to_click": sorted(indices_to_click), "detected_class": find_class_match(input_title) or input_title, "total_detected": len(indices_to_click)}
cache_prediction(cache_key, response)
processing_time = time.time() - start_time
monitor.log_request(processing_time)
if psutil.virtual_memory().percent > MEMORY_THRESHOLD: memory_cleanup()
return jsonify(response)
except Exception as e:
logging.error(f"Error in /predict_4x4: {e}", exc_info=True)
return jsonify({"error": "Internal server error"}), 500
@app.route('/classes', methods=['GET'])
def get_available_classes():
return jsonify({
"3x3_classes": list(model_class_maps.get('3x3', {}).keys()),
"4x4_classes": list(model_class_maps.get('4x4', {}).keys()),
"supported_aliases": {k: list(v) for k, v in CLASS_ALIASES.items()}
})
# ============================================================================
# πŸš€ APPLICATION STARTUP (FOR LOCAL TESTING)
# ============================================================================
if __name__ == '__main__':
logging.info("πŸš€ Starting Flask server for LOCAL DEVELOPMENT...")
logging.info(f"πŸ“Š Thresholds: 3x3={CONFIDENCE_THRESHOLD_3X3}, 4x4={CONFIDENCE_THRESHOLD_4X4}")
logging.info(f"πŸ”§ Max Workers: {MAX_WORKERS}")
logging.info(f"πŸ’Ύ Cache Size: {MAX_CACHE_SIZE}")
logging.info(f"πŸ” API Keys: {len(API_KEYS)} configured")
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False, threaded=True)