phuongnv commited on
Commit
e875ebf
1 Parent(s): 1a4e8c6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +5 -7
main.py CHANGED
@@ -12,7 +12,6 @@ test_model = Llama(model_name)
12
 
13
  class RequestBody(BaseModel):
14
  prompt: str
15
- num_return_sequences: int = 10
16
  temperature: float = 1.0
17
  top_k: int = 50
18
  top_p: float = 1.0
@@ -21,13 +20,12 @@ class RequestBody(BaseModel):
21
  async def generate_text(request: RequestBody):
22
  try:
23
  prompt = sf.encoder(request.prompt)
24
- input_ids = test_tokenizer(prompt, return_tensors='pt', truncation=False).input_ids
25
- outputs = test_model.generate(
26
- input_ids=input_ids,
27
  max_new_tokens=512,
28
  num_beams=10,
29
  early_stopping=True,
30
- num_return_sequences=request.num_return_sequences,
31
  do_sample=True,
32
  top_k = request.top_k,
33
  top_p = request.top_p,
@@ -35,8 +33,8 @@ async def generate_text(request: RequestBody):
35
  )
36
 
37
  result = {'input': prompt}
38
- for i in range(num_return_sequences):
39
- output1 = test_tokenizer.batch_decode(outputs.detach().numpy(), skip_special_tokens=True)[i][len(prompt):]
40
  first_inst_index = output1.find("[/INST]")
41
  second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
42
  predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip()
 
12
 
13
  class RequestBody(BaseModel):
14
  prompt: str
 
15
  temperature: float = 1.0
16
  top_k: int = 50
17
  top_p: float = 1.0
 
20
  async def generate_text(request: RequestBody):
21
  try:
22
  prompt = sf.encoder(request.prompt)
23
+ outputs = test_model(
24
+ prompt,
 
25
  max_new_tokens=512,
26
  num_beams=10,
27
  early_stopping=True,
28
+ num_return_sequences=10,
29
  do_sample=True,
30
  top_k = request.top_k,
31
  top_p = request.top_p,
 
33
  )
34
 
35
  result = {'input': prompt}
36
+ for i in range(10):
37
+ output1 = outputs[i][len(prompt):]
38
  first_inst_index = output1.find("[/INST]")
39
  second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
40
  predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip()