File size: 3,360 Bytes
894c286 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jarvis/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from transformers import ViTImageProcessor, ViTForImageClassification,FlaxViTForImageClassification\n",
"from PIL import Image\n",
"import requests\n",
"from matplotlib import pyplot as plt "
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['tiger cat', 'tabby, tabby cat', 'Egyptian cat'] [282 281 285]\n"
]
}
],
"source": [
"url = 'http://images.cocodataset.org/val2017/000000039769.jpg'\n",
"image = Image.open(requests.get(url, stream=True).raw)\n",
"\n",
"processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')\n",
"model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')\n",
"\n",
"inputs = processor(images=image, return_tensors=\"pt\")\n",
"outputs = model(**inputs)\n",
"logits = outputs.logits\n",
"\n",
"logits_np = logits.detach().cpu().numpy()\n",
"logits_args = logits_np.argsort()[0][-3:]\n",
"\n",
"prediction_classes = [model.config.id2label[predicted_class_idx] for predicted_class_idx in logits_args ]\n",
"print(prediction_classes,logits_args)\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'tiger cat': -0.27440035,\n",
" 'tabby, tabby cat': 0.8215165,\n",
" 'Egyptian cat': -0.08364794}"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = {}\n",
"for i,item in enumerate(prediction_classes):\n",
" result[item] = logits_np[0][i]\n",
"\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['tiger cat', 'tabby, tabby cat', 'Egyptian cat']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# model predicts one of the 1000 ImageNet classes\n",
"\n",
"prediction_classes = [model.config.id2label[predicted_class_idx] for predicted_class_idx in logits_args ]\n",
"\n",
"prediction_classes\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py_llm",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|