cahya commited on
Commit
e84c607
1 Parent(s): faf39cc

update the generate param

Browse files
Files changed (1) hide show
  1. app/api.py +20 -7
app/api.py CHANGED
@@ -70,20 +70,33 @@ async def websocket_endpoint(websocket: WebSocket):
70
  @app.post("/api/indochat/v1")
71
  async def indochat(
72
  text: str = Form(default="", description="The Prompt"),
 
 
73
  max_length: int = Form(default=250, description="Maximal length of the generated text"),
74
- do_sample: bool = Form(default=True, description="Whether to use sampling; use greedy decoding otherwise"),
75
  top_k: int = Form(default=30, description="The number of highest probability vocabulary tokens to keep "
76
  "for top-k-filtering"),
77
  top_p: float = Form(default=0.95, description="If set to float < 1, only the most probable tokens with "
78
  "probabilities that add up to top_p or higher are kept "
79
  "for generation"),
80
  temperature: float = Form(default=0.5, description="The Temperature of the softmax distribution"),
81
- penalty_alpha: float = Form(default=0.0, description="Penalty alpha"),
82
  repetition_penalty: float = Form(default=1.2, description="Repetition penalty"),
83
- seed: int = Form(default=42, description="Random Seed"),
84
  max_time: float = Form(default=60.0, description="Maximal time in seconds to generate the text")
85
  ):
86
- set_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
87
  if repetition_penalty == 0.0:
88
  min_penalty = 1.05
89
  max_penalty = 1.5
@@ -98,7 +111,8 @@ async def indochat(
98
  sample_outputs = model.generate(input_ids,
99
  penalty_alpha=penalty_alpha,
100
  do_sample=do_sample,
101
- min_length=200,
 
102
  max_length=max_length,
103
  top_k=top_k,
104
  top_p=top_p,
@@ -108,11 +122,10 @@ async def indochat(
108
  max_time=max_time
109
  )
110
  result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
111
- # result = result[len(prompt) + 1:]
112
  time_end = time.time()
113
  time_diff = time_end - time_start
114
  print(f"result:\n{result}")
115
- generated_text = result[len(prompt):]
116
  return {"generated_text": generated_text, "processing_time": time_diff}
117
 
118
 
 
70
  @app.post("/api/indochat/v1")
71
  async def indochat(
72
  text: str = Form(default="", description="The Prompt"),
73
+ decoding_method: str = Form(default="Sampling", description="Decoding method"),
74
+ min_length: int = Form(default=50, description="Minimal length of the generated text"),
75
  max_length: int = Form(default=250, description="Maximal length of the generated text"),
76
+ num_beams: int = Form(default=5, description="Beams number"),
77
  top_k: int = Form(default=30, description="The number of highest probability vocabulary tokens to keep "
78
  "for top-k-filtering"),
79
  top_p: float = Form(default=0.95, description="If set to float < 1, only the most probable tokens with "
80
  "probabilities that add up to top_p or higher are kept "
81
  "for generation"),
82
  temperature: float = Form(default=0.5, description="The Temperature of the softmax distribution"),
83
+ penalty_alpha: float = Form(default=0.5, description="Penalty alpha"),
84
  repetition_penalty: float = Form(default=1.2, description="Repetition penalty"),
85
+ seed: int = Form(default=-1, description="Random Seed"),
86
  max_time: float = Form(default=60.0, description="Maximal time in seconds to generate the text")
87
  ):
88
+ if seed >= 0:
89
+ set_seed(seed)
90
+ if decoding_method == "Beam Search":
91
+ do_sample = False
92
+ penalty_alpha = 0
93
+ elif decoding_method == "Sampling":
94
+ do_sample = True
95
+ penalty_alpha = 0
96
+ num_beams = 1
97
+ else:
98
+ do_sample = False
99
+ num_beams = 1
100
  if repetition_penalty == 0.0:
101
  min_penalty = 1.05
102
  max_penalty = 1.5
 
111
  sample_outputs = model.generate(input_ids,
112
  penalty_alpha=penalty_alpha,
113
  do_sample=do_sample,
114
+ num_beams=num_beams,
115
+ min_length=min_length,
116
  max_length=max_length,
117
  top_k=top_k,
118
  top_p=top_p,
 
122
  max_time=max_time
123
  )
124
  result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
 
125
  time_end = time.time()
126
  time_diff = time_end - time_start
127
  print(f"result:\n{result}")
128
+ generated_text = result[len(prompt)+1:]
129
  return {"generated_text": generated_text, "processing_time": time_diff}
130
 
131