moshel commited on
Commit
cc885f9
·
1 Parent(s): ec5990a
app.py CHANGED
@@ -13,15 +13,7 @@ for key in list(model_weights):
13
 
14
 
15
  def get_model():
16
- model = timm.create_model('tf_efficientnet_b1', pretrained=True, num_classes=2, global_pool='catavgmax')
17
- num_in_features = model.get_classifier().in_features
18
- from torch import nn
19
-
20
- model.fc = nn.Sequential(
21
- nn.Linear(in_features=num_in_features, out_features=1024, bias=False),
22
- nn.ReLU(),
23
- nn.Linear(in_features=1024, out_features=2, bias=False),
24
- )
25
 
26
  return model
27
 
@@ -33,15 +25,26 @@ model.eval()
33
  import requests
34
  from PIL import Image
35
  from torchvision import transforms
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # Download human-readable labels for ImageNet.
38
  labels = ['good', 'ill']
39
- CROP=384
40
 
41
  def predict(inp):
42
- img = torchvision.transforms.ToTensor()(inp)
43
- img = torchvision.transforms.Resize((800, 800))(img)
44
- img = torchvision.transforms.CenterCrop(CROP)(img)
45
  img = img.unsqueeze(0)
46
  with torch.no_grad():
47
  prediction = model(img).softmax(1).numpy()
@@ -51,7 +54,7 @@ def predict(inp):
51
  import gradio as gr
52
 
53
  gr.Interface(fn=predict,
54
- inputs=gr.Image(type="pil"),
55
  outputs=gr.Label(num_top_classes=1),
56
  ).launch()
57
 
 
13
 
14
 
15
  def get_model():
16
+ model = timm.create_model('convnext_base.fb_in22k_ft_in1k', pretrained=True, num_classes=2)
 
 
 
 
 
 
 
 
17
 
18
  return model
19
 
 
25
  import requests
26
  from PIL import Image
27
  from torchvision import transforms
28
+ import albumentations as A
29
+
30
+ CROP = 224
31
+ SIZE = CROP + CROP//8
32
+
33
+ ho_trans_center = A.Compose([
34
+ A.Resize(SIZE,SIZE, interpolation=cv2.INTER_AREA),
35
+ A.CenterCrop(height=CROP, width=CROP, always_apply=True),
36
+ ])
37
+ topt = A.Compose([
38
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
39
+ ToTensorV2(),
40
+ ])
41
 
42
  # Download human-readable labels for ImageNet.
43
  labels = ['good', 'ill']
 
44
 
45
  def predict(inp):
46
+ img = ho_trans_center(image = inp)['image']
47
+ img = topt(image = img)['image']
 
48
  img = img.unsqueeze(0)
49
  with torch.no_grad():
50
  prediction = model(img).softmax(1).numpy()
 
54
  import gradio as gr
55
 
56
  gr.Interface(fn=predict,
57
+ inputs=gr.Image(),
58
  outputs=gr.Label(num_top_classes=1),
59
  ).launch()
60
 
v5-epoch=19-val_loss=0.1464-val_accuracy=0.9514.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f34c9f1ef5bf747a84a52eff907ffafb9f37c7a023a4ea9e5b736fbc6e4156be
3
+ size 1051254575