from transformers import AutoFeatureExtractor, AutoModelForImageClassification from fastapi import FastAPI, Response from fastapi.middleware.cors import CORSMiddleware import requests import torch import base64 import traceback from ultralyticsplus import YOLO from PIL import Image, ImageDraw from io import BytesIO device = "cuda:0" if torch.cuda.is_available() else "cpu" extractor = AutoFeatureExtractor.from_pretrained("rizvandwiki/gender-classification") model_gender = AutoModelForImageClassification.from_pretrained("rizvandwiki/gender-classification") model_gender = model_gender.to(device) safe_img_base64 = "" safe_img_bytes = BytesIO(base64.b64decode(safe_img_base64)) safe_img = Image.open(safe_img_bytes) model_yolo = YOLO('kadirnar/yolov8n-v8.0') image_size = 640 model_yolo.conf = 0.25 model_yolo.iou = 0.45 app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def yolov8(img): results = model_yolo.predict(source=img, imgsz=image_size) object_prediction_list = [] for image_results in results: for box in image_results.boxes: x1, y1, x2, y2 = ( int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), ) bbox = [x1, y1, x2, y2] score = float(box.conf) object_prediction_list.append([bbox, score]) return object_prediction_list @app.get("/", responses = { 200: { "content": {"image/png": {}} } }, response_class=Response) def main(url): try: response = requests.get(url) if ".svg" in url: return Response(content=response.content, media_type="image/svg+xml") # if ".ico" in url: # return Response(content=response.content, media_type="image/icon") img_bytes = BytesIO(response.content) img = Image.open(img_bytes) objects = yolov8(img) for obj in objects: left, top, right, bottom = obj[0] # bbox crop = img.crop((left, top, right, bottom)) inputs = extractor(crop, return_tensors="pt").to(device) with torch.no_grad(): logits = model_gender(**inputs).logits logits = logits.softmax(-1) predicted_label = logits.argmax(-1).item() percentage = logits[0][predicted_label] label = model_gender.config.id2label[predicted_label] if label == "female" and percentage > 0.79: return Response(content=safe_img_bytes.getvalue(), media_type="image/jpeg") return Response(content=img_bytes.getvalue(), media_type="image/jpeg") except Exception as e: print(traceback.format_exc()) return Response(content=img_bytes.getvalue(), media_type="image/jpeg")