karan99300 commited on
Commit
40ed35c
1 Parent(s): 014bfb8

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +11 -0
  2. app.py +50 -0
  3. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import requests
5
+ import torch
6
+
7
+ # Initialize Flask app
8
+ app = Flask(__name__)
9
+
10
+ # Load pre-trained model and feature extractor
11
+ feature_extractor = AutoFeatureExtractor.from_pretrained('karan99300/ConvNext-finetuned-CIFAR100')
12
+ model = AutoModelForImageClassification.from_pretrained('karan99300/ConvNext-finetuned-CIFAR100')
13
+
14
+ # Define route for home page with form
15
+ @app.route('/', methods=['GET', 'POST'])
16
+ def index():
17
+ if request.method == 'POST':
18
+ # Get image URL from form submission
19
+ image_url = request.form['image_url']
20
+
21
+ # Classify image
22
+ predicted_class = classify_image(image_url)
23
+
24
+ return render_template('index.html', predicted_class=predicted_class, image_url=image_url)
25
+
26
+ return render_template('index.html')
27
+
28
+ # Function to classify image
29
+ def classify_image(image_url):
30
+ # Fetch image from URL
31
+ try:
32
+ image = Image.open(requests.get(image_url, stream=True).raw)
33
+ except Exception as e:
34
+ return f'Error fetching image: {str(e)}'
35
+
36
+ # Preprocess image and perform inference
37
+ pixel_values = feature_extractor(image.convert('RGB'), return_tensors='pt').pixel_values
38
+ with torch.no_grad():
39
+ outputs = model(pixel_values)
40
+ logits = outputs.logits
41
+ predicted_class_idx = logits.argmax(-1).item()
42
+
43
+ # Get predicted label
44
+ predicted_label = model.config.id2label[predicted_class_idx]
45
+
46
+ return predicted_label
47
+
48
+ # Run Flask app
49
+ if __name__ == '__main__':
50
+ app.run(debug=True,port=5000)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ flask
2
+ numpy
3
+ Pillow
4
+ torch
5
+ transformers
6
+ requests
7
+ uvicorn