TrishanuDas commited on
Commit
0463385
1 Parent(s): 27c2589

minor fixes

Browse files
Files changed (7) hide show
  1. README.md +31 -3
  2. api_endpoint.py +36 -0
  3. app.py +25 -0
  4. app_with_fastapi.py +25 -0
  5. check.ipynb +205 -0
  6. model.py +16 -0
  7. requirements.txt +8 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Cifar10 Classification
3
- emoji: 👀
4
- colorFrom: yellow
5
- colorTo: yellow
6
  sdk: streamlit
7
  sdk_version: 1.33.0
8
  app_file: app.py
@@ -11,3 +11,31 @@ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Cifar10 Classification
3
+ emoji: 🤗
4
+ colorFrom: red
5
+ colorTo: pink
6
  sdk: streamlit
7
  sdk_version: 1.33.0
8
  app_file: app.py
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+
16
+ ### To access the app, follow these steps:
17
+
18
+
19
+ Step 1: Access the app directly on the link (This does not use the FastAPI endpoints):
20
+ [https://huggingface.co/spaces/TrishanuDas/cifar10_classification](https://huggingface.co/spaces/TrishanuDas/cifar10_classification)
21
+
22
+ Step 2: Use the Streamlit app via the FastAPI endpoint.
23
+ - Run the FastAPI server on any instance using the following command:
24
+ ```
25
+ uvicorn api_endpoint:app --reload --host <host_name>
26
+ ```
27
+ - Change the HOST variable on the `app_with_fastapi.py` file.
28
+ - Execute the app_with_fastapi.py file using the following command:
29
+ ```
30
+ streamlit run app_with_fastapi.py
31
+ ```
32
+
33
+ ### Files
34
+
35
+ The following files are present in this repository:
36
+
37
+ - `app.py`: The main Streamlit app file to run directly.
38
+ - `requirements.txt`: The list of Python dependencies required by the app.
39
+ - `model.py`: Contains the code for loading and using the pre-trained model.
40
+ - `app_endpoint.py`: Contains api_endpoint for the prediction.
41
+ - `app_with_fastapi.py`: Contains the code for the Streamlit app with FastAPI endpoint.
api_endpoint.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException, Depends
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi import File, UploadFile
5
+ from PIL import Image
6
+ import model
7
+
8
+ app = FastAPI()
9
+ # Add the CORSMiddleware to enable Cross-Origin Resource Sharing
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["*"], # Allow all origins
13
+ allow_credentials=True,
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ @app.post("/upload_image_for_inference") # This is the endpoint for updating the bot's Knowledge Base
19
+ async def upload_image(file: UploadFile = File(...)):
20
+ # Save the uploaded image to a file
21
+ with open('image.jpg', 'wb') as image:
22
+ contents = await file.read() # Read the content of the uploaded file
23
+ image.write(contents) # Write the content to the image file
24
+
25
+ # Process the image
26
+ image_pil = Image.open('image.jpg') # Open the image using PIL
27
+
28
+ # Predict the image class
29
+ predicted_class = model.predict(image_pil)
30
+ # print(f"Predicted label: {predicted_class}")
31
+
32
+ image_pil.close()
33
+ # Delete the image file after processing
34
+ os.remove("image.jpg")
35
+
36
+ return {'predicted_class': predicted_class}
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import streamlit as st
3
+ import requests
4
+ import model
5
+
6
+ # Streamlit layout
7
+ st.title("CIFAR10 Prediction")
8
+
9
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
10
+
11
+ if uploaded_file is not None:
12
+ # Convert uploaded file to PIL image
13
+ image = Image.open(uploaded_file)
14
+
15
+ # Display the uploaded image
16
+ with st.container(height=300):
17
+ st.image(image, caption='Uploaded Image', use_column_width=True)
18
+
19
+ if st.button('Predict'):
20
+ predicted_class = model.predict(image)
21
+
22
+ if predicted_class is not None:
23
+ st.header(f"Predicted Label: {predicted_class}")
24
+ else:
25
+ st.error("Error processing image. Please try again!")
app_with_fastapi.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import streamlit as st
3
+
4
+ # Streamlit layout
5
+ st.title("CIFAR10 Prediction")
6
+
7
+ HOST = "http://localhost:8000"
8
+
9
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
10
+
11
+ if uploaded_file is not None:
12
+ with st.container(height=300):
13
+ st.image(uploaded_file, caption='Uploaded Image', use_column_width=True)
14
+
15
+ if st.button('Predict'):
16
+ # Send image to FastAPI endpoint
17
+ files = {'file': uploaded_file}
18
+ response = requests.post(f"{HOST}/upload_image_for_inference", files=files)
19
+
20
+ if response.status_code == 200:
21
+ result = response.json()
22
+ st.header(f"Predicted Label: {result['predicted_class']}")
23
+ else:
24
+ st.error("Error processing image. Please try again.")
25
+
check.ipynb ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 73,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch \n",
10
+ "import pickle\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "import numpy as np\n",
13
+ "import time"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 75,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "IMAGE_SIZE = 224 # We need to resize the images given resnet takes input of image size >= 224\n",
23
+ "\n",
24
+ "mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]\n",
25
+ "classes = ('airplane', \n",
26
+ " 'automobile', \n",
27
+ " 'bird',\n",
28
+ " 'cat',\n",
29
+ " 'deer',\n",
30
+ " 'dog', \n",
31
+ " 'frog', \n",
32
+ " 'horse', \n",
33
+ " 'ship',\n",
34
+ " 'truck')\n",
35
+ "\n",
36
+ "if torch.cuda.is_available():\n",
37
+ " torch.set_default_device('cuda')\n",
38
+ "\n",
39
+ "def show_data(img):\n",
40
+ " try:\n",
41
+ " plt.imshow(img[0])\n",
42
+ " except Exception as e:\n",
43
+ " print(e)\n",
44
+ " print(img[0].shape, img[0].permute(1,2,0).shape)\n",
45
+ " plt.imshow(img[0].permute(1,2,0))\n",
46
+ " plt.title('y = '+ str(img[1]))\n",
47
+ " plt.show()\n",
48
+ " \n",
49
+ "# We need to convert the images to numpy arrays as tensors are not compatible with matplotlib.\n",
50
+ "def im_convert(tensor):\n",
51
+ " #Lets\n",
52
+ " img = tensor.cpu().clone().detach().numpy() #\n",
53
+ " img = img.transpose(1, 2, 0)\n",
54
+ " img = img * np.array(tuple(mean)) + np.array(tuple(std))\n",
55
+ " img = img.clip(0, 1) # Clipping the size to print the images later\n",
56
+ " return img"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 64,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "def unpickle(file):\n",
66
+ " with open(file, 'rb') as fo:\n",
67
+ " data_dict = pickle.load(fo, encoding='bytes')\n",
68
+ " \n",
69
+ " # Decode keys from bytes to strings\n",
70
+ " decoded_dict = {}\n",
71
+ " for key, value in data_dict.items():\n",
72
+ " decoded_key = key.decode('utf-8') # Assuming UTF-8 encoding\n",
73
+ " decoded_dict[decoded_key] = value\n",
74
+ " \n",
75
+ " return decoded_dict\n"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 76,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "decoded_dict = unpickle('./test_batch')\n",
85
+ "decoded_dict\n",
86
+ "data = torch.tensor(decoded_dict['data']).reshape([10000,3,32,32])\n",
87
+ "dataset = {\"image\":data, \"target\": torch.tensor(decoded_dict[\"labels\"])}"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 77,
93
+ "metadata": {},
94
+ "outputs": [
95
+ {
96
+ "data": {
97
+ "text/plain": [
98
+ "dict_keys(['batch_label', 'labels', 'data', 'filenames'])"
99
+ ]
100
+ },
101
+ "execution_count": 77,
102
+ "metadata": {},
103
+ "output_type": "execute_result"
104
+ }
105
+ ],
106
+ "source": [
107
+ "decoded_dict.keys()"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": 78,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "idx = 0\n",
117
+ "image = dataset['image'][idx]\n",
118
+ "label = dataset[\"target\"][idx].item()"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 79,
124
+ "metadata": {},
125
+ "outputs": [
126
+ {
127
+ "data": {
128
+ "text/plain": [
129
+ "'cat'"
130
+ ]
131
+ },
132
+ "execution_count": 79,
133
+ "metadata": {},
134
+ "output_type": "execute_result"
135
+ }
136
+ ],
137
+ "source": [
138
+ "classes[label]"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 82,
144
+ "metadata": {},
145
+ "outputs": [
146
+ {
147
+ "name": "stdout",
148
+ "output_type": "stream",
149
+ "text": [
150
+ "cat\n",
151
+ "Time taken: 0.013 s\n"
152
+ ]
153
+ }
154
+ ],
155
+ "source": [
156
+ "# Load model directly\n",
157
+ "from transformers import AutoImageProcessor, AutoModelForImageClassification\n",
158
+ "\n",
159
+ "processor = AutoImageProcessor.from_pretrained(\"heyitskim1912/AML_A2_Q4\")\n",
160
+ "model = AutoModelForImageClassification.from_pretrained(\"heyitskim1912/AML_A2_Q4\")\n",
161
+ "\n",
162
+ "inputs = processor(image, return_tensors=\"pt\")\n",
163
+ "\n",
164
+ "start_time = time.time()\n",
165
+ "with torch.no_grad():\n",
166
+ " logits = model(**inputs).logits\n",
167
+ "\n",
168
+ "# model predicts one of the 1000 ImageNet classes\n",
169
+ "predicted_label = logits.argmax(-1).item()\n",
170
+ "print(model.config.id2label[predicted_label])\n",
171
+ "end_time = time.time()\n",
172
+ "time_taken = round(end_time - start_time, 3)\n",
173
+ "print(f\"Time taken: {time_taken} s\")"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": []
182
+ }
183
+ ],
184
+ "metadata": {
185
+ "kernelspec": {
186
+ "display_name": "PyTorchenv",
187
+ "language": "python",
188
+ "name": "python3"
189
+ },
190
+ "language_info": {
191
+ "codemirror_mode": {
192
+ "name": "ipython",
193
+ "version": 3
194
+ },
195
+ "file_extension": ".py",
196
+ "mimetype": "text/x-python",
197
+ "name": "python",
198
+ "nbconvert_exporter": "python",
199
+ "pygments_lexer": "ipython3",
200
+ "version": "3.10.9"
201
+ }
202
+ },
203
+ "nbformat": 4,
204
+ "nbformat_minor": 2
205
+ }
model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
2
+ import torch
3
+
4
+ processor = AutoImageProcessor.from_pretrained("heyitskim1912/AML_A2_Q4")
5
+ model = AutoModelForImageClassification.from_pretrained("heyitskim1912/AML_A2_Q4")
6
+
7
+ def predict(image_pil):
8
+ inputs = processor(image_pil, return_tensors="pt")
9
+
10
+ with torch.no_grad():
11
+ logits = model(**inputs).logits
12
+
13
+ # Get predicted label
14
+ predicted_label = logits.argmax(-1).item()
15
+ predicted_class = model.config.id2label[predicted_label]
16
+ return predicted_class
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ pillow
3
+ streamlit
4
+ requests
5
+ fastapi
6
+ torch
7
+ uvicorn
8
+ gunicorn