Spaces:
Runtime error
Runtime error
import streamlit as st | |
import streamlit.components.v1 as components | |
import torch | |
from transformers import pipeline, set_seed | |
from transformers import AutoTokenizer | |
from PIL import Image, ImageFont, ImageDraw | |
import re | |
import textwrap | |
from examples import EXAMPLES | |
import meta | |
from utils import remote_css, local_css, load_image, pure_comma_separation | |
class TextGeneration: | |
def __init__(self): | |
self.debug = True | |
self.dummy_output = { | |
"directions": [ | |
"peel the potato and slice thinly.", | |
"place in a microwave safe dish.", | |
"cover with plastic wrap and microwave on high for 5 minutes.", | |
"remove from the microwave and sprinkle with cheese.", | |
"return to the microwave for 1 minute or until cheese is melted.", | |
"return to the microwave for 1 minute or until cheese is melted. return to the microwave for 1 minute or until cheese is melted." | |
"return to the microwave for 1 minute or until cheese is melted.", | |
"return to the microwave for 1 minute or until cheese is melted.", | |
"return to the microwave for 1 minute or until cheese is melted.", | |
], | |
"ingredients": [ | |
"1 potato", | |
"1 slice cheese", | |
"1 potato", | |
"1 slice cheese" "1 potato", | |
"1 slice cheese", | |
"1 slice cheese", | |
"1 potato", | |
"1 slice cheese" "1 potato", | |
"1 slice cheese", | |
], | |
"title": "Cheese Potatoes with Some other items", | |
} | |
self.tokenizer = None | |
self.generator = None | |
self.task = "text2text-generation" | |
self.model_name_or_path = "flax-community/t5-recipe-generation" | |
self.list_division = 5 | |
self.point = "•" | |
self.h1_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Bold.ttf", 75) | |
self.h2_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Bold.ttf", 50) | |
self.p_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Regular.ttf", 30) | |
set_seed(42) | |
def _skip_special_tokens_and_prettify(self, text): | |
recipe_maps = {"<sep>": "--", "<section>": "\n"} | |
recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys())) | |
text = re.sub( | |
recipe_map_pattern, | |
lambda m: recipe_maps[m.group()], | |
re.sub("|".join(self.tokenizer.all_special_tokens), "", text), | |
) | |
data = {"title": "", "ingredients": [], "directions": []} | |
for section in text.split("\n"): | |
section = section.strip() | |
if section.startswith("title:"): | |
data["title"] = " ".join( | |
[ | |
w.strip().capitalize() | |
for w in section.replace("title:", "").strip().split() | |
if w.strip() | |
] | |
) | |
elif section.startswith("ingredients:"): | |
data["ingredients"] = [ | |
s.strip() for s in section.replace("ingredients:", "").split("--") | |
] | |
elif section.startswith("directions:"): | |
data["directions"] = [ | |
s.strip() for s in section.replace("directions:", "").split("--") | |
] | |
else: | |
pass | |
return data | |
def load(self): | |
if not self.debug: | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) | |
self.generator = pipeline( | |
self.task, | |
model=self.model_name_or_path, | |
tokenizer=self.model_name_or_path, | |
) | |
def prepare_frame(self, recipe, frame): | |
im_editable = ImageDraw.Draw(frame) | |
# Title | |
ws, hs = 120, 500 | |
im_editable.text( | |
(ws, hs), | |
textwrap.fill(recipe["title"], 25), | |
(61, 61, 70), | |
font=self.h1_font, | |
) | |
# Ingredients | |
hs = hs + 270 | |
im_editable.text( | |
(ws, hs), | |
"Ingredients", | |
(61, 61, 70), | |
font=self.h2_font, | |
) | |
hs = hs + 80 | |
ingredients = recipe["ingredients"] | |
ingredients_col1 = [ | |
textwrap.fill(item, 30) for item in ingredients[: self.list_division] | |
] | |
ingredients_col2 = [ | |
textwrap.fill(item, 30) for item in ingredients[self.list_division :] | |
] | |
im_editable.text( | |
(ws + 10, hs), | |
"\n".join([f"{self.point} {item}" for item in ingredients_col1]), | |
(61, 61, 70), | |
font=self.p_font, | |
) | |
im_editable.text( | |
(ws + 500, hs), | |
"\n".join([f"{self.point} {item}" for item in ingredients_col2]), | |
(61, 61, 70), | |
font=self.p_font, | |
) | |
# Directions | |
hs = hs + 240 | |
im_editable.text( | |
(ws, hs), | |
"Directions", | |
(61, 61, 70), | |
font=self.h2_font, | |
) | |
hs = hs + 80 | |
directions = [ | |
textwrap.fill(item, 70).replace("\n", "\n ") | |
for item in recipe["directions"] | |
] | |
im_editable.text( | |
(ws + 10, hs), | |
"\n".join([f"{num}. {d}" for num, d in enumerate(directions, start=1)]), | |
(61, 61, 70), | |
font=self.p_font, | |
) | |
return frame | |
def generate(self, items, generation_kwargs): | |
print(generation_kwargs) | |
if not self.debug: | |
generation_kwargs["num_return_sequences"] = 1 | |
# generation_kwargs["return_full_text"] = False | |
generation_kwargs["return_tensors"] = True | |
generation_kwargs["return_text"] = False | |
generated_ids = self.generator(items, **generation_kwargs,)[ | |
0 | |
]["generated_token_ids"] | |
recipe = self.tokenizer.decode(generated_ids, skip_special_tokens=False) | |
recipe = self._skip_special_tokens_and_prettify(recipe) | |
return recipe | |
return self.dummy_output | |
def generate_frame(self, recipe): | |
frame = load_image("asset/images/recipe-post.png") | |
return self.prepare_frame(recipe, frame) | |
def load_text_generator(): | |
generator = TextGeneration() | |
generator.load() | |
return generator | |
chef_top = { | |
"max_length": 512, | |
"min_length": 64, | |
"no_repeat_ngram_size": 3, | |
"do_sample": True, | |
"top_k": 60, | |
"top_p": 0.95, | |
"num_return_sequences": 1, | |
} | |
chef_beam = { | |
"max_length": 512, | |
"min_length": 64, | |
"no_repeat_ngram_size": 3, | |
"early_stopping": True, | |
"num_beams": 5, | |
"length_penalty": 1.5, | |
"num_return_sequences": 1, | |
} | |
def true_false_format_func(choose_yes): | |
if choose_yes: | |
return "Yes" | |
return "No" | |
def examples_custom_format_func(use_examples): | |
if use_examples: | |
return "Examples" | |
return "Custom ingredients" | |
def main(): | |
st.set_page_config( | |
page_title="Chef Transformer", | |
page_icon="🍲", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
generator = load_text_generator() | |
local_css("asset/css/style.css") | |
st.sidebar.image( | |
load_image("asset/images/chef-transformer-transparent.png"), width=310 | |
) | |
st.sidebar.title("Welcome to our lovely restaurant, how may I serve you?") | |
chef = st.sidebar.selectbox( | |
"Choose your chef", index=0, options=["Chef Scheherazade", "Chef Giovanni"] | |
) | |
recipe_as_text = st.sidebar.radio( | |
label="Show Recipe Text", | |
options=(True, False), | |
help="A text version of the generated recipe will be displayed (can copy-paste)", | |
format_func=true_false_format_func, | |
) | |
recipe_as_image = st.sidebar.radio( | |
label="Show Recipe Image", | |
options=(True, False), | |
help="An image of the generated recipe will be displayed (useful for social media)", | |
format_func=true_false_format_func, | |
) | |
st.markdown(meta.HEADER_INFO) | |
use_examples = st.radio( | |
label="Choose from predefined examples or use custom ingredients", | |
options=[True, False], | |
format_func=examples_custom_format_func, | |
) | |
input_container = st.empty() | |
if use_examples: | |
example_name = input_container.selectbox("Examples", EXAMPLES.keys()) | |
items = pure_comma_separation(EXAMPLES[example_name], return_list=False) | |
else: | |
example_name = None | |
items = input_container.text_input( | |
"Add custom ingredients here (separated by `,`): ", | |
key="custom_keywords", | |
max_chars=1000, | |
) | |
entered_items = st.empty() | |
if st.button("Generate Recipe!"): | |
entered_items.markdown( | |
"**Generating recipe using the following ingredients:** " + items | |
) | |
with st.spinner("Generating recipe..."): | |
if isinstance(items, str) and len(items) > 1: | |
gen_kw = chef_top if chef == "Chef Scheherazade" else chef_beam | |
generated_recipe = generator.generate(items, gen_kw) | |
if recipe_as_text: | |
st.markdown( | |
'<div>To copy the recipe to your clipboard, hover the mouse in the top-right corner of the section below and then click the button that looks like this: <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg> </div>', | |
unsafe_allow_html=True, | |
) | |
title = generated_recipe["title"] | |
ingredients = generated_recipe["ingredients"] | |
directions = [ | |
textwrap.fill(item, 100).replace("\n", "\n ") | |
for item in generated_recipe["directions"] | |
] | |
ingredients = "\n• ".join(ingredients) | |
directions = "\n".join( | |
[f"{n}. {d}" for n, d in enumerate(directions, start=1)] | |
) | |
st.code( | |
f""" | |
{title}\n\n\nIngredients\n\n• {ingredients}\n\nDirections\n\n{directions} | |
""".strip(), | |
language="text", | |
) | |
if recipe_as_image: | |
recipe_post = generator.generate_frame(generated_recipe) | |
col1, col2, col3 = st.beta_columns([1, 6, 1]) | |
with col1: | |
st.write("") | |
with col2: | |
st.image( | |
recipe_post, | |
# width=500, | |
caption="Your recipe", | |
use_column_width="auto", | |
output_format="PNG", | |
) | |
with col3: | |
st.write("") | |
if not recipe_as_text and not recipe_as_image: | |
st.write( | |
"Please select 'Yes' for either 'Show Recipe Text' or 'Show Recipe Image' in the left sidebar." | |
) | |
else: | |
entered_items.markdown("Enter your items...") | |
if __name__ == "__main__": | |
main() | |