real-or-fake / app.py
moonkeyboom's picture
Update app.py
e1474bb verified
from transformers import pipeline
import gradio as gr
from transformers import pipeline
from PIL import Image
# "Artifacts"
classifier = pipeline("image-classification", model="./my_final_image_model")
def crop(image):
# Initialize the zero-shot object detector
detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection")
# Perform the detection with candidate labels
predictions = detector(
image,
candidate_labels="human face",
)
# Print the predictions
print(predictions)
# If you want to access specific values:
for pred in predictions:
print(f"Label: {pred['label']}, Score: {pred['score']}, Box: {pred['box']}")
best_guess = max(predictions, key=lambda x: x['score']) # Get prediction with highest score
bbox = best_guess['box']
# Calculate center and size of the square
center_x = (bbox['xmin'] + bbox['xmax']) // 2
center_y = (bbox['ymin'] + bbox['ymax']) // 2
size = min(bbox['xmax'] - bbox['xmin'], bbox['ymax'] - bbox['ymin']) // 2
# Calculate coordinates for the square crop
left = center_x - size
top = center_y - size
right = center_x + size
bottom = center_y + size
# Cropped image of the square
cropped_image = image.crop((left, top, right, bottom))
# Resize to desired size if needed
max_size = (200, 200)
cropped_image.thumbnail(max_size, Image.Resampling.LANCZOS)
return cropped_image
def predict(image):
image = crop(image)
y_pred = classifier(image)
y_pred = {y["label"]: y["score"] for y in y_pred}
return y_pred
# https://www.gradio.app/guides
with gr.Blocks() as demo:
gr.Markdown("""
# Real or Fake Image Classifier
This tool analyzes an image to determine if it matches specific classification criteria.
1. **Upload an image that clearly shows a human face**: Click the box to select an image from your device, or drag and drop it into the box.
2. **Click "Predict"**: After uploading, click "Predict" to start processing.
3. **View the outpu**t: The app will show a classification label and confidence score based on the content of the image.
""")
image = gr.Image(type="pil")
predict_btn = gr.Button("Predict", variant="primary")
output = gr.Label(label="Output")
inputs = [image]
outputs = [output]
predict_btn.click(predict, inputs=inputs, outputs=outputs)
if __name__ == "__main__":
demo.launch() # Local machine only
# demo.launch(server_name="0.0.0.0") # LAN access to local machine
# demo.launch(share=True) # Public access to local machine