Spaces:
Runtime error
Runtime error
from typing import Any, Dict, List | |
from tqdm import tqdm | |
from colorama import Fore, Style | |
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList | |
import torch | |
from .min_new_tokens import MinNewTokensLogitsProcessor | |
class ResponseGenerator: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def __init__(self, pretrained_model_name_or_path: str, decoding_config: Dict[str, Any], seed=42, verbose=True): | |
self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) | |
if "pad_token" not in self.tokenizer.special_tokens_map: | |
self.tokenizer.pad_token = self.tokenizer.eos_token # A pad token needs to be set for batch decoding | |
self.decoding_config = decoding_config | |
self.verbose = verbose | |
torch.manual_seed(seed) | |
def generate_responses(self, inputs: List[str], batch_size=1) -> List[str]: | |
responses = [] | |
for i in tqdm(range(0, len(inputs), batch_size), disable=not self.verbose): | |
batch_inputs = inputs[i : i + batch_size] | |
batch_responses = self.generate_responses_for_batch(batch_inputs) | |
responses.extend(batch_responses) | |
return responses | |
def generate_responses_for_batch(self, inputs: List[str]) -> str: | |
inputs = [input_text + self.tokenizer.eos_token for input_text in inputs] | |
self.tokenizer.padding_side = "left" | |
tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(self.device) | |
input_len = tokenized_inputs["input_ids"].shape[-1] | |
params_for_generate = self._params_for_generate(input_len) | |
output_ids = self.model.generate( | |
**tokenized_inputs, **params_for_generate, pad_token_id=self.tokenizer.pad_token_id | |
) | |
response_ids = output_ids[:, input_len:] | |
responses = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) | |
return responses | |
def _params_for_generate(self, input_length: int) -> Dict[str, Any]: | |
params_for_generate = self.decoding_config.copy() | |
if "min_new_tokens" in params_for_generate and params_for_generate["min_new_tokens"] is not None: | |
# the HuggingFace `generate` function accepts a `logits_processor` argument, not a `min_new_tokens`, | |
# so we replace `min_new_tokens` from the `decoding_config` with our custom logits processor | |
# that enforces a minimum response length | |
min_new_tokens = params_for_generate["min_new_tokens"] | |
min_new_tokens_logits_processor = MinNewTokensLogitsProcessor( | |
min_new_tokens, self.tokenizer.eos_token_id, input_length | |
) | |
params_for_generate["logits_processor"] = LogitsProcessorList([min_new_tokens_logits_processor]) | |
params_for_generate.pop("min_new_tokens") | |
return params_for_generate | |
def respond(self, input_text: str) -> str: | |
"""Respond to a single hate speech input.""" | |
return self.generate_responses([input_text])[0] | |
def interact(self): | |
prompt = Fore.RED + "Hate speech: " + Style.RESET_ALL | |
input_text = input(prompt) | |
while input_text != "": | |
print(Fore.GREEN + "Response: " + Style.RESET_ALL, end="") | |
response = self.respond(input_text) | |
print(response) | |
input_text = input(prompt) | |