m3hrdadfi's picture
Reprogramm demo
ac6bfa9
raw history blame
No virus
7.55 kB
import streamlit as st
import torch
from transformers import pipeline, set_seed
from transformers import AutoTokenizer
from PIL import (
Image,
ImageFont,
ImageDraw
)
import re
from examples import EXAMPLES
import meta
from utils import (
remote_css,
local_css,
load_image,
pure_comma_separation
)
class TextGeneration:
def __init__(self):
self.debug = True
self.dummy_output = {
'directions': [
'peel the potato and slice thinly.',
'place in a microwave safe dish.',
'cover with plastic wrap and microwave on high for 5 minutes.',
'remove from the microwave and sprinkle with cheese.',
'return to the microwave for 1 minute or until cheese is melted.',
'return to the microwave for 1 minute or until cheese is melted.',
'return to the microwave for 1 minute or until cheese is melted.'
],
'ingredients': [
'1 potato',
'1 slice cheese',
'1 potato',
'1 slice cheese'
'1 potato',
'1 slice cheese',
'1 slice cheese',
'1 potato',
'1 slice cheese'
'1 potato',
'1 slice cheese',
],
'title': 'Cheese Potatoes'
}
self.tokenizer = None
self.generator = None
self.task = "text2text-generation"
self.model_name_or_path = "flax-community/t5-recipe-generation"
self.h1_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Bold.ttf", 100)
self.h2_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Bold.ttf", 50)
self.p_font = ImageFont.truetype("asset/fonts/PT_Serif/PTSerif-Regular.ttf", 30)
set_seed(42)
def _skip_special_tokens_and_prettify(self, text):
recipe_maps = {"<sep>": "--", "<section>": "\n"}
recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys()))
text = re.sub(
recipe_map_pattern,
lambda m: recipe_maps[m.group()],
re.sub("|".join(self.tokenizer.all_special_tokens), "", text)
)
data = {"title": "", "ingredients": [], "directions": []}
for section in text.split("\n"):
section = section.strip()
if section.startswith("title:"):
data["title"] = " ".join(
[w.strip().capitalize() for w in section.replace("title:", "").strip().split() if w.strip()]
)
elif section.startswith("ingredients:"):
data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')]
elif section.startswith("directions:"):
data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')]
else:
pass
return data
def load(self):
if not self.debug:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self.generator = pipeline(self.task, model=self.model_name_or_path, tokenizer=self.model_name_or_path)
def prepare_frame(self, recipe, frame):
im_editable = ImageDraw.Draw(frame)
# Title
ws, hs = 120, 500
im_editable.text(
(ws, hs),
recipe["title"],
(61, 61, 70),
font=self.h1_font,
)
# Ingredients
hs = hs + 200
im_editable.text(
(ws, hs),
"Ingredients",
(61, 61, 70),
font=self.h2_font,
)
hs = hs + 80
im_editable.text(
(ws + 10, hs),
"\n".join([f"- {item}" for item in recipe["ingredients"]]),
(61, 61, 70),
font=self.p_font,
)
# Directions
hs = hs + 400
im_editable.text(
(ws, hs),
"Directions",
(61, 61, 70),
font=self.h2_font,
)
hs = hs + 80
im_editable.text(
(ws + 10, hs),
"\n".join([f"- {item}" for item in recipe["directions"]]),
(61, 61, 70),
font=self.p_font,
)
return frame
def generate(self, items, generation_kwargs):
print(generation_kwargs)
if not self.debug:
generation_kwargs["num_return_sequences"] = 1
# generation_kwargs["return_full_text"] = False
generation_kwargs["return_tensors"] = True
generation_kwargs["return_text"] = False
generated_ids = self.generator(
items,
**generation_kwargs,
)[0]["generated_token_ids"]
recipe = self.tokenizer.decode(generated_ids, skip_special_tokens=False)
recipe = self._skip_special_tokens_and_prettify(recipe)
return recipe
return self.dummy_output
def generate_frame(self, recipe):
frame = load_image("asset/images/recipe-post.png")
return self.prepare_frame(recipe, frame)
@st.cache(allow_output_mutation=True)
def load_text_generator():
generator = TextGeneration()
generator.load()
return generator
chef_top = {
"max_length": 512,
"min_length": 64,
"no_repeat_ngram_size": 3,
"do_sample": True,
"top_k": 60,
"top_p": 0.95,
"num_return_sequences": 1
}
chef_beam = {
"max_length": 512,
"min_length": 64,
"no_repeat_ngram_size": 3,
"early_stopping": True,
"num_beams": 5,
"length_penalty": 1.5,
"num_return_sequences": 1
}
def main():
st.set_page_config(
page_title="Chef Transformer",
page_icon="🍲",
layout="wide",
initial_sidebar_state="expanded"
)
generator = load_text_generator()
local_css("asset/css/style.css")
st.sidebar.image(load_image("asset/images/chef-transformer-transparent.png"), width=310)
st.sidebar.title("Choose your own chef")
chef = st.sidebar.selectbox("Chef", index=0, options=["Chef Scheherazade", "Chef Giovanni"])
st.markdown(meta.HEADER_INFO)
prompts = list(EXAMPLES.keys()) + ["Custom"]
prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
if prompt == "Custom":
prompt_box = ""
else:
prompt_box = EXAMPLES[prompt]
items = st.text_input(
'Add custom ingredients here (separated by `,`): ',
pure_comma_separation(prompt_box, return_list=False),
key="custom_keywords",
max_chars=1000)
items = pure_comma_separation(items, return_list=False)
entered_items = st.empty()
if st.button('Get Recipe!'):
entered_items.markdown("**Generate recipe for:** " + items)
with st.spinner("Generating recipe..."):
gen_kw = chef_top if chef == "Chef Scheherazade" else chef_beam
generated_recipe = generator.generate(items, gen_kw)
recipe_post = generator.generate_frame(generated_recipe)
col1, col2, col3 = st.beta_columns([1, 6, 1])
with col1:
st.write("")
with col2:
st.image(
recipe_post,
# width=500,
caption="Your recipe",
use_column_width="auto",
output_format="PNG"
)
with col3:
st.write("")
if __name__ == '__main__':
main()