{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Inverse Cooking: Recipe Generation from Food Images" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "import os\n", "from args import get_parser\n", "import pickle\n", "from model import get_model\n", "from torchvision import transforms\n", "from utils.output_utils import prepare_output\n", "from PIL import Image\n", "import time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set ```data_dir``` to the path including vocabularies and model checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_dir = '../data'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# code will run in gpu if available and if the flag is set to True, else it will run on cpu\n", "use_gpu = False\n", "device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')\n", "map_loc = None if torch.cuda.is_available() and use_gpu else 'cpu'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# code below was used to save vocab files so that they can be loaded without Vocabulary class\n", "#ingrs_vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_ingrs.pkl'), 'rb'))\n", "#ingrs_vocab = [min(w, key=len) if not isinstance(w, str) else w for w in ingrs_vocab.idx2word.values()]\n", "#vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_toks.pkl'), 'rb')).idx2word\n", "#pickle.dump(ingrs_vocab, open('../demo/ingr_vocab.pkl', 'wb'))\n", "#pickle.dump(vocab, open('../demo/instr_vocab.pkl', 'wb'))\n", "\n", "ingrs_vocab = pickle.load(open(os.path.join(data_dir, 'ingr_vocab.pkl'), 'rb'))\n", "vocab = pickle.load(open(os.path.join(data_dir, 'instr_vocab.pkl'), 'rb'))\n", "\n", "ingr_vocab_size = len(ingrs_vocab)\n", "instrs_vocab_size = len(vocab)\n", "output_dim = instrs_vocab_size" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print (instrs_vocab_size, ingr_vocab_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = time.time()\n", "import sys; sys.argv=['']; del sys\n", "args = get_parser()\n", "args.maxseqlen = 15\n", "args.ingrs_only=False\n", "model = get_model(args, ingr_vocab_size, instrs_vocab_size)\n", "# Load the trained model parameters\n", "model_path = os.path.join(data_dir, 'modelbest.ckpt')\n", "model.load_state_dict(torch.load(model_path, map_location=map_loc))\n", "model.to(device)\n", "model.eval()\n", "model.ingrs_only = False\n", "model.recipe_only = False\n", "print ('loaded model')\n", "print (\"Elapsed time:\", time.time() -t)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transf_list_batch = []\n", "transf_list_batch.append(transforms.ToTensor())\n", "transf_list_batch.append(transforms.Normalize((0.485, 0.456, 0.406), \n", " (0.229, 0.224, 0.225)))\n", "to_input_transf = transforms.Compose(transf_list_batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "greedy = [True, False, False, False]\n", "beam = [-1, -1, -1, -1]\n", "temperature = 1.0\n", "numgens = len(greedy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set ```use_urls = True``` to get recipes for images in ```demo_urls```. \n", "\n", "You can also set ```use_urls = False``` and get recipes for images in the path in ```data_dir/test_imgs```." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "import requests\n", "from io import BytesIO\n", "import random\n", "from collections import Counter\n", "use_urls = False # set to true to load images from demo_urls instead of those in test_imgs folder\n", "show_anyways = False #if True, it will show the recipe even if it's not valid\n", "image_folder = os.path.join(data_dir, 'demo_imgs')\n", "\n", "if not use_urls:\n", " demo_imgs = os.listdir(image_folder)\n", " random.shuffle(demo_imgs)\n", "\n", "demo_urls = ['https://food.fnr.sndimg.com/content/dam/images/food/fullset/2013/12/9/0/FNK_Cheesecake_s4x3.jpg.rend.hgtvcom.826.620.suffix/1387411272847.jpeg',\n", " 'https://www.196flavors.com/wp-content/uploads/2014/10/california-roll-3-FP.jpg']\n", "\n", "demo_files = demo_urls if use_urls else demo_imgs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for img_file in demo_files:\n", " \n", " if use_urls:\n", " response = requests.get(img_file)\n", " image = Image.open(BytesIO(response.content))\n", " else:\n", " image_path = os.path.join(image_folder, img_file)\n", " image = Image.open(image_path).convert('RGB')\n", " \n", " transf_list = []\n", " transf_list.append(transforms.Resize(256))\n", " transf_list.append(transforms.CenterCrop(224))\n", " transform = transforms.Compose(transf_list)\n", " \n", " image_transf = transform(image)\n", " image_tensor = to_input_transf(image_transf).unsqueeze(0).to(device)\n", " \n", " plt.imshow(image_transf)\n", " plt.axis('off')\n", " plt.show()\n", " plt.close()\n", " \n", " num_valid = 1\n", " for i in range(numgens):\n", " with torch.no_grad():\n", " outputs = model.sample(image_tensor, greedy=greedy[i], \n", " temperature=temperature, beam=beam[i], true_ingrs=None)\n", " \n", " ingr_ids = outputs['ingr_ids'].cpu().numpy()\n", " recipe_ids = outputs['recipe_ids'].cpu().numpy()\n", " \n", " outs, valid = prepare_output(recipe_ids[0], ingr_ids[0], ingrs_vocab, vocab)\n", " \n", " if valid['is_valid'] or show_anyways:\n", " \n", " print ('RECIPE', num_valid)\n", " num_valid+=1\n", " #print (\"greedy:\", greedy[i], \"beam:\", beam[i])\n", " \n", " BOLD = '\\033[1m'\n", " END = '\\033[0m'\n", " print (BOLD + '\\nTitle:' + END,outs['title'])\n", "\n", " print (BOLD + '\\nIngredients:'+ END)\n", " print (', '.join(outs['ingrs']))\n", "\n", " print (BOLD + '\\nInstructions:'+END)\n", " print ('-'+'\\n-'.join(outs['recipe']))\n", "\n", " print ('='*20)\n", "\n", " else:\n", " pass\n", " print (\"Not a valid recipe!\")\n", " print (\"Reason: \", valid['reason'])\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }