swjin commited on
Commit
57205a9
1 Parent(s): 5b3c649

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -1
app.py CHANGED
@@ -6,6 +6,26 @@ import numpy as np
6
  from PIL import Image
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
11
  "nvidia/segformer-b0-finetuned-cityscapes-512-1024"
@@ -102,7 +122,8 @@ def sepia(input_img):
102
  return fig
103
 
104
 
105
- demo = gr.Interface(fn=sepia,
 
106
  inputs=gr.Image(shape=(400, 600)),
107
  outputs=['plot'],
108
  title="SWJIN11 TASK",
 
6
  from PIL import Image
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
10
+ import requests
11
+ from io import BytesIO
12
+
13
+ model_name = "facebook/deit-base-distilled-patch16"
14
+ model = AutoModelForImageClassification.from_pretrained(model_name)
15
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
16
+
17
+ def classify_image(image):
18
+ # Load and preprocess the image
19
+ image = Image.open(BytesIO(image))
20
+ inputs = feature_extractor(images=image, return_tensors="pt")
21
+
22
+ # Perform image classification
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
+ predicted_class = outputs.logits.argmax().item()
26
+
27
+ return model.config.id2label[predicted_class]
28
+
29
 
30
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
31
  "nvidia/segformer-b0-finetuned-cityscapes-512-1024"
 
122
  return fig
123
 
124
 
125
+
126
+ demo = gr.Interface(fn=classify_image,
127
  inputs=gr.Image(shape=(400, 600)),
128
  outputs=['plot'],
129
  title="SWJIN11 TASK",