mean pooling code in readme
#15
by
claralp
- opened
Isn't there a mistake in the example code for mean pooling strategy:outputs = torch.sum(outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"])
but it should be:
outputs = torch.sum(outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"], dim=1, keepdim=True)
?
Otherwise it takes the sum of all embedded texts and not just the current one
Thank you @claralp ! Yes you are right, oversaw this, sorry for that. Will fix!
aamirshakir
changed discussion status to
closed