File size: 981 Bytes
cd3659c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab3cc0a
cd3659c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

# beams = 5, return_seq = 1, max_length = 300
def get_question(sentence,answer,mdl,tknizer, num_seq, num_beams, max_length):
    if num_seq > num_beams:
        num_seq = num_beams

    prompt = "context: {} answer: {}".format(sentence,answer)
    print (prompt)
    max_len = 256
    encoding = tknizer.encode_plus(prompt,max_length=max_len, pad_to_max_length=False,truncation=True, return_tensors="pt")

    input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]

    outs = mdl.generate(input_ids=input_ids,
                        attention_mask=attention_mask,
                        early_stopping=True,
                        num_beams=num_beams,
                        num_return_sequences=num_seq,
                        no_repeat_ngram_size=2,
                        max_length=max_length)

    dec = [tknizer.decode(ids,skip_special_tokens=True) for ids in outs]

    Question = [x.replace("question:", "") for x in dec]
    return Question