Shane Weisz commited on
Commit
f648ebc
1 Parent(s): c4aa462

Add app using dialoGPT-finetuned beam10 no-minlen

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ __pycache__
app.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from response_generation import ResponseGenerator
2
+ import gradio as gr
3
+
4
+
5
+ DEFAULT_MODEL = "shaneweisz/DialoGPT-finetuned-multiCONAN"
6
+ DECODING_CONFIG = {"max_new_tokens": 100, "no_repeat_ngram_size": 3, "num_beams": 10}
7
+
8
+ model = ResponseGenerator(DEFAULT_MODEL, DECODING_CONFIG)
9
+
10
+
11
+ def respond(hate_speech_input_text):
12
+ return model.respond(hate_speech_input_text)
13
+
14
+
15
+ demo = gr.Interface(fn=respond, inputs="text", outputs="text")
16
+
17
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.11.0
2
+ transformers==4.19.2
3
+ tqdm==4.64.0
4
+ colorama==0.4.4
5
+ gradio==3.0.20
response_generation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .response_generator import ResponseGenerator
response_generation/min_new_tokens.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import LogitsProcessor
3
+
4
+
5
+ # HuggingFace's generate function does not yet support a `min_new_tokens`, so we need to add the functionality
6
+ # ourselves by adding a custom logits processor. Adapted from:
7
+ # https://huggingface.co/transformers/v4.1.1/_modules/transformers/generation_logits_process.html#MinLengthLogitsProcessor
8
+ class MinNewTokensLogitsProcessor(LogitsProcessor):
9
+ r"""
10
+ A [`LogitsProcessor`] enforcing a minimum response length by setting the `EOS` probability to 0 until
11
+ `min_new_tokens` new tokens have been generated since `input_length`.
12
+ """
13
+ def __init__(self, min_new_tokens: int, eos_token_id: int, input_length: int):
14
+ if not isinstance(min_new_tokens, int) or min_new_tokens < 0:
15
+ raise ValueError(f"`min_new_tokens` has to be a positive integer, but is {min_new_tokens}")
16
+
17
+ if not isinstance(eos_token_id, int) or eos_token_id < 0:
18
+ raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
19
+
20
+ if not isinstance(input_length, int) or input_length < 0:
21
+ raise ValueError(f"`input_length` has to be a positive integer, but is {input_length}")
22
+
23
+ self.min_new_tokens = min_new_tokens
24
+ self.eos_token_id = eos_token_id
25
+ self.input_length = input_length
26
+
27
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
28
+ if not hasattr(self, "input_length"):
29
+ raise ValueError("`save_input_length` has to be called before `__call__`")
30
+
31
+ total_length = input_ids.shape[-1]
32
+ response_len = total_length - self.input_length
33
+
34
+ if response_len < self.min_new_tokens:
35
+ scores[:, self.eos_token_id] = -float("inf")
36
+
37
+ return scores
response_generation/response_generator.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ from tqdm import tqdm
3
+ from colorama import Fore, Style
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList
5
+ import torch
6
+ from .min_new_tokens import MinNewTokensLogitsProcessor
7
+
8
+
9
+ class ResponseGenerator:
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ def __init__(self, pretrained_model_name_or_path: str, decoding_config: Dict[str, Any], seed=42, verbose=True):
13
+ self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).to(self.device)
14
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
15
+ if "pad_token" not in self.tokenizer.special_tokens_map:
16
+ self.tokenizer.pad_token = self.tokenizer.eos_token # A pad token needs to be set for batch decoding
17
+ self.decoding_config = decoding_config
18
+ self.verbose = verbose
19
+ torch.manual_seed(seed)
20
+
21
+ def generate_responses(self, inputs: List[str], batch_size=1) -> List[str]:
22
+ responses = []
23
+ for i in tqdm(range(0, len(inputs), batch_size), disable=not self.verbose):
24
+ batch_inputs = inputs[i : i + batch_size]
25
+ batch_responses = self.generate_responses_for_batch(batch_inputs)
26
+ responses.extend(batch_responses)
27
+ return responses
28
+
29
+ def generate_responses_for_batch(self, inputs: List[str]) -> str:
30
+ inputs = [input_text + self.tokenizer.eos_token for input_text in inputs]
31
+
32
+ self.tokenizer.padding_side = "left"
33
+ tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.device)
34
+ input_len = tokenized_inputs["input_ids"].shape[-1]
35
+
36
+ params_for_generate = self._params_for_generate(input_len)
37
+ output_ids = self.model.generate(
38
+ **tokenized_inputs, **params_for_generate, pad_token_id=self.tokenizer.pad_token_id
39
+ )
40
+
41
+ response_ids = output_ids[:, input_len:]
42
+ responses = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)
43
+
44
+ return responses
45
+
46
+ def _params_for_generate(self, input_length: int) -> Dict[str, Any]:
47
+ params_for_generate = self.decoding_config.copy()
48
+
49
+ if "min_new_tokens" in params_for_generate and params_for_generate["min_new_tokens"] is not None:
50
+ # the HuggingFace `generate` function accepts a `logits_processor` argument, not a `min_new_tokens`,
51
+ # so we replace `min_new_tokens` from the `decoding_config` with our custom logits processor
52
+ # that enforces a minimum response length
53
+ min_new_tokens = params_for_generate["min_new_tokens"]
54
+ min_new_tokens_logits_processor = MinNewTokensLogitsProcessor(
55
+ min_new_tokens, self.tokenizer.eos_token_id, input_length
56
+ )
57
+ params_for_generate["logits_processor"] = LogitsProcessorList([min_new_tokens_logits_processor])
58
+ params_for_generate.pop("min_new_tokens")
59
+
60
+ return params_for_generate
61
+
62
+ def respond(self, input_text: str) -> str:
63
+ """Respond to a single hate speech input."""
64
+ return self.generate_responses([input_text])[0]
65
+
66
+ def interact(self):
67
+ prompt = Fore.RED + "Hate speech: " + Style.RESET_ALL
68
+ input_text = input(prompt)
69
+ while input_text != "":
70
+ print(Fore.GREEN + "Response: " + Style.RESET_ALL, end="")
71
+ response = self.respond(input_text)
72
+ print(response)
73
+ input_text = input(prompt)