File size: 1,680 Bytes
40ed35c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from flask import Flask, request, jsonify, render_template
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import requests
import torch

# Initialize Flask app
app = Flask(__name__)

# Load pre-trained model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained('karan99300/ConvNext-finetuned-CIFAR100')
model = AutoModelForImageClassification.from_pretrained('karan99300/ConvNext-finetuned-CIFAR100')

# Define route for home page with form
@app.route('/', methods=['GET', 'POST'])
def index():
    if request.method == 'POST':
        # Get image URL from form submission
        image_url = request.form['image_url']
        
        # Classify image
        predicted_class = classify_image(image_url)
        
        return render_template('index.html', predicted_class=predicted_class, image_url=image_url)
    
    return render_template('index.html')

# Function to classify image
def classify_image(image_url):
    # Fetch image from URL
    try:
        image = Image.open(requests.get(image_url, stream=True).raw)
    except Exception as e:
        return f'Error fetching image: {str(e)}'
    
    # Preprocess image and perform inference
    pixel_values = feature_extractor(image.convert('RGB'), return_tensors='pt').pixel_values
    with torch.no_grad():
        outputs = model(pixel_values)
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
    
    # Get predicted label
    predicted_label = model.config.id2label[predicted_class_idx]
    
    return predicted_label

# Run Flask app
if __name__ == '__main__':
    app.run(debug=True,port=5000)