Update main.py
Browse files
main.py
CHANGED
@@ -12,7 +12,6 @@ test_model = Llama(model_name)
|
|
12 |
|
13 |
class RequestBody(BaseModel):
|
14 |
prompt: str
|
15 |
-
num_return_sequences: int = 10
|
16 |
temperature: float = 1.0
|
17 |
top_k: int = 50
|
18 |
top_p: float = 1.0
|
@@ -21,13 +20,12 @@ class RequestBody(BaseModel):
|
|
21 |
async def generate_text(request: RequestBody):
|
22 |
try:
|
23 |
prompt = sf.encoder(request.prompt)
|
24 |
-
|
25 |
-
|
26 |
-
input_ids=input_ids,
|
27 |
max_new_tokens=512,
|
28 |
num_beams=10,
|
29 |
early_stopping=True,
|
30 |
-
num_return_sequences=
|
31 |
do_sample=True,
|
32 |
top_k = request.top_k,
|
33 |
top_p = request.top_p,
|
@@ -35,8 +33,8 @@ async def generate_text(request: RequestBody):
|
|
35 |
)
|
36 |
|
37 |
result = {'input': prompt}
|
38 |
-
for i in range(
|
39 |
-
output1 =
|
40 |
first_inst_index = output1.find("[/INST]")
|
41 |
second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
|
42 |
predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip()
|
|
|
12 |
|
13 |
class RequestBody(BaseModel):
|
14 |
prompt: str
|
|
|
15 |
temperature: float = 1.0
|
16 |
top_k: int = 50
|
17 |
top_p: float = 1.0
|
|
|
20 |
async def generate_text(request: RequestBody):
|
21 |
try:
|
22 |
prompt = sf.encoder(request.prompt)
|
23 |
+
outputs = test_model(
|
24 |
+
prompt,
|
|
|
25 |
max_new_tokens=512,
|
26 |
num_beams=10,
|
27 |
early_stopping=True,
|
28 |
+
num_return_sequences=10,
|
29 |
do_sample=True,
|
30 |
top_k = request.top_k,
|
31 |
top_p = request.top_p,
|
|
|
33 |
)
|
34 |
|
35 |
result = {'input': prompt}
|
36 |
+
for i in range(10):
|
37 |
+
output1 = outputs[i][len(prompt):]
|
38 |
first_inst_index = output1.find("[/INST]")
|
39 |
second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1)
|
40 |
predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip()
|