SeyedAli commited on
Commit
c95621e
1 Parent(s): e35723c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -5,8 +5,9 @@ import torch
5
  from torchvision.io import read_image
6
  from transformers import ViTImageProcessor,pipeline
7
 
8
- model = ViTImageProcessor.from_pretrained('SeyedAli/Food-Image-Classification-VIT')
9
-
 
10
  def FoodClassification(image):
11
  with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file:
12
  # Copy the contents of the uploaded image file to the temporary file
@@ -15,9 +16,13 @@ def FoodClassification(image):
15
  Image.fromarray(image).save(temp_image_file.name)
16
  # Load the image file using torchvision
17
  image = read_image(temp_image_file.name)
18
- pipline = pipeline(task="image-classification", model=model)
19
- output=pipline(image, return_tensors='pt')
20
- return output
 
 
 
 
21
 
22
  iface = gr.Interface(fn=FoodClassification, inputs="image", outputs="text")
23
  iface.launch(share=False)
 
5
  from torchvision.io import read_image
6
  from transformers import ViTImageProcessor,pipeline
7
 
8
+ # model = ViTImageProcessor.from_pretrained('SeyedAli/Food-Image-Classification-VIT')
9
+ model = ViTForImageClassification.from_pretrained('SeyedAli/Food-Image-Classification-VIT')
10
+ feature_extractor = ViTFeatureExtractor.from_pretrained('SeyedAli/Food-Image-Classification-VIT')
11
  def FoodClassification(image):
12
  with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file:
13
  # Copy the contents of the uploaded image file to the temporary file
 
16
  Image.fromarray(image).save(temp_image_file.name)
17
  # Load the image file using torchvision
18
  image = read_image(temp_image_file.name)
19
+ # Preprocess the image using the ViT feature extractor
20
+ inputs = feature_extractor(images=image, return_tensors="pt")
21
+ # Use the ViT model for image classification
22
+ outputs = model(**inputs)
23
+ predicted_class_idx = torch.argmax(outputs.logits)
24
+ predicted_class = model.config.id2label[predicted_class_idx.item()]
25
+ return predicted_class
26
 
27
  iface = gr.Interface(fn=FoodClassification, inputs="image", outputs="text")
28
  iface.launch(share=False)