DANTE-GeoLingIT2023 / dual_regression_model.py
MorenoLaQuatra
Map location device
9ccc55b
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer
class DualRegressionModel(nn.Module):
def __init__(
self,
model_name_or_path: str = "camembert/camembert-base",
loss_aggreatation: str = "mean",
):
"""
This class instantiates the pre-training model.
:param model_name_or_path: The name or path of the model to be used for pre-training.
"""
super().__init__()
if "bart" in model_name_or_path:
self.model = AutoModel.from_pretrained(
model_name_or_path, output_hidden_states=True
)
self.model = self.model.encoder
else:
self.model = AutoModelForMaskedLM.from_pretrained(
model_name_or_path, output_hidden_states=True
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.loss_aggreatation = loss_aggreatation
# create two different regression heads for two tasks (latitude and longitude)
self.lat_regression_head = torch.nn.Linear(self.model.config.hidden_size, 1)
self.long_regression_head = torch.nn.Linear(self.model.config.hidden_size, 1)
self.crierion = torch.nn.MSELoss()
def forward(
self,
batch,
):
"""
This function is called to compute the loss for the specified task.
:param batch: The batch of data.
"""
predict = not batch.keys() & {"longitude", "latitude"}
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
if not predict:
latitudes = batch["latitude"]
longitudes = batch["longitude"]
# get the last hidden state
last_hidden_state = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
).hidden_states[-1][:, 0, :]
lat_predictions = self.lat_regression_head(last_hidden_state)
long_predictions = self.long_regression_head(last_hidden_state)
result = {"latitude": lat_predictions, "longitude": long_predictions}
if not predict:
lat_loss = self.crierion(lat_predictions.squeeze(), latitudes)
long_loss = self.crierion(long_predictions.squeeze(), longitudes)
if self.loss_aggreatation == "mean":
loss = (lat_loss + long_loss) / 2
elif self.loss_aggreatation == "sum":
loss = lat_loss + long_loss
else:
raise ValueError("Only mean and sum are supported for loss aggregation")
result |= {"loss": loss}
return result
def save_model(self, path):
"""
This function is called to save the model to a specified path. E.g. "model.pt"
:param path: The path where the model is saved.
"""
torch.save(self.state_dict(), path)
def load_model(self, path):
"""
This function is called to load the model.
:param path: The path where the model is saved. E.g. "model.pt"
"""
# load the state dict
self.load_state_dict(torch.load(path, map_location=torch.device("cpu")))