Vishaltiwari2019 commited on
Commit
f70b16e
1 Parent(s): 2aec2ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -27
app.py CHANGED
@@ -1,32 +1,22 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- from PIL import Image
4
- import requests
5
- from io import BytesIO
6
- import numpy as np
7
 
8
- # Load the pre-trained model and tokenizer
9
- model_name = "distilbert/distilbert-base-uncased"
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
12
 
13
- # Function to preprocess the image
14
- def preprocess_image(image):
15
- image = Image.open(BytesIO(image))
16
- image = image.resize((256, 256)) # Resize the image to match the model's input size
17
- return np.array(image)
 
 
 
 
18
 
19
- # Function to make predictions
20
- def classify_image(image):
21
- image = preprocess_image(image)
22
- inputs = tokenizer(image, return_tensors="pt", padding=True, truncation=True)
23
- outputs = model(**inputs)
24
- logits = outputs.logits.detach().numpy()[0]
25
- probabilities = np.exp(logits) / np.exp(logits).sum(-1)
26
- predicted_class = np.argmax(probabilities)
27
- return {str(i): float(prob) for i, prob in enumerate(probabilities)}
28
 
29
- # Create a Gradio interface
30
- input_image = gr.inputs.Image(shape=(256, 256))
31
- output_label = gr.outputs.Label(num_top_classes=3)
32
- gr.Interface(classify_image, inputs=input_image, outputs=output_label).launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
 
 
 
 
3
 
4
+ # Load the DETR object detection pipeline
5
+ object_detection_pipeline = pipeline("object-detection", model="facebook/detr-resnet-50")
 
 
6
 
7
+ # Function to perform object detection on an image
8
+ def detect_objects(image):
9
+ # Perform object detection
10
+ results = object_detection_pipeline(image)
11
+ # Extract bounding boxes and object labels
12
+ bounding_boxes = [obj["bbox"] for obj in results]
13
+ labels = [obj["label"] for obj in results]
14
+ # Return bounding boxes and labels
15
+ return bounding_boxes, labels
16
 
17
+ # Define Gradio interface
18
+ inputs = gr.inputs.Image()
19
+ outputs = gr.outputs.ObjectDetection(labels=["person", "car", "truck", "bicycle", "motorcycle"]) # Customize labels as needed
 
 
 
 
 
 
20
 
21
+ # Create Gradio interface
22
+ gr.Interface(fn=detect_objects, inputs=inputs, outputs=outputs, title="Object Detection", description="Detect objects in images.").launch()