Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision | |
| from torchvision.models.detection import FasterRCNN | |
| from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
| from torchvision.transforms import functional as F | |
| from PIL import Image, ImageDraw, ImageFont | |
| import gradio as gr | |
| # Force CPU | |
| device = torch.device('cpu') | |
| # COCO-style class map | |
| COCO_CLASSES = { | |
| 0: "Background", | |
| 1: "Stand", | |
| 2: "Sit", | |
| 3: "Ruku", | |
| 4: "Sijdah" | |
| } | |
| # Load model | |
| def get_model(num_classes): | |
| model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False) | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
| return model | |
| model = get_model(num_classes=5) | |
| model.load_state_dict(torch.load("Salatfasterrcnn_resnet50_epoch_3.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Prediction function | |
| def predict(image): | |
| image = image.convert("RGB") | |
| image_tensor = F.to_tensor(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| prediction = model(image_tensor) | |
| draw = ImageDraw.Draw(image) | |
| boxes = prediction[0]["boxes"].cpu().numpy() | |
| labels = prediction[0]["labels"].cpu().numpy() | |
| scores = prediction[0]["scores"].cpu().numpy() | |
| for box, label, score in zip(boxes, labels, scores): | |
| if score > 0.5: | |
| x_min, y_min, x_max, y_max = box | |
| class_name = COCO_CLASSES.get(label, "Unknown") | |
| draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) | |
| draw.text((x_min, y_min), f"{class_name} ({score:.2f})", fill="red") | |
| return image | |
| # Gradio interface | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Image(type="pil"), | |
| title="Salat Posture Detection", | |
| description="Upload an image to detect salat postures (stand, sit, ruku, sijdah)." | |
| ).launch() | |