cat_classifiter / app.py
zhijian12345's picture
Update app.py
1ae03fe
raw
history blame contribute delete
No virus
1.89 kB
# pip install --upgrade pip
# pip install --no-cache-dir -r requirements.txt
import os
from transformers import pipeline
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from math import sqrt
import gradio as gr
import numpy as np
# 加载图像分类模型
checkpoint_dir = "./checkpoint-905" # 模型检查点目录
classifier = pipeline("image-classification", model=checkpoint_dir) # 创建图像分类器模型
vitclassifier = pipeline("image-classification",model="google/vit-base-patch16-224")
demo = gr.Blocks()
# 定义推理函数
def flip_myvit(image):
# 图像预处理
image = Image.fromarray(image.astype('uint8'), 'RGB')
# 进行图像分类
result = classifier(image)
# 返回分类结果
text = "{:.3f}%".format(result[0]['score'] * 100)
return result[0]['label'],text
def flip_vit(image):
# 图像预处理
image = Image.fromarray(image.astype('uint8'), 'RGB')
# 进行图像分类
result = vitclassifier(image)
# 返回分类结果
text = "{:.3f}%".format(result[0]['score'] * 100)
return result[0]['label'],text
with demo:
with gr.Tabs():
with gr.TabItem("myvit"):
myvit_input = gr.Image()
myvit_output1 = gr.Textbox(label="预测结果")
myvit_output2 = gr.Textbox(label="准确度")
myvit_button = gr.Button("开始")
with gr.TabItem("vit"):
vit_input = gr.Image()
vit_output1 = gr.Textbox(label="预测结果")
vit_output2 = gr.Textbox(label="准确度")
vit_button = gr.Button("开始")
myvit_button.click(flip_myvit, inputs=myvit_input, outputs=[myvit_output1,myvit_output2])
vit_button.click(flip_vit, inputs=vit_input, outputs=[vit_output1,vit_output2])
demo.title="猫狗分类器"
demo.launch()