zergswim commited on
Commit
2411be5
1 Parent(s): 68e6e09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -1,24 +1,24 @@
1
  from transformers import AutoFeatureExtractor, ResNetForImageClassification
2
  import torch
3
- from datasets import load_dataset
4
 
5
- dataset = load_dataset("huggingface/cats-image")
6
- image = dataset["test"]["image"][0]
7
 
8
  feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
9
  model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
10
 
11
- inputs = feature_extractor(image, return_tensors="pt")
12
-
13
- with torch.no_grad():
14
- logits = model(**inputs).logits
15
-
16
- # model predicts one of the 1000 ImageNet classes
17
- predicted_label = logits.argmax(-1).item()
18
- print(model.config.id2label[predicted_label])
19
-
20
  import gradio as gr
21
  def segment(image):
22
- pass # Implement your image segmentation model here...
 
 
 
 
 
 
 
 
 
23
 
24
- gr.Interface(fn=segment, inputs="image", outputs="image").launch()
 
1
  from transformers import AutoFeatureExtractor, ResNetForImageClassification
2
  import torch
3
+ # from datasets import load_dataset
4
 
5
+ # dataset = load_dataset("huggingface/cats-image")
6
+ # image = dataset["test"]["image"][0]
7
 
8
  feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
9
  model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
10
 
 
 
 
 
 
 
 
 
 
11
  import gradio as gr
12
  def segment(image):
13
+ inputs = feature_extractor(image, return_tensors="pt")
14
+
15
+ with torch.no_grad():
16
+ logits = model(**inputs).logits
17
+
18
+ # model predicts one of the 1000 ImageNet classes
19
+ predicted_label = logits.argmax(-1).item()
20
+ # print(model.config.id2label[predicted_label])
21
+
22
+ return model.config.id2label[predicted_label]
23
 
24
+ gr.Interface(fn=segment, inputs="image", outputs="label").launch()