artfan123's picture
Upload 11 files
539f5c1
import gradio as gr
from transformers import ImageClassificationPipeline, AutoImageProcessor, AutoModelForImageClassification, ResNetForImageClassification
#
#
import torch
from transformers import pipeline
feature_extractor = AutoImageProcessor.from_pretrained("artfan123/resnet-18-finetuned-ai-art")
model = AutoModelForImageClassification.from_pretrained("artfan123/resnet-18-finetuned-ai-art")
image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
def classify_image(image):
results = image_pipe(image)
# convert to format Gradio expects
output = {}
for prediction in results:
predicted_label = prediction['label']
score = prediction['score']
output[predicted_label] = score
return output
image = gr.inputs.Image(type="pil")
label = gr.outputs.Label(num_top_classes=2)
examples = [['50.jpg'], ['344.jpg'],['24.jpg'], ['339.jpg'], ['105.jpg']]
title = "AI Art Detector"
description = "A deep learning model that detects whether an image is AI generated or human made. Upload image or use the example images below."
gr.Interface(fn=classify_image, inputs=image, outputs=label, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True)
# if __name__ == "__main__":
# with gr.Blocks() as demo:
# with gr.Row():
# with gr.Column(scale=4.5):
# with gr.Group():
# image_prompt = gr.Image(type='pil', shape=[512,512],label="Input Image")
# gr.Examples(inputs=image_prompt,examples=[['50.jpg'], ['344.jpg'],['24.jpg'], ['339.jpg'], ['105.jpg']])
# with gr.Row():
# clear_button = gr.Button('Clear')
# run_button = gr.Button('Predict')
# with gr.Column(scale=5.5):
# image_output = gr.Image(type='pil', shape=[512,512], label="Prediction")
# clear_button.click(lambda: None, None, image_prompt, queue=False)
# clear_button.click(lambda: None, None, image_output, queue=False)
# run_button.click(fn=segment,inputs=[image_prompt],
# outputs=[image_output])
# demo.queue().launch(share=True)