chansung commited on
Commit
237e09c
1 Parent(s): 7a487c4

upload 1683855462 model

Browse files
Files changed (4) hide show
  1. README.md +11 -0
  2. app.py +83 -0
  3. labels.txt +3 -0
  4. requirements.txt +3 -0
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Leaf Classification
3
+ emoji: 🍂
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.29.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tarfile
2
+ import wandb
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ import tensorflow as tf
8
+ from transformers import ViTFeatureExtractor
9
+
10
+ PRETRAIN_CHECKPOINT = "google/vit-base-patch16-224-in21k"
11
+ feature_extractor = ViTFeatureExtractor.from_pretrained(PRETRAIN_CHECKPOINT)
12
+
13
+ MODEL = None
14
+
15
+ RESOLTUION = 224
16
+
17
+ labels = []
18
+
19
+ with open(r"labels.txt", "r") as fp:
20
+ for line in fp:
21
+ labels.append(line[:-1])
22
+
23
+ def normalize_img(
24
+ img, mean=feature_extractor.image_mean, std=feature_extractor.image_std
25
+ ):
26
+ img = img / 255
27
+ mean = tf.constant(mean)
28
+ std = tf.constant(std)
29
+ return (img - mean) / std
30
+
31
+ def preprocess_input(image):
32
+ image = np.array(image)
33
+ image = tf.convert_to_tensor(image)
34
+
35
+ image = tf.image.resize(image, (RESOLTUION, RESOLTUION))
36
+ image = normalize_img(image)
37
+
38
+ image = tf.transpose(
39
+ image, (2, 0, 1)
40
+ ) # Since HF models are channel-first.
41
+
42
+ return {
43
+ "pixel_values": tf.expand_dims(image, 0)
44
+ }
45
+
46
+ def get_predictions(wb_token, image):
47
+ global MODEL
48
+
49
+ if MODEL is None:
50
+ wandb.login(key=wb_token)
51
+ wandb.init(project="tfx-vit-pipeline")
52
+ path = wandb.use_artifact('tfx-vit-pipeline/final_model:1683855462', type='model').download()
53
+
54
+ tar = tarfile.open(f"{path}/model.tar.gz")
55
+ tar.extractall(path=".")
56
+
57
+ MODEL = tf.keras.models.load_model("./model")
58
+
59
+ preprocessed_image = preprocess_input(image)
60
+ prediction = MODEL.predict(preprocessed_image)
61
+ probs = tf.nn.softmax(prediction['logits'], axis=1)
62
+
63
+ confidences = {labels[i]: float(probs[0][i]) for i in range(3)}
64
+ return confidences
65
+
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown("## Simple demo for a Image Classification of the Beans Dataset with HF ViT model")
68
+
69
+ wb_token_if = gr.Textbox(interactive=True, label="Your Weight & Biases API Key")
70
+
71
+ with gr.Row():
72
+ image_if = gr.Image()
73
+ label_if = gr.Label(num_top_classes=3)
74
+
75
+ classify_if = gr.Button()
76
+
77
+ classify_if.click(
78
+ get_predictions,
79
+ [wb_token_if, image_if],
80
+ label_if
81
+ )
82
+
83
+ demo.launch(debug=True)
labels.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ angular_leaf_spot
2
+ bean_rust
3
+ healthy
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ tensorflow
2
+ transformers
3
+ wandb