|
|
import requests |
|
|
import torch |
|
|
from torchvision import transforms, models |
|
|
import torch.nn as nn |
|
|
from PIL import Image |
|
|
import io |
|
|
|
|
|
def classify_image_huggingface(image_path, repo_id="mertincesu/property-room-classifier"): |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else |
|
|
"mps" if torch.backends.mps.is_available() else |
|
|
"cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
model_url = f"https://huggingface.co/{repo_id}/resolve/main/pytorch_model.bin" |
|
|
|
|
|
try: |
|
|
|
|
|
print(f"Downloading model from {model_url}") |
|
|
response = requests.get(model_url) |
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
model_binary = io.BytesIO(response.content) |
|
|
state_dict = torch.load(model_binary, map_location=device) |
|
|
|
|
|
|
|
|
if 'model_state_dict' in state_dict: |
|
|
model_state = state_dict['model_state_dict'] |
|
|
label_to_index = state_dict.get('label_to_index', { |
|
|
'bath': 0, |
|
|
'bed': 1, |
|
|
'din': 2, |
|
|
'kitchen': 3, |
|
|
'living': 4 |
|
|
}) |
|
|
else: |
|
|
model_state = state_dict |
|
|
label_to_index = { |
|
|
'bath': 0, |
|
|
'bed': 1, |
|
|
'din': 2, |
|
|
'kitchen': 3, |
|
|
'living': 4 |
|
|
} |
|
|
|
|
|
|
|
|
num_classes = model_state['fc.weight'].shape[0] |
|
|
|
|
|
|
|
|
model = models.resnet50(weights=None) |
|
|
num_ftrs = model.fc.in_features |
|
|
model.fc = nn.Linear(num_ftrs, num_classes) |
|
|
|
|
|
|
|
|
model.load_state_dict(model_state) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
img = Image.open(image_path).convert('RGB') |
|
|
img_tensor = transform(img).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(img_tensor) |
|
|
probs = torch.nn.functional.softmax(outputs, dim=1) |
|
|
_, predicted = torch.max(outputs, 1) |
|
|
|
|
|
|
|
|
index_to_label = {v: k for k, v in label_to_index.items()} |
|
|
|
|
|
result = { |
|
|
'class': index_to_label[predicted.item()], |
|
|
'confidence': probs[0][predicted.item()].item() * 100 |
|
|
} |
|
|
|
|
|
print(f"Class: {result['class']}") |
|
|
print(f"Confidence: {result['confidence']:.2f}%") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
if len(sys.argv) > 1: |
|
|
image_path = sys.argv[1] |
|
|
else: |
|
|
image_path = input("Enter path to image: ") |
|
|
|
|
|
classify_image_huggingface(image_path) |