jejun commited on
Commit
44c581d
·
1 Parent(s): 26246db

Update src/prediction.py

Browse files
Files changed (1) hide show
  1. src/prediction.py +29 -13
src/prediction.py CHANGED
@@ -22,6 +22,8 @@ generation_kwargs = {
22
  "early_stopping": True,
23
  "num_beams": 5,
24
  "length_penalty": 1.5,
 
 
25
  }
26
 
27
  special_tokens = tokenizer.all_special_tokens
@@ -50,7 +52,7 @@ def target_postprocessing(texts, special_tokens):
50
 
51
  return new_texts
52
 
53
- def generation_function(texts):
54
  _inputs = texts if isinstance(texts, list) else [texts]
55
  inputs = [prefix + inp for inp in _inputs]
56
  inputs = tokenizer(
@@ -58,23 +60,37 @@ def generation_function(texts):
58
  max_length=256,
59
  padding="max_length",
60
  truncation=True,
61
- return_tensors='jax'
62
  )
63
 
64
  input_ids = inputs.input_ids
65
  attention_mask = inputs.attention_mask
66
 
67
- output_ids = model.generate(
68
- input_ids=input_ids,
69
- attention_mask=attention_mask,
70
- **generation_kwargs
71
- )
72
- generated = output_ids.sequences
73
- generated_recipe = target_postprocessing(
74
- tokenizer.batch_decode(generated, skip_special_tokens=False),
75
- special_tokens
76
- )
77
- return generated_recipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  items = [
 
22
  "early_stopping": True,
23
  "num_beams": 5,
24
  "length_penalty": 1.5,
25
+ "num_return_sequences": 3, # Generate 3 unique sequences
26
+ "temperature": 0.8
27
  }
28
 
29
  special_tokens = tokenizer.all_special_tokens
 
52
 
53
  return new_texts
54
 
55
+ def generation_function(texts, num_recipes=1):
56
  _inputs = texts if isinstance(texts, list) else [texts]
57
  inputs = [prefix + inp for inp in _inputs]
58
  inputs = tokenizer(
 
60
  max_length=256,
61
  padding="max_length",
62
  truncation=True,
63
+ return_tensors="pt"
64
  )
65
 
66
  input_ids = inputs.input_ids
67
  attention_mask = inputs.attention_mask
68
 
69
+ generated_recipes = []
70
+ while len(generated_recipes) < num_recipes:
71
+ output_ids = model.generate(
72
+ input_ids=input_ids,
73
+ attention_mask=attention_mask,
74
+ **generation_kwargs
75
+ )
76
+ generated = output_ids.detach().cpu().numpy()
77
+ generated_recipe = target_postprocessing(
78
+ tokenizer.batch_decode(generated, skip_special_tokens=False),
79
+ special_tokens
80
+ )
81
+
82
+ # Check if generated_recipe is unique and contains only inputted ingredients
83
+ unique = True
84
+ for recipe in generated_recipes:
85
+ if generated_recipe == recipe or not all(ingredient in generated_recipe[0] for ingredient in texts[0].split(',')):
86
+ unique = False
87
+ break
88
+
89
+ if unique:
90
+ generated_recipes.append(generated_recipe)
91
+
92
+ return generated_recipes[0] if num_recipes == 1 else generated_recipes
93
+
94
 
95
 
96
  items = [