GyroScope / benchmark.py
LH-Tech-AI's picture
Create benchmark.py
e699431 verified
print(f"[*] Setting up...")
import torch
import requests
import random
import numpy as np
from io import BytesIO
from PIL import Image
from torchvision import transforms
from transformers import ResNetForImageClassification
from collections import Counter
# --- 1. CONFIGURATION & SETUP ---
ANGLES = [0, 90, 180, 270]
NUM_IMAGES = 500
MODEL_NAME = "LH-Tech-AI/GyroScope"
IMG_SOURCE_URL = "https://loremflickr.com/400/400/all"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[*] Using device: {device}")
# Modell laden
print(f"[*] Loading model {MODEL_NAME}...")
model = ResNetForImageClassification.from_pretrained(MODEL_NAME)
model.eval()
model.to(device)
# Vorverarbeitung
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
results = []
# --- 2. EVALUATIONS-LOOP ---
print(f"[*] Starting download and evaluation of {NUM_IMAGES} images (In-Memory)...")
for i in range(1, NUM_IMAGES + 1):
try:
# Load image into RAM
response = requests.get(f"{IMG_SOURCE_URL}?random={i}", timeout=10)
img = Image.open(BytesIO(response.content)).convert("RGB")
# Apply random rotation
true_angle = random.choice(ANGLES)
label_idx = ANGLES.index(true_angle)
# Rotate image
rotated_img = img.rotate(true_angle, expand=True)
# Prediction
tensor = preprocess(rotated_img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(pixel_values=tensor).logits
pred_idx = logits.argmax().item()
is_correct = (pred_idx == label_idx)
results.append({
"true": true_angle,
"pred": ANGLES[pred_idx],
"correct": is_correct
})
status = "✓" if is_correct else "✗"
percent = (i / NUM_IMAGES) * 100
bar_length = 20
filled_length = int(bar_length * i // NUM_IMAGES)
bar = '#' * filled_length + ' ' * (bar_length - filled_length)
status = "✓" if is_correct else "✗"
print(f"\rProgress: [{bar}] {percent:.1f}% ({i}/{NUM_IMAGES}) | Last result: {status}", end="")
except Exception as e:
print(f"\n[!] Error processing image {i}: {e}")
# --- 3. RESULTS ---
print("\n\n" + "="*15)
print(" RESULTS")
print("="*15)
total_correct = sum(1 for r in results if r['correct'])
accuracy = (total_correct / len(results)) * 100
print(f"Overall result: {total_correct}/{len(results)} correct")
print(f"Hit rate: {accuracy:.2f} %")
print("-" * 30)
print("Details per rotation class:")
for angle in ANGLES:
class_results = [r for r in results if r['true'] == angle]
if class_results:
correct_in_class = sum(1 for r in class_results if r['correct'])
class_acc = (correct_in_class / len(class_results)) * 100
print(f" {angle:>3}° : {correct_in_class:>2}/{len(class_results):>2} correct ({class_acc:>6.2f}%)")
print("="*30)
# Result of our benchmark:
# ===============
# RESULTS
# ===============
# Overall result: 411/500 correct
# Hit rate: 82.20 %
# ------------------------------
# Details per rotation class:
# 0° : 96/124 correct ( 77.42%)
# 90° : 103/119 correct ( 86.55%)
# 180° : 112/129 correct ( 86.82%)
# 270° : 100/128 correct ( 78.12%)
# ==============================