Spaces:
Sleeping
Sleeping
import os | |
import time | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from pathlib import Path | |
from ultralytics import YOLO | |
import io | |
import base64 | |
import uuid | |
import glob | |
from tensorflow import keras | |
from flask import Flask, jsonify, request, render_template, send_file | |
import torch | |
from collections import Counter | |
import psutil | |
from gradio_client import Client, handle_file | |
from io import BytesIO | |
# Disable tensorflow warnings | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
load_type = 'local' | |
MODEL_YOLO = "yolo11_detect_best_241024_1.pt" | |
MODEL_DIR = "./artifacts/models" | |
YOLO_DIR = "./artifacts/yolo" | |
GRADIO_URL = "https://a0c594662477a008f4.gradio.live/" | |
# Load the saved YOLO model into memory | |
if load_type == 'local': | |
# 本地模型路徑 | |
model_path = f'{MODEL_DIR}/{MODEL_YOLO}' | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Model file not found at {model_path}") | |
model = YOLO(model_path) | |
print("***** FLASK API---LOAD YOLO MODEL DONE *****") | |
#model.eval() # 設定模型為推理模式 | |
elif load_type == 'remote_hub_download': | |
from huggingface_hub import hf_hub_download | |
# 從 Hugging Face Hub 下載模型 | |
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_YOLO) | |
model = torch.load(model_path) | |
#model.eval() | |
elif load_type == 'remote_hub_from_pretrained': | |
# 使用 Hugging Face Hub 預訓練的模型方式下載 | |
os.environ['TRANSFORMERS_CACHE'] = str(Path(MODEL_DIR).absolute()) | |
from huggingface_hub import from_pretrained | |
model = from_pretrained(REPO_ID, filename=MODEL_YOLO, cache_dir=MODEL_DIR) | |
#model.eval() | |
else: | |
raise AssertionError('No load type is specified!') | |
# image to base64 | |
def image_to_base64(image_path): | |
with open(image_path, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
return encoded_string | |
# 抓取指定路徑下的所有 JPG 檔案 | |
def get_jpg_files(path): | |
""" | |
Args: | |
path: 要搜尋的目錄路徑。 | |
Returns: | |
一個包含所有 JPG 檔案路徑的列表。 | |
""" | |
return glob.glob(os.path.join(path, "*.jpg")) | |
# 使用範例 | |
# image_folder = '/content/drive/MyDrive/chiikawa' # 替換成你的目錄路徑 | |
# jpg_files = get_jpg_files(image_folder) | |
def clip_model (choice="find_similar_words",image=None,word=None): | |
client = Client(GRADIO_URL) | |
# 當 image 存在時才處理 | |
if image is not None: | |
image_input = handle_file(image) | |
else: | |
image_input = None | |
try: | |
clip_result = client.predict( | |
choice=choice, | |
image=image_input, | |
word=word, | |
top_k=3, | |
api_name="/run_function" | |
) | |
except Exception as e: | |
return f"Error occurred while processing the request: {e}" | |
return clip_result | |
def check_memory_usage(): | |
# Get memory details | |
memory_info = psutil.virtual_memory() | |
total_memory = memory_info.total / (1024 * 1024) # Convert bytes to MB | |
available_memory = memory_info.available / (1024 * 1024) | |
used_memory = memory_info.used / (1024 * 1024) | |
memory_usage_percent = memory_info.percent | |
print(f"^^^^^^ Total Memory: {total_memory:.2f} MB ^^^^^^") | |
print(f"^^^^^^ Available Memory: {available_memory:.2f} MB ^^^^^^") | |
print(f"^^^^^^ Used Memory: {used_memory:.2f} MB ^^^^^^") | |
print(f"^^^^^^ Memory Usage (%): {memory_usage_percent}% ^^^^^^") | |
# Run the function | |
check_memory_usage() | |
# Initialize the Flask application | |
app = Flask(__name__) | |
# API route for prediction(YOLO) | |
def predict(): | |
#user_id = request.args.get('user_id') | |
file = request.files['image'] | |
message_id = request.form.get('message_id') #str(uuid.uuid4()) | |
choice = request.form.get('choice') | |
word = request.form.get('word') | |
if 'image' not in request.files: | |
# Handle if no file is selected | |
return jsonify({"error": "No image part"}), 400 | |
# 讀取圖像 | |
try: | |
image_data = Image.open(file) | |
except Exception as e: | |
return jsonify({'error': str(e)}), 400 | |
print("***** FLASK API---/predict Start YOLO predict *****") | |
# Make a prediction using YOLO | |
results = model(image_data) | |
print ("===== FLASK API---/predict YOLO predict result:",results,"=====") | |
print("***** FLASK API---/predict YOLO predict DONE *****") | |
check_memory_usage() | |
# 檢查 YOLO 是否返回了有效的結果 | |
if results is None or len(results) == 0: | |
return jsonify({'error': 'No results from YOLO model'}), 400 | |
saved_images = [] | |
# 儲存辨識後的圖片到指定資料夾 | |
for result in results: | |
encoded_images=[] | |
element_list =[] | |
top_k_words =[] | |
# 保存圖片 | |
result.save_crop(f"{YOLO_DIR}/{message_id}") | |
num_detections = len(result.boxes) # Get the number of detections | |
labels = result.boxes.cls # Get predicted label IDs | |
label_names = [model.names[int(label)] for label in labels] # Convert to names | |
print(f"====== FLASK API---/predict 3. YOLO label_names: {label_names}======") | |
element_counts = Counter(label_names) | |
for element, count in element_counts.items(): | |
yolo_path = f"{YOLO_DIR}/{message_id}/{element}" | |
yolo_file = get_jpg_files(yolo_path) | |
print(f"***** FLASK API---/predict 處理:{yolo_path} *****") | |
if len(yolo_file) == 0: | |
print(f" FLASK API---/predict 警告:{element} 沒有找到相關的 JPG 檔案") | |
continue | |
for yolo_img in yolo_file: # 每張切圖yolo_img | |
print("***** FLASK API---/predict 4. START CLIP *****") | |
clip_result = clip_model(choice,yolo_img,word) | |
top_k_words.append(clip_result[0]) # CLIP預測3個結果(top_k_words) | |
encoded_images.append(image_to_base64(yolo_img)) | |
element_list.append(element) | |
print(f"===== FLASK API---/predict CLIP RESULT:{top_k_words} =====\n") | |
# 刪除已處理的圖片文件 | |
print(f"===== FLASK API---/predict DELETE yolo_img:{yolo_img} =====\n") | |
os.remove(yolo_img) | |
# 建立回應資料 | |
response_data = { | |
'message_id': message_id, | |
'objects': [ | |
{ | |
'element': element, | |
'images': | |
{ | |
'encoded_image': encoded_image, | |
'description_list': description_list | |
} | |
} | |
for element, encoded_image, description_list in zip(element_list, encoded_images, top_k_words) | |
] | |
} | |
return jsonify(response_data), 200 | |
# API route for health check | |
def text2img(): | |
message_id = request.form.get('message_id') | |
choice = request.form.get('choice') | |
word = request.form.get('word') | |
clip_result = clip_model(choice,None,word) | |
print(f"===== FLASK API---/text2img 文字轉圖片result:{clip_result} =====") | |
result_img = clip_result[2] # 已經是base64 coded | |
# 建立回應資料 | |
response_data = { | |
'message_id': message_id, | |
'encoded_image': result_img, | |
'description': clip_result[0] | |
} | |
return jsonify(response_data), 200 | |
# API route for version | |
def version(): | |
""" | |
Returns the version of the application. | |
Demo Usage: "curl http://127.0.0.1:5000/version" or using alias "curl http://127.0.0.1:5000/version" | |
""" | |
return '1.0' | |
def hello_world(): | |
return render_template("index.html") | |
# return "<p>Hello, Team!</p>" | |
# Start the Flask application | |
if __name__ == '__main__': | |
app.run(debug=True) | |