Baskar2005's picture
Update app.py
4958652 verified
raw
history blame contribute delete
No virus
4.48 kB
import os
import random
import logging
import gradio as gr
from PIL import Image
from zipfile import ZipFile
from typing import Any, Dict,List
from transformers import pipeline
class Image_classification:
def __init__(self):
pass
def unzip_image_data(self) -> str:
"""
Unzips an image dataset into a specified directory.
Returns:
str: The path to the directory containing the extracted image files.
"""
try:
with ZipFile("image_dataset.zip","r") as extract:
directory_path=str("dataset")
os.mkdir(directory_path)
extract.extractall(f"{directory_path}")
return f"{directory_path}"
except Exception as e:
logging.error(f"An error occurred during extraction: {e}")
return ""
def example_images(self) -> List[str]:
"""
Unzips the image dataset and generates a list of paths to the individual image files and use image for showing example
Returns:
List[str]: A list of file paths to each image in the dataset.
"""
try:
image_dataset_folder = self.unzip_image_data()
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp']
image_count = len([name for name in os.listdir(image_dataset_folder) if os.path.isfile(os.path.join(image_dataset_folder, name)) and os.path.splitext(name)[1].lower() in image_extensions])
example=[]
for i in range(image_count):
for name in os.listdir(image_dataset_folder):
path=(os.path.join(os.path.dirname(image_dataset_folder),os.path.join(image_dataset_folder,name)))
example.append(path)
return example
except Exception as e:
logging.error(f"An error occurred in example images: {e}")
return ""
def classify(self, image: Image.Image, model: Any) -> Dict[str, float]:
"""
Classifies an image using a specified model.
Args:
image (Image.Image): The image to classify.
model (Any): The model used for classification.
Returns:
Dict[str, float]: A dictionary of classification labels and their corresponding scores.
"""
try:
classifier = pipeline("image-classification", model=model)
result= classifier(image)
return result
except Exception as e:
logging.error(f"An error occurred during image classification: {e}")
raise
def format_the_result(self, image: Image.Image, model: Any) -> Dict[str, float]:
"""
Formats the classification result by retaining the highest score for each label.
Args:
image (Image.Image): The image to classify.
model (Any): The model used for classification.
Returns:
Dict[str, float]: A dictionary with unique labels and the highest score for each label.
"""
try:
data=self.classify(image,model)
new_dict = {}
for item in data:
label = item['label']
score = item['score']
if label in new_dict:
if new_dict[label] < score:
new_dict[label] = score
else:
new_dict[label] = score
return new_dict
except Exception as e:
logging.error(f"An error occurred while formatting the results: {e}")
raise
def interface(self):
with gr.Blocks(css="""
.gradio-container {background: #314755;
background: -webkit-linear-gradient(to right, #26a0da, #314755);
background: linear-gradient(to right, #26a0da, #314755);}
.block svelte-90oupt padded{background:314755;
margin:0;
padding:0;}""") as demo:
gr.HTML("""
<center><h1 style="color:#fff">Image Classification</h1></center>""")
exam_img=self.example_images()
with gr.Row():
model = gr.Dropdown(["facebook/regnet-x-040","google/vit-large-patch16-384","microsoft/resnet-50",""],label="Choose a model")
with gr.Row():
image = gr.Image(type="filepath",sources="upload")
with gr.Column():
output=gr.Label()
with gr.Row():
button=gr.Button()
button.click(self.format_the_result,[image,model],output)
gr.Examples(
examples=exam_img,
inputs=[image],
outputs=output,
fn=self.format_the_result,
cache_examples=False,
)
demo.launch(debug=True)
if __name__=="__main__":
image_classification=Image_classification()
result=image_classification.interface()