nbroad's picture
nbroad HF staff
resolve conflicts
0b546b7
raw history blame
No virus
12.5 kB
import streamlit as st
import torch
from transformers import pipeline, set_seed
from transformers import AutoTokenizer
from PIL import (
Image,
ImageFont,
ImageDraw
)
import numpy as np
from datetime import datetime
import glob
import re
import random
import textwrap
from examples import EXAMPLES
import meta
from utils import (
remote_css,
local_css,
load_image,
pure_comma_separation
)
# np.random.seed(42)
class TextGeneration:
def __init__(self):
self.debug = False
self.dummy_output = {
'directions': [
"for the dough",
"in a small bowl, combine the warm water and yeast.",
"let it sit for 5 minutes.",
"add the flour, salt, and olive oil.",
"mix well and knead until the dough is smooth and elastic.",
"cover the dough with a damp towel and let it rise for about 1 hour.",
"for the filling",
"heat a large skillet over medium high heat.",
"cook the beef and onion until the beef is browned and the onion is translucent. browned and the onion is translucent.",
"drain off any excess grease.",
"stir in the pepper and salt and black pepper to taste.",
"remove from the heat and set aside.",
"preheat the oven to 425 degrees f.",
"roll out the dough on a lightly floured surface into a 12 inch circle.",
"spread the beef mixture over the dough, leaving a 1 inch border.",
"top with the feta, parsley, and lemon juice.",
"bake for 20 minutes or until the crust is golden brown.",
"cut into wedges and serve.",
],
'ingredients': [
'1 potato',
'1 slice cheese 1 slice cheese',
'1 potato 1 slice cheese 1 slice cheese',
'1 slice cheese'
'1 potato',
'1 slice cheese',
'1 slice cheese',
'1 potato',
'1 slice cheese'
'1 potato',
'1 slice cheese',
],
'title': 'Cheese Potatoes with Some other items'
}
self.tokenizer = None
self.generator = None
self.task = "text2text-generation"
self.model_name_or_path = "flax-community/t5-recipe-generation"
self.frames = list(glob.glob("asset/images/frames/*.jpg"))
self.fonts = {
"title": ImageFont.truetype("asset/fonts/Poppins-Bold.ttf", 90),
"sub_title": ImageFont.truetype("asset/fonts/Poppins-Medium.ttf", 30),
"body_bold": ImageFont.truetype("asset/fonts/Montserrat-Bold.ttf", 50),
"body": ImageFont.truetype("asset/fonts/Montserrat-Regular.ttf", 30),
}
self.list_division = 5
self.point = "•"
set_seed(42)
def _skip_special_tokens_and_prettify(self, text):
recipe_maps = {"<sep>": "--", "<section>": "\n"}
recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys()))
text = re.sub(
recipe_map_pattern,
lambda m: recipe_maps[m.group()],
re.sub("|".join(self.tokenizer.all_special_tokens), "", text)
)
data = {"title": "", "ingredients": [], "directions": []}
for section in text.split("\n"):
section = section.strip()
if section.startswith("title:"):
data["title"] = " ".join(
[w.strip().capitalize() for w in section.replace("title:", "").strip().split() if w.strip()]
)
elif section.startswith("ingredients:"):
data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')]
elif section.startswith("directions:"):
data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')]
else:
pass
return data
def load(self):
if not self.debug:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self.generator = pipeline(self.task, model=self.model_name_or_path, tokenizer=self.model_name_or_path)
def prepare_frame(self, recipe, frame):
im_editable = ImageDraw.Draw(frame)
# Info
im_editable.text(
(1570, 180),
"Created " + datetime.now().strftime("%Y-%m-%d %H:%M"),
(61, 61, 70),
font=self.fonts["sub_title"],
)
im_editable.text(
(1570, 170 + 70),
"By " + recipe["by"],
(61, 61, 70),
font=self.fonts["sub_title"],
)
# Title
im_editable.text(
(400, 650),
textwrap.fill(recipe["title"], 30),
(61, 61, 70),
font=self.fonts["title"],
)
# Ingredients
im_editable.text(
(100, 1000),
"Ingredients",
(61, 61, 70),
font=self.fonts["body_bold"],
)
ingredients = recipe["ingredients"]
ingredients = [textwrap.fill(item, 30).replace("\n", "\n ") for item in ingredients]
im_editable.text(
(100, 1080),
"\n".join([f"{self.point} {item}" for item in ingredients]),
(61, 61, 70),
font=self.fonts["body"],
)
# Directions
im_editable.text(
(700, 1000),
"Directions",
(61, 61, 70),
font=self.fonts["body_bold"],
)
directions = recipe["directions"]
directions = [textwrap.fill(item, 80).replace("\n", "\n ") for item in directions]
im_editable.text(
(700, 1080),
"\n".join([f"{i + 1}. {item}" for i, item in enumerate(directions)]).strip(),
(61, 61, 70),
font=self.fonts["body"],
)
# directions_col1 = [textwrap.fill(item, 30).replace("\n", "\n ") for item in directions[:self.list_division]]
# directions_col2 = [textwrap.fill(item, 30).replace("\n", "\n ") for item in directions[self.list_division:]]
# im_editable.text(
# (700, 1130),
# "\n".join([f"{i + 1}. {item}" for i, item in enumerate(directions_col1)]).strip(),
# (61, 61, 70),
# font=self.fonts["body"],
# )
# im_editable.text(
# (1300, 1130),
# "\n".join([f"{i + 1 + self.list_division}. {item}" for i, item in enumerate(directions_col2)]).strip(),
# (61, 61, 70),
# font=self.fonts["body"],
# )
return frame
def generate(self, items, generation_kwargs):
if not self.debug:
generation_kwargs["num_return_sequences"] = 1
# generation_kwargs["return_full_text"] = False
generation_kwargs["return_tensors"] = True
generation_kwargs["return_text"] = False
generated_ids = self.generator(
items,
**generation_kwargs,
)[0]["generated_token_ids"]
recipe = self.tokenizer.decode(generated_ids, skip_special_tokens=False)
recipe = self._skip_special_tokens_and_prettify(recipe)
return recipe
return self.dummy_output
def generate_frame(self, recipe, frame_path):
frame = load_image(frame_path)
return self.prepare_frame(recipe, frame)
@st.cache(allow_output_mutation=True)
def load_text_generator():
generator = TextGeneration()
generator.load()
return generator
chef_top = {
"max_length": 512,
"min_length": 64,
"no_repeat_ngram_size": 3,
"do_sample": True,
"top_k": 60,
"top_p": 0.95,
"num_return_sequences": 1
}
chef_beam = {
"max_length": 512,
"min_length": 64,
"no_repeat_ngram_size": 3,
"early_stopping": True,
"num_beams": 5,
"length_penalty": 1.5,
"num_return_sequences": 1
}
def main():
st.set_page_config(
page_title="Chef Transformer",
page_icon="🍲",
layout="wide",
initial_sidebar_state="expanded"
)
generator = load_text_generator()
if hasattr(st, "session_state"):
if 'get_random_frame' not in st.session_state:
st.session_state.get_random_frame = generator.frames[0]
else:
get_random_frame = generator.frames[0]
local_css("asset/css/style.css")
col1, col2 = st.beta_columns([5, 3])
with col2:
st.image(load_image("asset/images/chef-transformer-transparent.png"), width=300)
st.markdown(meta.SIDEBAR_INFO, unsafe_allow_html=True)
with st.beta_expander("Where did this story start?"):
st.markdown(meta.STORY, unsafe_allow_html=True)
with col1:
st.markdown(meta.HEADER_INFO, unsafe_allow_html=True)
st.markdown(meta.CHEF_INFO, unsafe_allow_html=True)
chef = st.selectbox("Choose your chef", index=0, options=["Chef Scheherazade", "Chef Giovanni"])
prompts = list(EXAMPLES.keys()) + ["Custom"]
prompt = st.selectbox(
'Examples (select from this list)',
prompts,
# index=len(prompts) - 1,
index=0
)
if prompt == "Custom":
prompt_box = ""
else:
prompt_box = EXAMPLES[prompt]
items = st.text_area(
'Insert your ingredients here (separated by `,`): ',
pure_comma_separation(prompt_box, return_list=False),
)
items = pure_comma_separation(items, return_list=False)
entered_items = st.empty()
recipe_button = st.button('Get Recipe!')
st.markdown(
"<hr />",
unsafe_allow_html=True
)
if recipe_button:
if hasattr(st, "session_state"):
st.session_state.get_random_frame = generator.frames[random.randint(0, len(generator.frames)) - 1]
else:
get_random_frame = generator.frames[random.randint(0, len(generator.frames)) - 1]
entered_items.markdown("**Generate recipe for:** " + items)
with st.spinner("Generating recipe..."):
if not isinstance(items, str) or not len(items) > 1:
entered_items.markdown(
f"**{chef}** would like to know what ingredients do you like to use in "
f"your food? "
)
else:
gen_kw = chef_top if chef == "Chef Scheherazade" else chef_beam
generated_recipe = generator.generate(items, gen_kw)
title = generated_recipe["title"]
ingredients = generated_recipe["ingredients"]
directions = [textwrap.fill(item, 70).replace("\n", "\n ") for item in
generated_recipe["directions"]]
generated_recipe["by"] = chef
r1, r2 = st.beta_columns([3, 5])
with r1:
# st.write(st.session_state.get_random_frame)
if hasattr(st, "session_state"):
recipe_post = generator.generate_frame(generated_recipe, st.session_state.get_random_frame)
else:
recipe_post = generator.generate_frame(generated_recipe, get_random_frame)
st.image(
recipe_post,
# width=500,
caption="Click 🔎 to enlarge",
use_column_width="auto",
output_format="PNG"
)
with r2:
st.markdown(
" ".join([
"<div class='r-text-recipe'>",
f"<h2>{title}</h2>",
"<h3>Ingredients</h3>",
"<ul class='ingredients-list'>",
" ".join([f'<li>{item}</li>' for item in ingredients]),
"</ul>",
"<h3>Directions</h3>",
"<ol class='ingredients-list'>",
" ".join([f'<li>{item}</li>' for item in directions]),
"</ol>",
"</div>"
]),
unsafe_allow_html=True
)
if __name__ == '__main__':
main()