tgieruc's picture
commit files to HF hub
0d6cb4b
raw
history blame contribute delete
No virus
2.52 kB
from transformers import Pipeline
from transformers import AutoTokenizer
import numpy as np
def softmax(outputs):
maxes = np.max(outputs, axis=-1, keepdims=True)
shifted_exp = np.exp(outputs - maxes)
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
class HeritageDigitalAgePipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "expression1" in kwargs:
preprocess_kwargs["caption"] = str(kwargs["caption"]).lower()
if "expression2" in kwargs:
preprocess_kwargs["title"] = str(kwargs["title"]).lower()
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, maybe_arg=2):
sep_token = "[SEP]"
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True)
model_input = tokenizer.encode(inputs['caption'] + sep_token + inputs['title'], return_tensors='pt', add_special_tokens=True, truncation=True)
return {"model_input": model_input}
def _forward(self, model_inputs):
# model_inputs == {"model_input": model_input}
return self.model(model_inputs['model_input'])
def postprocess(self, model_outputs):
logits = model_outputs.logits[0].numpy()
probabilities = softmax(logits)
best_class = np.argmax(probabilities)
label = self.model.config.id2label[best_class]
score = probabilities[best_class].item()
logits = logits.tolist()
return {"label": label, "score": score, "logits": logits}
class ExpressionRankingPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, maybe_arg=2):
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True)
model_input = tokenizer(inputs, truncation=True, padding="max_length", max_length=256, return_tensors="pt")
return {"model_input": model_input}
def _forward(self, model_inputs):
return self.model(**model_inputs['model_input'])
def postprocess(self, model_outputs):
logits = model_outputs.logits[0].numpy()
probabilities = softmax(logits)
best_class = np.argmax(probabilities)
label = self.model.config.id2label[best_class]
score = probabilities[best_class].item()
logits = logits.tolist()
return {"label": label, "score": score, "logits": logits}