{ "cells": [ { "cell_type": "code", "execution_count": 15, "id": "02edd069-0381-4537-902e-03ffd273349c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\keras\\src\\layers\\convolutional\\base_conv.py:99: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n", " super().__init__(\n", "WARNING:absl:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n", "WARNING:absl:Error in loading the saved optimizer state. As a result, your model is starting with a freshly initialized optimizer.\n" ] } ], "source": [ "from keras.models import load_model\n", "import numpy as np\n", "\n", "# Load the saved model\n", "model = load_model('model.h5')\n", "\n" ] }, { "cell_type": "code", "execution_count": 33, "id": "78c9d169-adbe-4588-9a78-fb02b90e3781", "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms" ] }, { "cell_type": "code", "execution_count": 16, "id": "f57b1e4e-c171-4233-addf-a5bbdd91896f", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'input_image' is not defined", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[16], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# Perform inference on the input image\u001b[39;00m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;66;03m# Make sure your input shape matches the input shape of the model\u001b[39;00m\n\u001b[1;32m----> 3\u001b[0m predicted_image \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mpredict(np\u001b[38;5;241m.\u001b[39mexpand_dims(\u001b[43minput_image\u001b[49m, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m))\n\u001b[0;32m 5\u001b[0m \u001b[38;5;66;03m# The output 'predicted_image' will be the deblurred image generated by the model\u001b[39;00m\n\u001b[0;32m 6\u001b[0m \u001b[38;5;66;03m# You can further process or save the output image as needed\u001b[39;00m\n", "\u001b[1;31mNameError\u001b[0m: name 'input_image' is not defined" ] } ], "source": [ "# Perform inference on the input image\n", "# Make sure your input shape matches the input shape of the model\n", "predicted_image = model.predict(np.expand_dims(input_image, axis=0))\n", "\n", "# The output 'predicted_image' will be the deblurred image generated by the model\n", "# You can further process or save the output image as needed\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "87814748-7c0b-41e2-998b-d3a3eb6d7bbd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'2.16.1'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import tensorflow as tf\n", "tf.__version__" ] }, { "cell_type": "code", "execution_count": 5, "id": "9fafecb4-54e9-43a5-ac19-143205069848", "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow.keras.datasets import cifar10\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout\n", "from tensorflow.keras.utils import to_categorical\n", "from tensorflow.keras.optimizers import Adam\n", "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", "\n", "# Load CIFAR-10 dataset\n", "(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n", "\n", "# Normalize pixel values to be between 0 and 1\n", "x_train = x_train.astype('float32') / 255.0\n", "x_test = x_test.astype('float32') / 255.0\n", "\n", "# One-hot encode the labels\n", "y_train = to_categorical(y_train, num_classes=10)\n", "y_test = to_categorical(y_test, num_classes=10)" ] }, { "cell_type": "code", "execution_count": 47, "id": "b8e744fc-d509-49a9-a1d1-9be7f37a6c21", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 1.],\n", " [0., 0., 0., ..., 0., 0., 1.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 1.],\n", " [0., 1., 0., ..., 0., 0., 0.],\n", " [0., 1., 0., ..., 0., 0., 0.]])" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train" ] }, { "cell_type": "code", "execution_count": 7, "id": "0ae22ae2-c40d-4a4a-9c25-00d2c423b53c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m313/313\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m23s\u001b[0m 72ms/step\n" ] } ], "source": [ "predicted = model.predict(x_test)" ] }, { "cell_type": "code", "execution_count": 9, "id": "e80370e5-6753-4b04-afde-ddcfcdb3c148", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(10000, 32, 32, 3)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_test.shape" ] }, { "cell_type": "code", "execution_count": 13, "id": "9bd01caf-8ce5-465a-b734-a5480b96521d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([3, 8, 8, ..., 5, 1, 7], dtype=int64)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmax(predicted, axis = 1)" ] }, { "cell_type": "code", "execution_count": null, "id": "070f8ef8-6522-4fe7-9d35-67f1861de531", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "# Define AlexNet architecture\n", "model = Sequential([\n", " # First convolutional layer\n", " Conv2D(96, (11, 11), strides=(1, 1), activation='relu', input_shape=(32, 32, 3)),\n", " MaxPooling2D(pool_size=(3, 3), strides=(2, 2)),\n", " # Second convolutional layer\n", " Conv2D(256, (5, 5), padding='same', activation='relu'),\n", " MaxPooling2D(pool_size=(3, 3), strides=(2, 2)),\n", " # Third convolutional layer\n", " Conv2D(384, (3, 3), padding='same', activation='relu'),\n", " # Fourth convolutional layer\n", " Conv2D(384, (3, 3), padding='same', activation='relu'),\n", " # Fifth convolutional layer\n", " Conv2D(256, (3, 3), padding='same', activation='relu'),\n", " MaxPooling2D(pool_size=(3, 3), strides=(2, 2)),\n", " # Flatten the convolutional layers output for fully connected layers\n", " Flatten(),\n", " # First fully connected layer\n", " Dense(4096, activation='relu'),\n", " Dropout(0.5),\n", " # Second fully connected layer\n", " Dense(4096, activation='relu'),\n", " Dropout(0.5),\n", " # Output layer\n", " Dense(10, activation='softmax')\n", "])\n", "\n", "# Compile the model with a lower learning rate\n", "optimizer = Adam(learning_rate=0.0001)\n", "model.compile(optimizer=optimizer,\n", " loss='categorical_crossentropy',\n", " metrics=['accuracy'])\n", "\n", "# Data augmentation\n", "datagen = ImageDataGenerator(\n", " rotation_range=15,\n", " width_shift_range=0.1,\n", " height_shift_range=0.1,\n", " horizontal_flip=True,\n", ")\n", "\n", "datagen.fit(x_train)\n", "\n", "# Train the model with data augmentation\n", "model.fit(datagen.flow(x_train, y_train, batch_size=128), epochs=25, validation_data=(x_test, y_test))\n", "\n", "# Evaluate the model on the test set\n", "test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=2)\n", "\n", "print(\"\\nTest Accuracy:\", test_accuracy)\n", "print(\"Test Loss:\", test_loss)" ] }, { "cell_type": "code", "execution_count": 23, "id": "82b6c768-2b9d-4633-8e17-96d26d814421", "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 40, "id": "692a13a1-3483-4dd9-9364-c27e909b89d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7896\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" }, { "name": "stdout", "output_type": "stream", "text": [ "(1280, 717, 3)\n", "(1, 32, 32, 3)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\gradio\\queueing.py\", line 501, in call_prediction\n", " output = await route_utils.call_process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\gradio\\route_utils.py\", line 253, in call_process_api\n", " output = await app.get_blocks().process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\gradio\\blocks.py\", line 1695, in process_api\n", " result = await self.call_function(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\gradio\\blocks.py\", line 1235, in call_function\n", " prediction = await anyio.to_thread.run_sync(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\anyio\\to_thread.py\", line 33, in run_sync\n", " return await get_asynclib().run_sync_in_worker_thread(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 877, in run_sync_in_worker_thread\n", " return await future\n", " ^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\anyio\\_backends\\_asyncio.py\", line 807, in run\n", " result = context.run(func, *args)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\gradio\\utils.py\", line 692, in wrapper\n", " response = f(*args, **kwargs)\n", " ^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Temp\\ipykernel_29808\\1451871443.py\", line 16, in prediction\n", " output = model.predict(transformed_image)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\keras\\src\\utils\\traceback_utils.py\", line 122, in error_handler\n", " raise e.with_traceback(filtered_tb) from None\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\keras\\src\\models\\functional.py\", line 280, in _adjust_input_rank\n", " raise ValueError(\n", "ValueError: Exception encountered when calling Sequential.call().\n", "\n", "\u001b[1mInvalid input shape for input Tensor(\"data:0\", shape=(32, 32, 3), dtype=float32). Expected shape (None, 32, 32, 3), but input has incompatible shape (32, 32, 3)\u001b[0m\n", "\n", "Arguments received by Sequential.call():\n", " • inputs=tf.Tensor(shape=(32, 32, 3), dtype=float32)\n", " • training=False\n", " • mask=None\n", "Exception in thread Thread-81 (_do_normal_analytics_request):\n", "Traceback (most recent call last):\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpcore\\_exceptions.py\", line 10, in map_exceptions\n", " yield\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpcore\\_backends\\sync.py\", line 168, in start_tls\n", " raise exc\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpcore\\_backends\\sync.py\", line 163, in start_tls\n", " sock = ssl_context.wrap_socket(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\ssl.py\", line 455, in wrap_socket\n", " return self.sslsocket_class._create(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\ssl.py\", line 1046, in _create\n", " self.do_handshake()\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\ssl.py\", line 1317, in do_handshake\n", " self._sslobj.do_handshake()\n", "TimeoutError: _ssl.c:983: The handshake operation timed out\n", "\n", "The above exception was the direct cause of the following exception:\n", "\n", "Traceback (most recent call last):\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_transports\\default.py\", line 69, in map_httpcore_exceptions\n", " yield\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_transports\\default.py\", line 233, in handle_request\n", " resp = self._pool.handle_request(req)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpcore\\_sync\\connection_pool.py\", line 268, in handle_request\n", " raise exc\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpcore\\_sync\\connection_pool.py\", line 251, in handle_request\n", " response = connection.handle_request(request)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpcore\\_sync\\connection.py\", line 99, in handle_request\n", " raise exc\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpcore\\_sync\\connection.py\", line 76, in handle_request\n", " stream = self._connect(request)\n", " ^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpcore\\_sync\\connection.py\", line 156, in _connect\n", " stream = stream.start_tls(**kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpcore\\_backends\\sync.py\", line 152, in start_tls\n", " with map_exceptions(exc_map):\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\contextlib.py\", line 155, in __exit__\n", " self.gen.throw(value)\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpcore\\_exceptions.py\", line 14, in map_exceptions\n", " raise to_exc(exc) from exc\n", "httpcore.ConnectTimeout: _ssl.c:983: The handshake operation timed out\n", "\n", "The above exception was the direct cause of the following exception:\n", "\n", "Traceback (most recent call last):\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\threading.py\", line 1052, in _bootstrap_inner\n", " self.run()\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\threading.py\", line 989, in run\n", " self._target(*self._args, **self._kwargs)\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\gradio\\analytics.py\", line 63, in _do_normal_analytics_request\n", " httpx.post(url, data=data, timeout=5)\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_api.py\", line 319, in post\n", " return request(\n", " ^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_api.py\", line 106, in request\n", " return client.request(\n", " ^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_client.py\", line 827, in request\n", " return self.send(request, auth=auth, follow_redirects=follow_redirects)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_client.py\", line 914, in send\n", " response = self._send_handling_auth(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_client.py\", line 942, in _send_handling_auth\n", " response = self._send_handling_redirects(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_client.py\", line 979, in _send_handling_redirects\n", " response = self._send_single_request(request)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_client.py\", line 1015, in _send_single_request\n", " response = transport.handle_request(request)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_transports\\default.py\", line 232, in handle_request\n", " with map_httpcore_exceptions():\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\contextlib.py\", line 155, in __exit__\n", " self.gen.throw(value)\n", " File \"C:\\Users\\Haider Ali\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\httpx\\_transports\\default.py\", line 86, in map_httpcore_exceptions\n", " raise mapped_exc(message) from exc\n", "httpx.ConnectTimeout: _ssl.c:983: The handshake operation timed out\n" ] } ], "source": [ "def prediction(input_img):\n", " # image = Image.open(\"img1.jpg\")\n", " print(input_img.shape)\n", " # Define the transformation\n", " transform = transforms.Compose([\n", " transforms.Resize(32),\n", " transforms.CenterCrop(32),\n", " transforms.ToTensor(),\n", " # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ])\n", " pil_image = Image.fromarray(input_img.astype('uint8'))\n", " # Apply the transformation\n", " transformed_image = np.array(transform(pil_image).T)\n", " input_image = np.expand_dims(transformed_image, axis=0)\n", " print(input_image.shape)\n", " output = model.predict(transformed_image)\n", " print(output)\n", " # print(transformed_image.shape)\n", " # print(transformed_image)\n", " # plt.imshow(transformed_image)\n", " # plt.show()\n", " # return transformed_image\n", "demo = gr.Interface(prediction, gr.Image(), \"image\")\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": 49, "id": "b2b8aaaf-ce9c-4c25-875e-fce01fbf3832", "metadata": {}, "outputs": [], "source": [ "classes = {\n", " 0 : 'Airplane',\n", " 1 : 'Automobile',\n", " 2 : 'Bird',\n", " 3 : 'Cat',\n", " 4 : 'Deer',\n", " 5 : 'Dog',\n", " 6 : 'Frog',\n", " 7 : 'Horse',\n", " 8 : 'Ship',\n", " 9 : 'Truck'\n", "}" ] }, { "cell_type": "code", "execution_count": 53, "id": "d3da2b58-6b86-4e6d-9b7f-9f8ce1fe2339", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7904\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" }, { "name": "stdout", "output_type": "stream", "text": [ "(1600, 1204, 3)\n", "(1, 32, 32, 3)\n", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 52ms/step\n", "[[0.00640055 0.17760815 0.04763744 0.09621317 0.05900569 0.09116109\n", " 0.02236336 0.09745935 0.01388952 0.38826168]]\n" ] } ], "source": [ "def prediction(input_img):\n", " # Define the transformation\n", " transform = transforms.Compose([\n", " transforms.Resize(32),\n", " transforms.CenterCrop(32),\n", " transforms.ToTensor(),\n", " # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ])\n", " pil_image = Image.fromarray(input_img.astype('uint8'))\n", " # Apply the transformation\n", " transformed_image = np.array(transform(pil_image).T)\n", " input_image = np.expand_dims(transformed_image, axis=0)\n", " output = model.predict(input_image)\n", " # print(transformed_image.shape)\n", " # print(transformed_image)\n", " # plt.imshow(transformed_image)\n", " # plt.show()\n", " return classes[np.argmax(output)]\n", "demo = gr.Interface(prediction, gr.Image(), \"text\")\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "2070d29e-b593-490e-b867-6a10ab1b02ff", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.0" } }, "nbformat": 4, "nbformat_minor": 5 }