import torch class MyToxicityDebiaserPipeline(object): def __init__(self, model, tokenizer, gpt_model, gpt_tokenizer, device=None, **kwargs): self.model = model.to(device) self.tokenizer = tokenizer self.gpt_model = gpt_model.to(device) self.gpt_tokenizer = gpt_tokenizer self.device = device if device is not None else torch.device("cpu") def _forward(self, inputs): text = inputs["text"] encoded = self.tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(self.device) logits = self.model(encoded.input_ids, encoded.attention_mask).logits probs = torch.softmax(logits, dim=-1) label = torch.argmax(probs, dim=-1).item() return {"label": label, "probabilities": probs.tolist(), "text_input_ids": encoded.input_ids} def _sanitize_parameters(self, **kwargs): return kwargs, {}, {} def preprocess(self, inputs): return {"text": inputs} def postprocess(self, outputs): label = outputs["label"] if label == 0: prompt = "This comment is non-toxic." elif label == 1: prompt = "This comment is toxic but has been debiased as follows:" text = self.tokenizer.decode(outputs["text_input_ids"][0]) debias_prompt = f"Remove the offensive words and biased tone and write the same sentence nicely: {text}" encoded_debias_prompt = self.gpt_tokenizer.encode_plus(debias_prompt, return_tensors="pt").to(self.device) generated = self.gpt_model.generate( input_ids=encoded_debias_prompt["input_ids"], attention_mask=encoded_debias_prompt["attention_mask"], do_sample=True, max_length=100, top_p=0.95, temperature=0.7, pad_token_id=self.gpt_tokenizer.pad_token_id, eos_token_id=self.gpt_tokenizer.eos_token_id, early_stopping=True, ) generated_text = self.gpt_tokenizer.decode(generated[0], skip_special_tokens=True) prompt += f"\nOriginal text: {text}\nDebiased text: {generated_text}" return prompt def __call__(self, inputs, *args, **kwargs): _args, _kwargs, forward_kwargs = self._sanitize_parameters(*args, **kwargs) inputs = self.preprocess(inputs) outputs = self._forward(inputs, **forward_kwargs) return self.postprocess(outputs)