zhijian12345 commited on
Commit
136aa0e
1 Parent(s): 16c601e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import pipeline
3
+ from tqdm import tqdm
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from math import sqrt
7
+ import gradio as gr
8
+ import numpy as np
9
+
10
+ model_info = """
11
+ **模型名称**: Google/vit-base-patch16-224
12
+ **模型介绍**: 本程序根据huggingface上Google开源模型vit,在猫狗图片数据上进行微调,上传一张图片,将会预测其类别并显示结果。模型官网:https://huggingface.co/google/vit-base-patch16-224
13
+ **程序作者**: 计科三班 王志建、计科三班 罗楷轩
14
+ **特别支持**: 计科三班 黄成栋
15
+ """
16
+
17
+ # 加载图像分类模型
18
+ checkpoint_dir = "./checkpoint/checkpoint-181" # 模型检查点目录
19
+ classifier = pipeline("image-classification", model=checkpoint_dir) # 创建图像分类器模型
20
+ vitclassifier = pipeline("image-classification",model="google/vit-base-patch16-224")
21
+
22
+
23
+ demo = gr.Blocks()
24
+
25
+ # 定义推理函数
26
+ def flip_myvit(image):
27
+ # 图像预处理
28
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
29
+ # 进行图像分类
30
+ result = classifier(image)
31
+ # 返回分类结果
32
+ text = "{:.3f}%".format(result[0]['score'] * 100)
33
+ return result[0]['label'],text
34
+
35
+
36
+ def flip_vit(image):
37
+ # 图像预处理
38
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
39
+ # 进行图像分类
40
+ result = vitclassifier(image)
41
+ # 返回分类结果
42
+ text = "{:.3f}%".format(result[0]['score'] * 100)
43
+ return result[0]['label'],text
44
+
45
+ with demo:
46
+ gr.Markdown(model_info)
47
+ with gr.Tabs():
48
+ with gr.TabItem("myvit"):
49
+ myvit_input = gr.Image()
50
+ myvit_output1 = gr.Textbox(label="预测结果")
51
+ myvit_output2 = gr.Textbox(label="准确度")
52
+ myvit_button = gr.Button("开始")
53
+ with gr.TabItem("vit"):
54
+ vit_input = gr.Image()
55
+ vit_output1 = gr.Textbox(label="预测结果")
56
+ vit_output2 = gr.Textbox(label="准确度")
57
+ vit_button = gr.Button("开始")
58
+
59
+ myvit_button.click(flip_myvit, inputs=myvit_input, outputs=[myvit_output1,myvit_output2])
60
+ vit_button.click(flip_vit, inputs=vit_input, outputs=[vit_output1,vit_output2])
61
+
62
+ demo.title="猫狗分类器"
63
+ demo.launch()
64
+