m3hrdadfi commited on
Commit
3d35750
1 Parent(s): 18b0e4b

Fix config

Browse files
Files changed (1) hide show
  1. server.py +13 -8
server.py CHANGED
@@ -47,7 +47,6 @@ original_keywords = st.multiselect(
47
  ["parmesan cheese", "fresh oregano", "basil", "whole wheat flour"]
48
  )
49
 
50
-
51
  # st.write("Add custom ingredients here:")
52
  # custom_keywords = st_tags(
53
  # label="",
@@ -57,8 +56,8 @@ original_keywords = st.multiselect(
57
  # maxtags=15,
58
  # key='1')
59
 
60
- def custom_keywords_on_change():
61
- pass
62
 
63
 
64
  custom_keywords = st.text_input(
@@ -79,12 +78,18 @@ submit = st.button('Get Recipe!')
79
  if submit:
80
  with st.spinner('Generating recipe...'):
81
  if sampling_mode == "Beam Search":
82
- generated = generator(all_ingredients, return_tensors=True, return_text=False,
83
- **beam_search.generate_kwargs)
 
 
 
84
  outputs = beam_search.post_generator(generated, tokenizer)
85
- elif sampling_mode == "Top-k Sampling":
86
- generated = generator(all_ingredients, return_tensors=True, return_text=False,
87
- **top_sampling.generate_kwargs)
 
 
 
88
  outputs = top_sampling.post_generator(generated, tokenizer)
89
  output = outputs[0]
90
  output['title'] = " ".join([w.capitalize() for w in output['title'].split()])
 
47
  ["parmesan cheese", "fresh oregano", "basil", "whole wheat flour"]
48
  )
49
 
 
50
  # st.write("Add custom ingredients here:")
51
  # custom_keywords = st_tags(
52
  # label="",
 
56
  # maxtags=15,
57
  # key='1')
58
 
59
+ # def custom_keywords_on_change():
60
+ # pass
61
 
62
 
63
  custom_keywords = st.text_input(
 
78
  if submit:
79
  with st.spinner('Generating recipe...'):
80
  if sampling_mode == "Beam Search":
81
+ generated = generator(
82
+ all_ingredients,
83
+ return_tensors=True,
84
+ return_text=False,
85
+ **beam_search.generate_kwargs)
86
  outputs = beam_search.post_generator(generated, tokenizer)
87
+ elif sampling_mode == "Top Sampling":
88
+ generated = generator(
89
+ all_ingredients,
90
+ return_tensors=True,
91
+ return_text=False,
92
+ **top_sampling.generate_kwargs)
93
  outputs = top_sampling.post_generator(generated, tokenizer)
94
  output = outputs[0]
95
  output['title'] = " ".join([w.capitalize() for w in output['title'].split()])