2006elad commited on
Commit
5dcff48
1 Parent(s): 81bcaee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -58
app.py CHANGED
@@ -1,68 +1,25 @@
1
  import streamlit as st
2
- from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
3
- import numpy as np
4
 
5
- MODEL_NAME_OR_PATH = "t5-base"
6
- tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME_OR_PATH)
7
- model = FlaxT5ForConditionalGeneration.from_pretrained(MODEL_NAME_OR_PATH)
8
 
9
- prefix = "items: "
10
- generation_kwargs = {
11
- "max_length": 512,
12
- "min_length": 64,
13
- "no_repeat_ngram_size": 3,
14
- "do_sample": True,
15
- "top_k": 60,
16
- "top_p": 0.95,
17
- "num_return_sequences": 1
18
- }
19
-
20
- def skip_special_tokens(text, special_tokens):
21
- for token in special_tokens:
22
- text = text.replace(token, "")
23
- return text
24
-
25
- def target_postprocessing(texts, special_tokens):
26
- if not isinstance(texts, list):
27
- texts = [texts]
28
- new_texts = []
29
- for text in texts:
30
- text = skip_special_tokens(text, special_tokens)
31
- new_texts.append(text)
32
- return new_texts
33
-
34
- def generate_recipe(items):
35
- inputs = [prefix + items]
36
- inputs = tokenizer(
37
- inputs,
38
- max_length=256,
39
- padding="max_length",
40
- truncation=True,
41
- return_tensors="jax"
42
- )
43
- input_ids = inputs.input_ids
44
- attention_mask = inputs.attention_mask
45
-
46
- output_ids = model.generate(
47
- input_ids=input_ids,
48
- attention_mask=attention_mask,
49
- **generation_kwargs
50
- )
51
-
52
- # Convert output IDs to numpy array
53
- output_ids = np.array(output_ids)
54
-
55
- generated_recipe = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
56
- generated_recipe = target_postprocessing(generated_recipe, tokenizer.all_special_tokens)
57
- return generated_recipe[0]
58
 
59
  def main():
60
  st.title("Recipe Generation")
61
-
62
- items = st.text_input("Enter food items separated by comma (e.g., apple, cucumber):")
63
  if st.button("Generate Recipe"):
64
- generated_recipe = generate_recipe(items)
65
- st.write(generated_recipe)
 
66
 
67
  if __name__ == "__main__":
68
  main()
 
1
  import streamlit as st
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
 
3
 
4
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
5
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
6
+ model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME_OR_PATH)
7
 
8
+ def generate_recipe(input_items):
9
+ prefix = "items: "
10
+ input_text = prefix + input_items
11
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
12
+ output_ids = model.generate(input_ids)
13
+ generated_recipe = tokenizer.decode(output_ids[0], skip_special_tokens=True)
14
+ return generated_recipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def main():
17
  st.title("Recipe Generation")
18
+ input_items = st.text_area("Enter the recipe instructions:")
 
19
  if st.button("Generate Recipe"):
20
+ generated_recipe = generate_recipe(input_items)
21
+ st.subheader("Generated Recipe:")
22
+ st.text(generated_recipe)
23
 
24
  if __name__ == "__main__":
25
  main()