File size: 2,980 Bytes
aeae383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from transformers import T5TokenizerFast, T5ForConditionalGeneration
import string
from typing import List


SOURCE_MAX_TOKEN_LEN = 512
TARGET_MAX_TOKEN_LEN = 50
SEP_TOKEN = "[SEP]"  
MODEL_NAME = "t5-small"  

# Definisi kelas DistractorGenerator
class DistractorGenerator:
    def __init__(self):
        self.tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME)
        self.tokenizer.add_tokens(SEP_TOKEN)
        self.tokenizer_len = len(self.tokenizer)
        self.model = T5ForConditionalGeneration.from_pretrained("fahmiaziz/QDModel")

    def generate(self, generate_count: int, correct: str, question: str, context: str) -> List[str]:
        model_output = self._model_predict(generate_count, correct, question, context)

        cleaned_result = model_output.replace('<pad>', '').replace('</s>', ',')
        cleaned_result = self._replace_all_extra_id(cleaned_result)
        distractors = cleaned_result.split(",")[:-1]
        distractors = [x.translate(str.maketrans('', '', string.punctuation)) for x in distractors]
        distractors = list(map(lambda x: x.strip(), distractors))

        return distractors

    def _model_predict(self, generate_count: int, correct: str, question: str, context: str) -> str:
        source_encoding = self.tokenizer(
            '{} {} {} {} {}'.format(correct, SEP_TOKEN, question, SEP_TOKEN, context),
            max_length=SOURCE_MAX_TOKEN_LEN,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors='pt'
        )

        generated_ids = self.model.generate(
            input_ids=source_encoding['input_ids'],
            attention_mask=source_encoding['attention_mask'],
            num_beams=generate_count,
            num_return_sequences=generate_count,
            max_length=TARGET_MAX_TOKEN_LEN,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True,
            use_cache=True
        )

        preds = {
            self.tokenizer.decode(generated_id, skip_special_tokens=False, clean_up_tokenization_spaces=True)
            for generated_id in generated_ids
        }

        return ''.join(preds)

    def _correct_index_of(self, text: str, substring: str, start_index: int = 0):
        try:
            index = text.index(substring, start_index)
        except ValueError:
            index = -1

        return index

    def _replace_all_extra_id(self, text: str):
        new_text = text
        start_index_of_extra_id = 0

        while (self._correct_index_of(new_text, '<extra_id_') >= 0):
            start_index_of_extra_id = self._correct_index_of(new_text, '<extra_id_', start_index_of_extra_id)
            end_index_of_extra_id = self._correct_index_of(new_text, '>', start_index_of_extra_id)

            new_text = new_text[:start_index_of_extra_id] + '[SEP]' + new_text[end_index_of_extra_id + 1:]

        return new_text