sunwaee's picture
removed headers
4d679c8
raw history blame
No virus
3.59 kB
from typing import List
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import PerceiverTokenizer
def _map_outputs(predictions):
"""
Map model outputs to classes.
:param predictions: model ouptut batch
:return:
"""
labels = [
"admiration",
"amusement",
"anger",
"annoyance",
"approval",
"caring",
"confusion",
"curiosity",
"desire",
"disappointment",
"disapproval",
"disgust",
"embarrassment",
"excitement",
"fear",
"gratitude",
"grief",
"joy",
"love",
"nervousness",
"optimism",
"pride",
"realization",
"relief",
"remorse",
"sadness",
"surprise",
"neutral"
]
classes = []
for i, example in enumerate(predictions):
out_batch = []
for j, category in enumerate(example):
out_batch.append(labels[j]) if category > 0.5 else None
classes.append(out_batch)
return classes
class MultiLabelPipeline:
"""
Multi label classification pipeline.
"""
def __init__(self, model_path):
"""
Init MLC pipeline.
:param model_path: model to use
"""
# Init attributes
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if self.device == 'cuda':
self.model = torch.load(model_path).eval().to(self.device)
else:
self.model = torch.load(model_path, map_location=torch.device('cpu')).eval().to(self.device)
self.tokenizer = PerceiverTokenizer.from_pretrained('deepmind/language-perceiver')
def __call__(self, dataset, batch_size: int = 4):
"""
Processing pipeline.
:param dataset: dataset
:return:
"""
# Tokenize inputs
dataset = dataset.map(lambda row: self.tokenizer(row['text'], padding="max_length", truncation=True),
batched=True, remove_columns=['text'], desc='Tokenizing')
dataset.set_format('torch', columns=['input_ids', 'attention_mask'])
dataloader = DataLoader(dataset, batch_size=batch_size)
# Define output classes
classes = []
mem_logs = []
with tqdm(dataloader, unit='batches') as progression:
for batch in progression:
progression.set_description('Inference')
# Forward
outputs = self.model(inputs=batch['input_ids'].to(self.device),
attention_mask=batch['attention_mask'].to(self.device), )
# Outputs
predictions = outputs.logits.cpu().detach().numpy()
# Map predictions to classes
batch_classes = _map_outputs(predictions)
for row in batch_classes:
classes.append(row)
# Retrieve memory usage
memory = round(torch.cuda.memory_reserved(self.device) / 1e9, 2)
mem_logs.append(memory)
# Update pbar
progression.set_postfix(memory=f"{round(sum(mem_logs) / len(mem_logs), 2)}Go")
return classes
def inputs_to_dataset(inputs: List[str]):
"""
Convert a list of strings to a dataset object.
:param inputs: list of strings
:return:
"""
inputs = {'text': [input for input in inputs]}
return Dataset.from_dict(inputs)