File size: 2,492 Bytes
50e4378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch

class MyToxicityDebiaserPipeline(object):
    def __init__(self, model, tokenizer, gpt_model, gpt_tokenizer, device=None, **kwargs):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.gpt_model = gpt_model.to(device)
        self.gpt_tokenizer = gpt_tokenizer
        self.device = device if device is not None else torch.device("cpu")

    def _forward(self, inputs):
        text = inputs["text"]
        encoded = self.tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(self.device)
        logits = self.model(encoded.input_ids, encoded.attention_mask).logits
        probs = torch.softmax(logits, dim=-1)
        label = torch.argmax(probs, dim=-1).item()
        return {"label": label, "probabilities": probs.tolist(), "text_input_ids": encoded.input_ids}

    def _sanitize_parameters(self, **kwargs):
        return kwargs, {}, {}

    def preprocess(self, inputs):
        return {"text": inputs}

    def postprocess(self, outputs):
        label = outputs["label"]
        if label == 0:
            prompt = "This comment is non-toxic."
        elif label == 1:
            prompt = "This comment is toxic but has been debiased as follows:"
            text = self.tokenizer.decode(outputs["text_input_ids"][0])

            debias_prompt = f"Remove the offensive words and biased tone and write the same sentence nicely: {text}"
            encoded_debias_prompt = self.gpt_tokenizer.encode_plus(debias_prompt, return_tensors="pt").to(self.device)

            generated = self.gpt_model.generate(
                input_ids=encoded_debias_prompt["input_ids"],
                attention_mask=encoded_debias_prompt["attention_mask"],
                do_sample=True,
                max_length=100,
                top_p=0.95,
                temperature=0.7,
                pad_token_id=self.gpt_tokenizer.pad_token_id,
                eos_token_id=self.gpt_tokenizer.eos_token_id,
                early_stopping=True,
            )
            generated_text = self.gpt_tokenizer.decode(generated[0], skip_special_tokens=True)
            prompt += f"\nOriginal text: {text}\nDebiased text: {generated_text}"
        return prompt


    def __call__(self, inputs, *args, **kwargs):
        _args, _kwargs, forward_kwargs = self._sanitize_parameters(*args, **kwargs)
        inputs = self.preprocess(inputs)
        outputs = self._forward(inputs, **forward_kwargs)
        return self.postprocess(outputs)