Tejasn commited on
Commit
25ae2ed
1 Parent(s): 9df472f
Files changed (1) hide show
  1. utils.py +19 -1
utils.py CHANGED
@@ -10,7 +10,7 @@ import torch.nn.functional as F
10
  import PIL
11
  import PIL.Image as Image
12
  import numpy as np
13
-
14
 
15
 
16
  classes_outside_india = ['apple pie', 'baby back ribs', 'baklava', 'beef carpaccio', 'beef tartare',
@@ -40,6 +40,24 @@ classes_india = ['burger','butter_naan', 'chai', 'chapati', 'chole_bhature', 'da
40
  'pizza', 'samosa']
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def make_pred_outside_india(input_img, model, device, user_location):
44
  input_img = input_img.unsqueeze(0)
45
  model.eval()
 
10
  import PIL
11
  import PIL.Image as Image
12
  import numpy as np
13
+ from transformers import CLIPProcessor, CLIPModel
14
 
15
 
16
  classes_outside_india = ['apple pie', 'baby back ribs', 'baklava', 'beef carpaccio', 'beef tartare',
 
40
  'pizza', 'samosa']
41
 
42
 
43
+ def food_nofood_pred(input_image):
44
+ # input labels for clip model
45
+ labels = ['food', 'not food']
46
+
47
+ # CLIP Model for classification
48
+ food_nofood_model = CLIPModel.from_pretrained("flax-community/clip-rsicd-v2")
49
+ processor = CLIPProcessor.from_pretrained("flax-community/clip-rsicd-v2")
50
+
51
+ # image = Image.open(requests.get(uploaded_file, stream=True).raw)
52
+ inputs = processor(text=[f"a photo of a {l}" for l in labels], images=input_image, return_tensors="pt", padding=True)
53
+ outputs = food_nofood_model(**inputs)
54
+ logits_per_image = outputs.logits_per_image
55
+ probs = logits_per_image.softmax(dim=1)
56
+ print(probs)
57
+ pred = probs.detach().cpu().numpy().argmax(axis=1)
58
+ pred_class = labels[pred[0]]
59
+ return pred_class
60
+
61
  def make_pred_outside_india(input_img, model, device, user_location):
62
  input_img = input_img.unsqueeze(0)
63
  model.eval()