Spaces:
Runtime error
Runtime error
File size: 981 Bytes
cd3659c ab3cc0a cd3659c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# beams = 5, return_seq = 1, max_length = 300
def get_question(sentence,answer,mdl,tknizer, num_seq, num_beams, max_length):
if num_seq > num_beams:
num_seq = num_beams
prompt = "context: {} answer: {}".format(sentence,answer)
print (prompt)
max_len = 256
encoding = tknizer.encode_plus(prompt,max_length=max_len, pad_to_max_length=False,truncation=True, return_tensors="pt")
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
outs = mdl.generate(input_ids=input_ids,
attention_mask=attention_mask,
early_stopping=True,
num_beams=num_beams,
num_return_sequences=num_seq,
no_repeat_ngram_size=2,
max_length=max_length)
dec = [tknizer.decode(ids,skip_special_tokens=True) for ids in outs]
Question = [x.replace("question:", "") for x in dec]
return Question
|