nekofura commited on
Commit
14be786
1 Parent(s): 517300d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -11
app.py CHANGED
@@ -3,7 +3,6 @@ from torchvision import models, transforms
3
  from PIL import Image
4
  import gradio as gr
5
 
6
- # Mendefinisikan nama kelas
7
  class_names = [
8
  "calculus",
9
  "caries",
@@ -13,26 +12,21 @@ class_names = [
13
  "tooth_discoloration"
14
  ]
15
 
16
- # Mengatur jumlah kelas
17
  num_classes = len(class_names)
18
 
19
- # Membuat dan mengkonfigurasi model
20
- model = models.resnet50(pretrained=False)
21
  model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
22
 
23
- # Memuat bobot model (sesuaikan path jika diperlukan)
24
  model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
25
  model.eval()
26
 
27
- # Mengatur transformasi preprocessing
28
  preprocess = transforms.Compose([
29
  transforms.Resize((224, 224)),
30
  transforms.ToTensor(),
31
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
  ])
33
 
34
- # Fungsi untuk melakukan prediksi
35
- def predict_image(image, model, preprocess, class_names):
36
  processed_image = preprocess(image).unsqueeze(0)
37
 
38
  with torch.no_grad():
@@ -42,14 +36,12 @@ def predict_image(image, model, preprocess, class_names):
42
 
43
  return predicted_class
44
 
45
- # Membuat interface Gradio
46
  iface = gr.Interface(
47
  fn=predict_image,
48
- inputs=[gr.inputs.Image(type='pil')],
49
  outputs=gr.outputs.Label(num_top_classes=1),
50
  title="Klasifikasi Gambar Medis",
51
  description="Upload gambar untuk memprediksi kelasnya."
52
  )
53
 
54
- # Menjalankan aplikasi Gradio
55
  iface.launch()
 
3
  from PIL import Image
4
  import gradio as gr
5
 
 
6
  class_names = [
7
  "calculus",
8
  "caries",
 
12
  "tooth_discoloration"
13
  ]
14
 
 
15
  num_classes = len(class_names)
16
 
17
+ model = models.resnet50(weights=None)
 
18
  model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
19
 
 
20
  model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
21
  model.eval()
22
 
 
23
  preprocess = transforms.Compose([
24
  transforms.Resize((224, 224)),
25
  transforms.ToTensor(),
26
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
  ])
28
 
29
+ def predict_image(image, model=model, preprocess=preprocess, class_names=class_names):
 
30
  processed_image = preprocess(image).unsqueeze(0)
31
 
32
  with torch.no_grad():
 
36
 
37
  return predicted_class
38
 
 
39
  iface = gr.Interface(
40
  fn=predict_image,
41
+ inputs=gr.inputs.Image(type='pil'),
42
  outputs=gr.outputs.Label(num_top_classes=1),
43
  title="Klasifikasi Gambar Medis",
44
  description="Upload gambar untuk memprediksi kelasnya."
45
  )
46
 
 
47
  iface.launch()