import streamlit as st import torch from transformers import pipeline, set_seed from transformers import AutoTokenizer from PIL import ( Image, ImageFont, ImageDraw ) import numpy as np from datetime import datetime import glob import re import random import textwrap from examples import EXAMPLES import meta from utils import ( remote_css, local_css, load_image, pure_comma_separation ) np.random.seed(42) class TextGeneration: def __init__(self): self.debug = False self.dummy_output = { 'directions': [ 'peel the potato and slice thinly.', 'place in a microwave safe dish.', 'cover with plastic wrap and microwave on high for 5 minutes.', 'remove from the microwave and sprinkle with cheese.', 'return to the microwave for 1 minute or until cheese is melted.', 'return to the microwave for 1 minute or until cheese is melted. return to the microwave for 1 minute or until cheese is melted.' 'return to the microwave for 1 minute or until cheese is melted.', 'return to the microwave for 1 minute or until cheese is melted.', 'return to the microwave for 1 minute or until cheese is melted.', ], 'ingredients': [ '1 potato', '1 slice cheese 1 slice cheese', '1 potato 1 slice cheese 1 slice cheese', '1 slice cheese' '1 potato', '1 slice cheese', '1 slice cheese', '1 potato', '1 slice cheese' '1 potato', '1 slice cheese', ], 'title': 'Cheese Potatoes with Some other items' } self.tokenizer = None self.generator = None self.task = "text2text-generation" self.model_name_or_path = "flax-community/t5-recipe-generation" self.frames = list(glob.glob("asset/images/frames/*.jpg")) self.fonts = { "title": ImageFont.truetype("asset/fonts/Poppins-Bold.ttf", 90), "sub_title": ImageFont.truetype("asset/fonts/Poppins-Medium.ttf", 30), "body_bold": ImageFont.truetype("asset/fonts/Montserrat-Bold.ttf", 50), "body": ImageFont.truetype("asset/fonts/Montserrat-Regular.ttf", 30), } self.list_division = 5 self.point = "-" set_seed(42) def _skip_special_tokens_and_prettify(self, text): recipe_maps = {"": "--", "
": "\n"} recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys())) text = re.sub( recipe_map_pattern, lambda m: recipe_maps[m.group()], re.sub("|".join(self.tokenizer.all_special_tokens), "", text) ) data = {"title": "", "ingredients": [], "directions": []} for section in text.split("\n"): section = section.strip() if section.startswith("title:"): data["title"] = " ".join( [w.strip().capitalize() for w in section.replace("title:", "").strip().split() if w.strip()] ) elif section.startswith("ingredients:"): data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')] elif section.startswith("directions:"): data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')] else: pass return data def load(self): if not self.debug: self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) self.generator = pipeline(self.task, model=self.model_name_or_path, tokenizer=self.model_name_or_path) def prepare_frame(self, recipe, frame): im_editable = ImageDraw.Draw(frame) # Info im_editable.text( (1570, 180), datetime.now().strftime("%Y-%m-%d %H:%M"), (61, 61, 70), font=self.fonts["sub_title"], ) im_editable.text( (1570, 170 + 70), "By " + recipe["by"], (61, 61, 70), font=self.fonts["sub_title"], ) # Title im_editable.text( (400, 650), textwrap.fill(recipe["title"], 30), (61, 61, 70), font=self.fonts["title"], ) # Ingredients im_editable.text( (100, 1050), "Ingredients", (61, 61, 70), font=self.fonts["body_bold"], ) ingredients = recipe["ingredients"] ingredients = [textwrap.fill(item, 30).replace("\n", "\n ") for item in ingredients] im_editable.text( (100, 1130), "\n".join([f"{self.point} {item}" for item in ingredients]), (61, 61, 70), font=self.fonts["body"], ) # Directions im_editable.text( (700, 1050), "Directions", (61, 61, 70), font=self.fonts["body_bold"], ) directions = recipe["directions"] directions_col1 = [textwrap.fill(item, 30).replace("\n", "\n ") for item in directions[:self.list_division]] directions_col2 = [textwrap.fill(item, 30).replace("\n", "\n ") for item in directions[self.list_division:]] im_editable.text( (700, 1130), "\n".join([f"{i + 1}. {item}" for i, item in enumerate(directions_col1)]).strip(), (61, 61, 70), font=self.fonts["body"], ) im_editable.text( (1300, 1130), "\n".join([f"{i + 1 + self.list_division}. {item}" for i, item in enumerate(directions_col2)]).strip(), (61, 61, 70), font=self.fonts["body"], ) return frame def generate(self, items, generation_kwargs): if not self.debug: generation_kwargs["num_return_sequences"] = 1 # generation_kwargs["return_full_text"] = False generation_kwargs["return_tensors"] = True generation_kwargs["return_text"] = False generated_ids = self.generator( items, **generation_kwargs, )[0]["generated_token_ids"] recipe = self.tokenizer.decode(generated_ids, skip_special_tokens=False) recipe = self._skip_special_tokens_and_prettify(recipe) return recipe return self.dummy_output def generate_frame(self, recipe, frame_path): frame = load_image(frame_path) return self.prepare_frame(recipe, frame) @st.cache(allow_output_mutation=True) def load_text_generator(): generator = TextGeneration() generator.load() return generator chef_top = { "max_length": 512, "min_length": 64, "no_repeat_ngram_size": 3, "do_sample": True, "top_k": 60, "top_p": 0.95, "num_return_sequences": 1 } chef_beam = { "max_length": 512, "min_length": 64, "no_repeat_ngram_size": 3, "early_stopping": True, "num_beams": 5, "length_penalty": 1.5, "num_return_sequences": 1 } def main(): st.set_page_config( page_title="Chef Transformer", page_icon="🍲", layout="wide", initial_sidebar_state="expanded" ) generator = load_text_generator() if hasattr(st, "session_state"): if 'get_random_frame' not in st.session_state: st.session_state.get_random_frame = generator.frames[0] else: get_random_frame = generator.frames[0] local_css("asset/css/style.css") col1, col2 = st.beta_columns([5, 3]) with col2: st.image(load_image("asset/images/chef-transformer-transparent.png"), width=300) st.markdown(meta.SIDEBAR_INFO, unsafe_allow_html=True) with st.beta_expander("Where did this story start?"): st.markdown(meta.STORY, unsafe_allow_html=True) with col1: st.markdown(meta.HEADER_INFO, unsafe_allow_html=True) st.markdown(meta.CHEF_INFO, unsafe_allow_html=True) chef = st.selectbox("Choose your chef", index=0, options=["Chef Scheherazade", "Chef Giovanni"]) prompts = list(EXAMPLES.keys()) + ["Custom"] prompt = st.selectbox( 'Examples (select from this list)', prompts, # index=len(prompts) - 1, index=0 ) if prompt == "Custom": prompt_box = "" else: prompt_box = EXAMPLES[prompt] items = st.text_area( 'Insert your ingredients here (separated by `,`): ', pure_comma_separation(prompt_box, return_list=False), ) items = pure_comma_separation(items, return_list=False) entered_items = st.empty() recipe_button = st.button('Get Recipe!') st.markdown( "
", unsafe_allow_html=True ) if recipe_button: if hasattr(st, "session_state"): st.session_state.get_random_frame = generator.frames[random.randint(0, len(generator.frames)) - 1] else: get_random_frame = generator.frames[random.randint(0, len(generator.frames)) - 1] entered_items.markdown("**Generate recipe for:** " + items) with st.spinner("Generating recipe..."): if not isinstance(items, str) or not len(items) > 1: entered_items.markdown( f"**{chef}** would like to know what ingredients do you like to use in " f"your food? " ) else: gen_kw = chef_top if chef == "Chef Scheherazade" else chef_beam generated_recipe = generator.generate(items, gen_kw) title = generated_recipe["title"] ingredients = generated_recipe["ingredients"] directions = [textwrap.fill(item, 70).replace("\n", "\n ") for item in generated_recipe["directions"]] generated_recipe["by"] = chef r1, r2 = st.beta_columns([3, 5]) with r1: # st.write(st.session_state.get_random_frame) if hasattr(st, "session_state"): recipe_post = generator.generate_frame(generated_recipe, st.session_state.get_random_frame) else: recipe_post = generator.generate_frame(generated_recipe, get_random_frame) st.image( recipe_post, # width=500, caption="Click 🔎 to enlarge", use_column_width="auto", output_format="PNG" ) with r2: st.markdown( " ".join([ "
", f"

{title}

", "

Ingredient

", "
    ", " ".join([f'
  • {item}
  • ' for item in ingredients]), "
", "

Direction

", "
    ", " ".join([f'
  • {item}
  • ' for item in directions]), "
", "
" ]), unsafe_allow_html=True ) if __name__ == '__main__': main()