nbroad's picture
nbroad HF staff
minor update
d6db8b6
raw history blame
No virus
11.7 kB
import streamlit as st
import streamlit.components.v1 as components
import torch
from transformers import pipeline, set_seed
from transformers import AutoTokenizer
from PIL import Image, ImageFont, ImageDraw
import re
import textwrap
from examples import EXAMPLES
import meta
from utils import remote_css, local_css, load_image, pure_comma_separation
class TextGeneration:
def __init__(self):
self.debug = True
self.dummy_output = {
"directions": [
"peel the potato and slice thinly.",
"place in a microwave safe dish.",
"cover with plastic wrap and microwave on high for 5 minutes.",
"remove from the microwave and sprinkle with cheese.",
"return to the microwave for 1 minute or until cheese is melted.",
"return to the microwave for 1 minute or until cheese is melted. return to the microwave for 1 minute or until cheese is melted."
"return to the microwave for 1 minute or until cheese is melted.",
"return to the microwave for 1 minute or until cheese is melted.",
"return to the microwave for 1 minute or until cheese is melted.",
],
"ingredients": [
"1 potato",
"1 slice cheese",
"1 potato",
"1 slice cheese" "1 potato",
"1 slice cheese",
"1 slice cheese",
"1 potato",
"1 slice cheese" "1 potato",
"1 slice cheese",
],
"title": "Cheese Potatoes with Some other items",
}
self.tokenizer = None
self.generator = None
self.task = "text2text-generation"
self.model_name_or_path = "flax-community/t5-recipe-generation"
self.list_division = 5
self.point = "•"
self.h1_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Bold.ttf", 75)
self.h2_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Bold.ttf", 50)
self.p_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Regular.ttf", 30)
set_seed(42)
def _skip_special_tokens_and_prettify(self, text):
recipe_maps = {"<sep>": "--", "<section>": "\n"}
recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys()))
text = re.sub(
recipe_map_pattern,
lambda m: recipe_maps[m.group()],
re.sub("|".join(self.tokenizer.all_special_tokens), "", text),
)
data = {"title": "", "ingredients": [], "directions": []}
for section in text.split("\n"):
section = section.strip()
if section.startswith("title:"):
data["title"] = " ".join(
[
w.strip().capitalize()
for w in section.replace("title:", "").strip().split()
if w.strip()
]
)
elif section.startswith("ingredients:"):
data["ingredients"] = [
s.strip() for s in section.replace("ingredients:", "").split("--")
]
elif section.startswith("directions:"):
data["directions"] = [
s.strip() for s in section.replace("directions:", "").split("--")
]
else:
pass
return data
def load(self):
if not self.debug:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self.generator = pipeline(
self.task,
model=self.model_name_or_path,
tokenizer=self.model_name_or_path,
)
def prepare_frame(self, recipe, frame):
im_editable = ImageDraw.Draw(frame)
# Title
ws, hs = 120, 500
im_editable.text(
(ws, hs),
textwrap.fill(recipe["title"], 25),
(61, 61, 70),
font=self.h1_font,
)
# Ingredients
hs = hs + 270
im_editable.text(
(ws, hs),
"Ingredients",
(61, 61, 70),
font=self.h2_font,
)
hs = hs + 80
ingredients = recipe["ingredients"]
ingredients_col1 = [
textwrap.fill(item, 30) for item in ingredients[: self.list_division]
]
ingredients_col2 = [
textwrap.fill(item, 30) for item in ingredients[self.list_division :]
]
im_editable.text(
(ws + 10, hs),
"\n".join([f"{self.point} {item}" for item in ingredients_col1]),
(61, 61, 70),
font=self.p_font,
)
im_editable.text(
(ws + 500, hs),
"\n".join([f"{self.point} {item}" for item in ingredients_col2]),
(61, 61, 70),
font=self.p_font,
)
# Directions
hs = hs + 240
im_editable.text(
(ws, hs),
"Directions",
(61, 61, 70),
font=self.h2_font,
)
hs = hs + 80
directions = [
textwrap.fill(item, 70).replace("\n", "\n ")
for item in recipe["directions"]
]
im_editable.text(
(ws + 10, hs),
"\n".join([f"{num}. {d}" for num, d in enumerate(directions, start=1)]),
(61, 61, 70),
font=self.p_font,
)
return frame
def generate(self, items, generation_kwargs):
print(generation_kwargs)
if not self.debug:
generation_kwargs["num_return_sequences"] = 1
# generation_kwargs["return_full_text"] = False
generation_kwargs["return_tensors"] = True
generation_kwargs["return_text"] = False
generated_ids = self.generator(items, **generation_kwargs,)[
0
]["generated_token_ids"]
recipe = self.tokenizer.decode(generated_ids, skip_special_tokens=False)
recipe = self._skip_special_tokens_and_prettify(recipe)
return recipe
return self.dummy_output
def generate_frame(self, recipe):
frame = load_image("asset/images/recipe-post.png")
return self.prepare_frame(recipe, frame)
@st.cache(allow_output_mutation=True)
def load_text_generator():
generator = TextGeneration()
generator.load()
return generator
chef_top = {
"max_length": 512,
"min_length": 64,
"no_repeat_ngram_size": 3,
"do_sample": True,
"top_k": 60,
"top_p": 0.95,
"num_return_sequences": 1,
}
chef_beam = {
"max_length": 512,
"min_length": 64,
"no_repeat_ngram_size": 3,
"early_stopping": True,
"num_beams": 5,
"length_penalty": 1.5,
"num_return_sequences": 1,
}
def true_false_format_func(choose_yes):
if choose_yes:
return "Yes"
return "No"
def examples_custom_format_func(use_examples):
if use_examples:
return "Examples"
return "Custom ingredients"
def main():
st.set_page_config(
page_title="Chef Transformer",
page_icon="🍲",
layout="wide",
initial_sidebar_state="expanded",
)
generator = load_text_generator()
local_css("asset/css/style.css")
st.sidebar.image(
load_image("asset/images/chef-transformer-transparent.png"), width=310
)
st.sidebar.title("Welcome to our lovely restaurant, how may I serve you?")
chef = st.sidebar.selectbox(
"Choose your chef", index=0, options=["Chef Scheherazade", "Chef Giovanni"]
)
recipe_as_text = st.sidebar.radio(
label="Show Recipe Text",
options=(True, False),
help="A text version of the generated recipe will be displayed (can copy-paste)",
format_func=true_false_format_func,
)
recipe_as_image = st.sidebar.radio(
label="Show Recipe Image",
options=(True, False),
help="An image of the generated recipe will be displayed (useful for social media)",
format_func=true_false_format_func,
)
st.markdown(meta.HEADER_INFO)
use_examples = st.radio(
label="Choose from predefined examples or use custom ingredients",
options=[True, False],
format_func=examples_custom_format_func,
)
input_container = st.empty()
if use_examples:
example_name = input_container.selectbox("Examples", EXAMPLES.keys())
items = pure_comma_separation(EXAMPLES[example_name], return_list=False)
else:
example_name = None
items = input_container.text_input(
"Add custom ingredients here (separated by `,`): ",
key="custom_keywords",
max_chars=1000,
)
entered_items = st.empty()
if st.button("Generate Recipe!"):
entered_items.markdown(
"**Generating recipe using the following ingredients:** " + items
)
with st.spinner("Generating recipe..."):
if isinstance(items, str) and len(items) > 1:
gen_kw = chef_top if chef == "Chef Scheherazade" else chef_beam
generated_recipe = generator.generate(items, gen_kw)
if recipe_as_text:
st.markdown(
'<div>To copy the recipe to your clipboard, hover the mouse in the top-right corner of the section below and then click the button that looks like this: <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg> </div>',
unsafe_allow_html=True,
)
title = generated_recipe["title"]
ingredients = generated_recipe["ingredients"]
directions = [
textwrap.fill(item, 100).replace("\n", "\n ")
for item in generated_recipe["directions"]
]
ingredients = "\n• ".join(ingredients)
directions = "\n".join(
[f"{n}. {d}" for n, d in enumerate(directions, start=1)]
)
st.code(
f"""
{title}\n\n\nIngredients\n\n• {ingredients}\n\nDirections\n\n{directions}
""".strip(),
language="text",
)
if recipe_as_image:
recipe_post = generator.generate_frame(generated_recipe)
col1, col2, col3 = st.beta_columns([1, 6, 1])
with col1:
st.write("")
with col2:
st.image(
recipe_post,
# width=500,
caption="Your recipe",
use_column_width="auto",
output_format="PNG",
)
with col3:
st.write("")
if not recipe_as_text and not recipe_as_image:
st.write(
"Please select 'Yes' for either 'Show Recipe Text' or 'Show Recipe Image' in the left sidebar."
)
else:
entered_items.markdown("Enter your items...")
if __name__ == "__main__":
main()