chesstest / app.py
IndianServers's picture
Update app.py
6e7d4b9 verified
raw
history blame
829 Bytes
import tensorflow as tf
import gradio as gr
import numpy as np
import os
# Load the model
model = tf.saved_model.load('.')
# Define the prediction function
def predict(image):
# Preprocess the image to the required input format
img = np.array(image).astype(np.float32)
img = np.expand_dims(img, axis=0) # Add batch dimension
img = tf.image.resize(img, (640, 640)) # Resize if needed
# Perform inference
predictions = model(img)
return predictions.numpy().tolist() # Adjust output processing as needed
# Set up the Gradio interface
image_input = gr.Image(type="pil")
label_output = gr.Label(num_top_classes=3)
interface = gr.Interface(fn=predict, inputs=image_input, outputs=label_output)
interface.launch(server_port=os.getenv('GRADIO_SERVER_PORT', 7860)) # Use environment variable for port