swjin commited on
Commit
59198bc
1 Parent(s): 57205a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -20
app.py CHANGED
@@ -6,25 +6,7 @@ import numpy as np
6
  from PIL import Image
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
- from transformers import AutoModelForImageClassification, AutoFeatureExtractor
10
- import requests
11
- from io import BytesIO
12
-
13
- model_name = "facebook/deit-base-distilled-patch16"
14
- model = AutoModelForImageClassification.from_pretrained(model_name)
15
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
16
-
17
- def classify_image(image):
18
- # Load and preprocess the image
19
- image = Image.open(BytesIO(image))
20
- inputs = feature_extractor(images=image, return_tensors="pt")
21
-
22
- # Perform image classification
23
- with torch.no_grad():
24
- outputs = model(**inputs)
25
- predicted_class = outputs.logits.argmax().item()
26
-
27
- return model.config.id2label[predicted_class]
28
 
29
 
30
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
@@ -123,7 +105,7 @@ def sepia(input_img):
123
 
124
 
125
 
126
- demo = gr.Interface(fn=classify_image,
127
  inputs=gr.Image(shape=(400, 600)),
128
  outputs=['plot'],
129
  title="SWJIN11 TASK",
 
6
  from PIL import Image
7
  import tensorflow as tf
8
  from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  feature_extractor = SegformerFeatureExtractor.from_pretrained(
 
105
 
106
 
107
 
108
+ demo = gr.Interface(fn=draw_plot,
109
  inputs=gr.Image(shape=(400, 600)),
110
  outputs=['plot'],
111
  title="SWJIN11 TASK",