CircleStar commited on
Commit
ca9c54c
·
verified ·
1 Parent(s): a5f8c02

Create predict_utils.py

Browse files
Files changed (1) hide show
  1. predict_utils.py +78 -0
predict_utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from config import IMAGE_SIZE
7
+ from data_utils import get_transform, load_charcoal_dataset
8
+ from train_utils import load_model, get_runtime_device
9
+
10
+
11
+ def predict_uploaded_image(model_name: str, image: Image.Image):
12
+ if not model_name:
13
+ return "Veuillez sélectionner un modèle.", None
14
+
15
+ if image is None:
16
+ return "Veuillez importer une image.", None
17
+
18
+ device = get_runtime_device()
19
+ model, meta = load_model(model_name, device)
20
+
21
+ class_names = meta["config"]["class_names"]
22
+ transform = get_transform()
23
+
24
+ image = image.convert("RGB")
25
+ tensor = transform(image).unsqueeze(0).to(device)
26
+
27
+ with torch.no_grad():
28
+ logits = model(tensor)
29
+ probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
30
+ pred_idx = int(torch.argmax(logits, dim=1).item())
31
+
32
+ result_text = (
33
+ f"Prédiction : {class_names[pred_idx]}\n"
34
+ f"Confiance : {max(probs):.4f}\n\n"
35
+ f"Modèle : {model_name}\n"
36
+ f"Jeu de données : {meta['config']['dataset_name']}\n"
37
+ f"Appareil utilisé : {device}"
38
+ )
39
+
40
+ prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
41
+ return result_text, prob_dict
42
+
43
+
44
+ def test_random_sample(model_name: str):
45
+ if not model_name:
46
+ return None, "Veuillez sélectionner un modèle.", None
47
+
48
+ device = get_runtime_device()
49
+ model, meta = load_model(model_name, device)
50
+
51
+ raw, class_names = load_charcoal_dataset()
52
+ test_dataset = raw["test"]
53
+
54
+ idx = random.randint(0, len(test_dataset) - 1)
55
+ item = test_dataset[idx]
56
+
57
+ image = item["image"].convert("RGB")
58
+ label = int(item["label"])
59
+ label_name = class_names[label]
60
+
61
+ transform = get_transform()
62
+ tensor = transform(image).unsqueeze(0).to(device)
63
+
64
+ with torch.no_grad():
65
+ logits = model(tensor)
66
+ probs = torch.softmax(logits, dim=1).squeeze(0).cpu().tolist()
67
+ pred_idx = int(torch.argmax(logits, dim=1).item())
68
+
69
+ result_text = (
70
+ f"Échantillon test aléatoire\n"
71
+ f"Vérité terrain : {label_name}\n"
72
+ f"Prédiction : {class_names[pred_idx]}\n"
73
+ f"Confiance : {max(probs):.4f}\n"
74
+ f"Appareil utilisé : {device}"
75
+ )
76
+
77
+ prob_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
78
+ return image, result_text, prob_dict