Spaces:
Runtime error
Runtime error
File size: 1,680 Bytes
40ed35c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from flask import Flask, request, jsonify, render_template
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import requests
import torch
# Initialize Flask app
app = Flask(__name__)
# Load pre-trained model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained('karan99300/ConvNext-finetuned-CIFAR100')
model = AutoModelForImageClassification.from_pretrained('karan99300/ConvNext-finetuned-CIFAR100')
# Define route for home page with form
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
# Get image URL from form submission
image_url = request.form['image_url']
# Classify image
predicted_class = classify_image(image_url)
return render_template('index.html', predicted_class=predicted_class, image_url=image_url)
return render_template('index.html')
# Function to classify image
def classify_image(image_url):
# Fetch image from URL
try:
image = Image.open(requests.get(image_url, stream=True).raw)
except Exception as e:
return f'Error fetching image: {str(e)}'
# Preprocess image and perform inference
pixel_values = feature_extractor(image.convert('RGB'), return_tensors='pt').pixel_values
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# Get predicted label
predicted_label = model.config.id2label[predicted_class_idx]
return predicted_label
# Run Flask app
if __name__ == '__main__':
app.run(debug=True,port=5000)
|