Spaces:
Paused
Paused
| import io | |
| import base64 | |
| import os | |
| import logging | |
| import traceback | |
| import hashlib | |
| from concurrent.futures import ThreadPoolExecutor | |
| from functools import lru_cache | |
| import threading | |
| import gc | |
| import psutil | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from PIL import Image | |
| from ultralytics import YOLO | |
| import numpy as np | |
| import cv2 | |
| import re | |
| import yaml | |
| # ============================================================================ | |
| # π SECURITY CONFIGURATION | |
| # ============================================================================ | |
| SECRET_KEY = os.getenv('SECRET_KEY', 'your-secret-key-here-change-this') | |
| API_KEYS = { | |
| os.getenv('API_KEY_1', 'key1-change-this'), | |
| os.getenv('API_KEY_2', 'key2-change-this'), | |
| os.getenv('API_KEY_3', 'key3-change-this') | |
| } | |
| def verify_api_key(request): | |
| """Verifikasi API key dari header atau query parameter""" | |
| api_key = request.headers.get('X-API-Key') or request.args.get('api_key') | |
| return api_key in API_KEYS | |
| # ============================================================================ | |
| # π PERFORMANCE OPTIMIZATION CONFIGURATION | |
| # ============================================================================ | |
| MAX_WORKERS = min(4, (os.cpu_count() or 1) + 1) | |
| executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
| MEMORY_THRESHOLD = 80 | |
| MAX_CACHE_SIZE = 50 | |
| os.environ['OMP_NUM_THREADS'] = '2' | |
| os.environ['OPENBLAS_NUM_THREADS'] = '2' | |
| os.environ['MKL_NUM_THREADS'] = '2' | |
| # ============================================================================ | |
| # π MONITORING & LOGGING | |
| # ============================================================================ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='[%(asctime)s] [%(levelname)s] %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| class PerformanceMonitor: | |
| def __init__(self): | |
| self.request_count = 0 | |
| self.total_processing_time = 0 | |
| self.lock = threading.Lock() | |
| def log_request(self, processing_time): | |
| with self.lock: | |
| self.request_count += 1 | |
| self.total_processing_time += processing_time | |
| def get_stats(self): | |
| with self.lock: | |
| avg_time = self.total_processing_time / max(self.request_count, 1) | |
| return { | |
| 'requests': self.request_count, | |
| 'avg_processing_time': avg_time, | |
| 'memory_usage': psutil.virtual_memory().percent, | |
| 'cpu_usage': psutil.cpu_percent() | |
| } | |
| monitor = PerformanceMonitor() | |
| # ============================================================================ | |
| # π― CONFIGURATION | |
| # ============================================================================ | |
| app = Flask(__name__) | |
| CORS(app) | |
| CONFIDENCE_THRESHOLD_3X3 = float(os.getenv('CONFIDENCE_3X3', '0.45')) | |
| CONFIDENCE_THRESHOLD_4X4 = float(os.getenv('CONFIDENCE_4X4', '0.25')) | |
| MIN_COVERAGE_PERCENTAGE = int(os.getenv('MIN_COVERAGE', '10')) | |
| # ============================================================================ | |
| # πΎ CACHING SYSTEM | |
| # ============================================================================ | |
| def get_image_hash(image_b64): | |
| """Generate hash for image caching""" | |
| return hashlib.md5(image_b64.encode()).hexdigest() | |
| prediction_cache = {} | |
| cache_lock = threading.Lock() | |
| def cache_prediction(key, result): | |
| with cache_lock: | |
| if len(prediction_cache) >= MAX_CACHE_SIZE: | |
| oldest_keys = list(prediction_cache.keys())[:MAX_CACHE_SIZE//2] | |
| for old_key in oldest_keys: | |
| del prediction_cache[old_key] | |
| prediction_cache[key] = result | |
| def get_cached_prediction(key): | |
| with cache_lock: | |
| return prediction_cache.get(key) | |
| # ============================================================================ | |
| # π€ CLASS ALIASES (OPTIMIZED) | |
| # ============================================================================ | |
| CLASS_ALIASES = { | |
| 'cars': {'car', 'cars', 'mobil', 'mobil-mobil', 'kendaraan', 'auto', 'vehicle', 'vehicles', 'sedan', 'hatchback', 'suv'}, | |
| 'buses': {'bus', 'buses', 'bis', 'autobus', 'bus umum', 'transjakarta'}, | |
| 'bicycles': {'bicycle', 'bicycles', 'sepeda', 'bike', 'bikes'}, | |
| 'motorcycles': {'motorcycle', 'motorcycles', 'sepeda motor', 'motor', 'motorbike'}, | |
| 'taxis': {'taxi', 'taxis', 'taksi', 'cab', 'grab', 'gojek'}, | |
| 'bridge': {'bridge', 'bridges', 'jembatan', 'flyover', 'overpass'}, | |
| 'traffic lights': {'traffic light', 'traffic lights', 'lampu lalu lintas', 'lampu merah'}, | |
| 'a fire hydrant': {'fire hydrant', 'hydrant', 'hidran', 'hidran kebakaran'}, | |
| 'chimneys': {'chimney', 'chimneys', 'cerobong asap', 'cerobong'}, | |
| 'stairs': {'stair', 'stairs', 'tangga', 'steps', 'escalator'}, | |
| 'crosswalks': {'crosswalk', 'crosswalks', 'zebra cross', 'penyeberangan'} | |
| } | |
| # ============================================================================ | |
| # π OPTIMIZED TEXT PROCESSING | |
| # ============================================================================ | |
| def normalize_text(text): | |
| if not text: return "" | |
| text = text.lower().strip() | |
| text = re.sub(r'[^\w\s]', ' ', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def find_class_match(input_text): | |
| if not input_text: return None | |
| normalized_input = normalize_text(input_text) | |
| for canonical_name, aliases in CLASS_ALIASES.items(): | |
| if normalized_input in [alias.lower() for alias in aliases]: | |
| return canonical_name | |
| for canonical_name, aliases in CLASS_ALIASES.items(): | |
| for alias in aliases: | |
| if alias.lower() in normalized_input: | |
| return canonical_name | |
| return None | |
| # ============================================================================ | |
| # π MODEL LOADING & INITIALIZATION (STRATEGI: LAZY PER-WORKER) | |
| # ============================================================================ | |
| _worker_models = {} | |
| _model_lock = threading.Lock() | |
| model_class_maps = {} | |
| def load_yaml_classes(yaml_path): | |
| try: | |
| with open(yaml_path, 'r', encoding='utf-8') as file: | |
| data = yaml.safe_load(file) | |
| return {idx: name for idx, name in enumerate(data.get('names', []))} | |
| except Exception as e: | |
| logging.error(f"Error loading {yaml_path}: {e}") | |
| return {} | |
| try: | |
| for model_type, yaml_file in [('3x3', 'data.yaml'), ('4x4', 'data4x4.yaml')]: | |
| class_map = {} | |
| yaml_classes = load_yaml_classes(yaml_file) | |
| for class_id, class_name in yaml_classes.items(): | |
| canonical_match = find_class_match(class_name.lower()) | |
| if canonical_match: | |
| class_map[canonical_match] = class_id | |
| else: | |
| class_map[class_name.lower()] = class_id | |
| model_class_maps[model_type] = class_map | |
| logging.info("β Class maps loaded successfully!") | |
| except Exception as e: | |
| logging.error(f"β FATAL: Failed to initialize class maps: {e}") | |
| raise | |
| def get_model(model_type: str): | |
| """ | |
| Loads a model only once per worker process (lazy initialization). | |
| This is the robust solution for multi-process servers like Gunicorn. | |
| """ | |
| if model_type in _worker_models: | |
| return _worker_models[model_type] | |
| with _model_lock: | |
| if model_type in _worker_models: | |
| return _worker_models[model_type] | |
| logging.info(f"WORKER_INIT: Loading model '{model_type}' for worker PID: {os.getpid()}...") | |
| model_path, task_type = '', '' | |
| if model_type == '3x3': | |
| model_path, task_type = 'best.onnx', 'classify' | |
| elif model_type == '4x4': | |
| model_path, task_type = 'best4x4.onnx', 'segment' | |
| else: | |
| logging.error(f"Attempted to load unknown model type: {model_type}") | |
| return None | |
| try: | |
| model = YOLO(model_path, task=task_type) | |
| _worker_models[model_type] = model | |
| logging.info(f"WORKER_INIT: Model '{model_type}' loaded successfully for worker PID: {os.getpid()}.") | |
| return model | |
| except Exception as e: | |
| logging.error(f"WORKER_INIT: Failed to load model '{model_path}' for worker PID: {os.getpid()}: {e}") | |
| return None | |
| # ============================================================================ | |
| # πΌοΈ OPTIMIZED IMAGE PROCESSING | |
| # ============================================================================ | |
| def decode_image_optimized(base64_string): | |
| try: | |
| image_data = base64.b64decode(base64_string.split(',')[1]) | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| return image | |
| except Exception as e: | |
| logging.error(f"Image decode error: {e}") | |
| return None | |
| def divide_image_into_4x4_grid(image_cv2): | |
| height, width = image_cv2.shape[:2] | |
| grid_height, grid_width = height // 4, width // 4 | |
| grid_images, grid_coordinates = [], [] | |
| for row in range(4): | |
| for col in range(4): | |
| y1, y2 = row * grid_height, (row + 1) * grid_height if row < 3 else height | |
| x1, x2 = col * grid_width, (col + 1) * grid_width if col < 3 else width | |
| grid_images.append(image_cv2[y1:y2, x1:x2]) | |
| grid_coordinates.append((x1, y1, x2, y2)) | |
| return grid_images, grid_coordinates | |
| def is_object_in_grid_cell(mask_contour, grid_coords, min_coverage_percentage=MIN_COVERAGE_PERCENTAGE): | |
| x1, y1, x2, y2 = grid_coords | |
| grid_width, grid_height = x2 - x1, y2 - y1 | |
| grid_area = grid_width * grid_height | |
| contour_bounds = cv2.boundingRect(mask_contour) | |
| cb_x, cb_y, cb_w, cb_h = contour_bounds | |
| if (cb_x > x2 or cb_x + cb_w < x1 or cb_y > y2 or cb_y + cb_h < y1): return False, 0.0 | |
| grid_mask = np.zeros((grid_height, grid_width), dtype=np.uint8) | |
| adjusted_contour = mask_contour - [x1, y1] | |
| clipped_contour = np.clip(adjusted_contour, [0, 0], [grid_width-1, grid_height-1]) | |
| if len(clipped_contour) < 3: return False, 0.0 | |
| cv2.fillPoly(grid_mask, [clipped_contour.astype(np.int32)], 255) | |
| object_area = np.sum(grid_mask > 0) | |
| coverage_percentage = (object_area / grid_area) * 100 | |
| return coverage_percentage >= min_coverage_percentage, coverage_percentage | |
| # ============================================================================ | |
| # π§ UTILITY FUNCTIONS | |
| # ============================================================================ | |
| def get_target_class_index(input_title, model_type): | |
| model_classes = model_class_maps.get(model_type, {}) | |
| if not input_title or not model_classes: return None | |
| canonical_name = find_class_match(input_title) | |
| if canonical_name and canonical_name in model_classes: | |
| return model_classes[canonical_name] | |
| normalized_input = normalize_text(input_title) | |
| return model_classes.get(normalized_input) | |
| def memory_cleanup(): | |
| gc.collect() | |
| current_memory = psutil.virtual_memory().percent | |
| if current_memory > MEMORY_THRESHOLD: | |
| logging.warning(f"High memory usage: {current_memory}%") | |
| # ============================================================================ | |
| # π‘οΈ MIDDLEWARE | |
| # ============================================================================ | |
| def check_api_key(): | |
| if request.endpoint in ['health', 'stats']: return | |
| if not verify_api_key(request): | |
| return jsonify({"error": "Invalid or missing API key"}), 401 | |
| # ============================================================================ | |
| # π‘ API ENDPOINTS | |
| # ============================================================================ | |
| def health(): | |
| return jsonify({ | |
| "status": "healthy", | |
| "models_loaded_in_worker": len(_worker_models), | |
| "memory_usage": psutil.virtual_memory().percent, | |
| "cpu_usage": psutil.cpu_percent() | |
| }) | |
| def stats(): | |
| return jsonify(monitor.get_stats()) | |
| def predict(): | |
| import time | |
| start_time = time.time() | |
| try: | |
| data = request.get_json(silent=True) | |
| if not data: return jsonify({"error": "Invalid request body"}), 400 | |
| model = get_model('3x3') | |
| if not model: return jsonify({"error": "3x3 model not loaded"}), 500 | |
| input_title = data.get('title', '') | |
| target_class_index = get_target_class_index(input_title, '3x3') | |
| if target_class_index is None: | |
| return jsonify({"indices_to_click": [], "message": f"Class '{input_title}' not found", "available_classes": list(model_class_maps['3x3'].keys())}) | |
| images_hash = hashlib.md5(str(data.get('images', [])).encode()).hexdigest() | |
| cache_key = f"3x3_{input_title}_{images_hash}" | |
| cached_result = get_cached_prediction(cache_key) | |
| if cached_result: | |
| logging.info(f"Cache hit for {cache_key}") | |
| return jsonify(cached_result) | |
| def process_image(item): | |
| try: | |
| image = decode_image_optimized(item['base64']) | |
| if image is None: return None | |
| results = model(image, verbose=False) | |
| if not results: return None | |
| res = results[0] | |
| if res.probs is None or res.probs.data is None: return None | |
| confidence = res.probs.data[target_class_index].item() | |
| return {'index': item['index'], 'confidence': confidence, 'selected': confidence >= CONFIDENCE_THRESHOLD_3X3} | |
| except Exception as e: | |
| logging.error(f"Error processing image {item.get('index', 'N/A')}: {e}", exc_info=False) | |
| return None | |
| with ThreadPoolExecutor(max_workers=min(len(data.get('images', [])), MAX_WORKERS)) as pool: | |
| results = list(pool.map(process_image, data.get('images', []))) | |
| results_to_click = [r['index'] for r in results if r and r['selected']] | |
| response = {"indices_to_click": results_to_click, "detected_class": find_class_match(input_title) or input_title, "total_detected": len(results_to_click)} | |
| cache_prediction(cache_key, response) | |
| processing_time = time.time() - start_time | |
| monitor.log_request(processing_time) | |
| if psutil.virtual_memory().percent > MEMORY_THRESHOLD: memory_cleanup() | |
| return jsonify(response) | |
| except Exception as e: | |
| logging.error(f"Error in /predict: {e}", exc_info=True) | |
| return jsonify({"error": "Internal server error"}), 500 | |
| def predict_4x4(): | |
| """Optimized 4x4 prediction endpoint""" | |
| import time | |
| start_time = time.time() | |
| try: | |
| data = request.get_json(silent=True) | |
| if not data: return jsonify({"error": "Invalid request body"}), 400 | |
| model = get_model('4x4') | |
| if not model: return jsonify({"error": "4x4 model not loaded"}), 500 | |
| input_title = data.get('title', '') | |
| target_class_index = get_target_class_index(input_title, '4x4') | |
| if target_class_index is None: | |
| return jsonify({"indices_to_click": [], "message": f"Class '{input_title}' not found", "available_classes": list(model_class_maps['4x4'].keys())}) | |
| image_hash = get_image_hash(data['image_b64']) | |
| cache_key = f"4x4_{input_title}_{image_hash}" | |
| cached_result = get_cached_prediction(cache_key) | |
| if cached_result: | |
| logging.info(f"Cache hit for {cache_key}") | |
| return jsonify(cached_result) | |
| image_pil = decode_image_optimized(data['image_b64']) | |
| if image_pil is None: return jsonify({"error": "Invalid image data"}), 400 | |
| image_cv2 = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) | |
| # baru perbaikan: Baris ini hilang dan menyebabkan NameError. Kita tambahkan kembali. | |
| grid_images, grid_coordinates = divide_image_into_4x4_grid(image_cv2) | |
| results = model(image_cv2, verbose=False) | |
| indices_to_click = [] | |
| if results and results[0].masks is not None and results[0].boxes is not None: | |
| for mask, box in zip(results[0].masks, results[0].boxes): | |
| class_id = int(box.cls.item()) | |
| confidence = box.conf.item() | |
| if class_id == target_class_index and confidence >= CONFIDENCE_THRESHOLD_4X4: | |
| contour = mask.xy[0].astype(np.int32) | |
| for grid_idx, grid_coords in enumerate(grid_coordinates): | |
| is_selected, coverage = is_object_in_grid_cell(contour, grid_coords) | |
| if is_selected and grid_idx not in indices_to_click: | |
| indices_to_click.append(grid_idx) | |
| response = {"indices_to_click": sorted(indices_to_click), "detected_class": find_class_match(input_title) or input_title, "total_detected": len(indices_to_click)} | |
| cache_prediction(cache_key, response) | |
| processing_time = time.time() - start_time | |
| monitor.log_request(processing_time) | |
| if psutil.virtual_memory().percent > MEMORY_THRESHOLD: memory_cleanup() | |
| return jsonify(response) | |
| except Exception as e: | |
| logging.error(f"Error in /predict_4x4: {e}", exc_info=True) | |
| return jsonify({"error": "Internal server error"}), 500 | |
| def get_available_classes(): | |
| return jsonify({ | |
| "3x3_classes": list(model_class_maps.get('3x3', {}).keys()), | |
| "4x4_classes": list(model_class_maps.get('4x4', {}).keys()), | |
| "supported_aliases": {k: list(v) for k, v in CLASS_ALIASES.items()} | |
| }) | |
| # ============================================================================ | |
| # π APPLICATION STARTUP (FOR LOCAL TESTING) | |
| # ============================================================================ | |
| if __name__ == '__main__': | |
| logging.info("π Starting Flask server for LOCAL DEVELOPMENT...") | |
| logging.info(f"π Thresholds: 3x3={CONFIDENCE_THRESHOLD_3X3}, 4x4={CONFIDENCE_THRESHOLD_4X4}") | |
| logging.info(f"π§ Max Workers: {MAX_WORKERS}") | |
| logging.info(f"πΎ Cache Size: {MAX_CACHE_SIZE}") | |
| logging.info(f"π API Keys: {len(API_KEYS)} configured") | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port, debug=False, threaded=True) |