mean pooling code in readme

#15
by claralp - opened

Isn't there a mistake in the example code for mean pooling strategy:
outputs = torch.sum(outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"])
but it should be:
outputs = torch.sum(outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"], dim=1, keepdim=True)?

Otherwise it takes the sum of all embedded texts and not just the current one

Mixedbread org

Thank you @claralp ! Yes you are right, oversaw this, sorry for that. Will fix!

aamirshakir changed discussion status to closed

Sign up or log in to comment