File size: 1,902 Bytes
8d29b29 587e383 8d29b29 de9b5a0 587e383 de9b5a0 1cb6077 de9b5a0 1cb6077 de9b5a0 587e383 de9b5a0 587e383 de9b5a0 8d29b29 587e383 1cb6077 de9b5a0 8d29b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import streamlit as st
from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
import numpy as np
MODEL_NAME_OR_PATH = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME_OR_PATH)
model = FlaxT5ForConditionalGeneration.from_pretrained(MODEL_NAME_OR_PATH)
prefix = "items: "
generation_kwargs = {
"max_length": 512,
"min_length": 64,
"no_repeat_ngram_size": 3,
"do_sample": True,
"top_k": 60,
"top_p": 0.95,
"num_return_sequences": 1
}
def skip_special_tokens(text, special_tokens):
for token in special_tokens:
text = text.replace(token, "")
return text
def target_postprocessing(texts, special_tokens):
if not isinstance(texts, list):
texts = [texts]
new_texts = []
for text in texts:
text = skip_special_tokens(text, special_tokens)
new_texts.append(text)
return new_texts
def generate_recipe(items):
inputs = [prefix + items]
inputs = tokenizer(
inputs,
max_length=256,
padding="max_length",
truncation=True,
return_tensors="jax"
)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
output_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**generation_kwargs
)
# Convert output IDs to numpy array
output_ids = np.array(output_ids)
generated_recipe = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
generated_recipe = target_postprocessing(generated_recipe, tokenizer.all_special_tokens)
return generated_recipe[0]
def main():
st.title("Recipe Generation")
items = st.text_input("Enter food items separated by comma (e.g., apple, cucumber):")
if st.button("Generate Recipe"):
generated_recipe = generate_recipe(items)
st.write(generated_recipe)
if __name__ == "__main__":
main()
|