File size: 3,507 Bytes
f648ebc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Any, Dict, List
from tqdm import tqdm
from colorama import Fore, Style
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList
import torch
from .min_new_tokens import MinNewTokensLogitsProcessor


class ResponseGenerator:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __init__(self, pretrained_model_name_or_path: str, decoding_config: Dict[str, Any], seed=42, verbose=True):
        self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
        if "pad_token" not in self.tokenizer.special_tokens_map:
            self.tokenizer.pad_token = self.tokenizer.eos_token  # A pad token needs to be set for batch decoding
        self.decoding_config = decoding_config
        self.verbose = verbose
        torch.manual_seed(seed)

    def generate_responses(self, inputs: List[str], batch_size=1) -> List[str]:
        responses = []
        for i in tqdm(range(0, len(inputs), batch_size), disable=not self.verbose):
            batch_inputs = inputs[i : i + batch_size]
            batch_responses = self.generate_responses_for_batch(batch_inputs)
            responses.extend(batch_responses)
        return responses

    def generate_responses_for_batch(self, inputs: List[str]) -> str:
        inputs = [input_text + self.tokenizer.eos_token for input_text in inputs]

        self.tokenizer.padding_side = "left"
        tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.device)
        input_len = tokenized_inputs["input_ids"].shape[-1]

        params_for_generate = self._params_for_generate(input_len)
        output_ids = self.model.generate(
            **tokenized_inputs, **params_for_generate, pad_token_id=self.tokenizer.pad_token_id
        )

        response_ids = output_ids[:, input_len:]
        responses = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)

        return responses

    def _params_for_generate(self, input_length: int) -> Dict[str, Any]:
        params_for_generate = self.decoding_config.copy()

        if "min_new_tokens" in params_for_generate and params_for_generate["min_new_tokens"] is not None:
            # the HuggingFace `generate` function accepts a `logits_processor` argument, not a `min_new_tokens`,
            # so we replace `min_new_tokens` from the `decoding_config` with our custom logits processor
            # that enforces a minimum response length
            min_new_tokens = params_for_generate["min_new_tokens"]
            min_new_tokens_logits_processor = MinNewTokensLogitsProcessor(
                min_new_tokens, self.tokenizer.eos_token_id, input_length
            )
            params_for_generate["logits_processor"] = LogitsProcessorList([min_new_tokens_logits_processor])
            params_for_generate.pop("min_new_tokens")

        return params_for_generate

    def respond(self, input_text: str) -> str:
        """Respond to a single hate speech input."""
        return self.generate_responses([input_text])[0]

    def interact(self):
        prompt = Fore.RED + "Hate speech: " + Style.RESET_ALL
        input_text = input(prompt)
        while input_text != "":
            print(Fore.GREEN + "Response: " + Style.RESET_ALL, end="")
            response = self.respond(input_text)
            print(response)
            input_text = input(prompt)