Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from train_model import train | |
from predict_model import predict_all | |
os.environ['NUMPY_EXPERIMENTAL_ARRAY_FUNCTION'] = '0' | |
def train_model(): | |
train() | |
return "Model trained and saved as animal_classifier_resnet.pth" | |
def download_model(): | |
return "animal_classifier_resnet.pth" | |
def run_predictions(): | |
results = predict_all() | |
return "\n".join(results) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Animal Classifier Model") | |
with gr.Tab("Train"): | |
train_button = gr.Button("Train Model") | |
train_output = gr.Textbox() | |
train_button.click(train_model, outputs=train_output) | |
with gr.Tab("Predict"): | |
predict_button = gr.Button("Run Predictions") | |
predict_output = gr.Textbox() | |
predict_button.click(run_predictions, outputs=predict_output) | |
with gr.Tab("Download"): | |
gr.Markdown("## Download Trained Model") | |
download_button = gr.Button("Download Model") | |
download_button.click(download_model, outputs=gr.File()) | |
if __name__ == "__main__": | |
demo.launch() |