checkpoint-all / README.md
AnonymousPage's picture
Rename readme.md to README.md
bddd7af verified
|
raw
history blame
4.08 kB

Model Summary

The checkpoint aligns with our pixel-linguist-all setting in the paper. The model is initialized from our monolingual model, and is trained on parallel data (205000 steps) <-> AllNLI (2600 steps), going back and forth for three rounds. This model is the last round checkpoint. We recommend using it with A100 GPU, aligning with training.

Downstream Use

Semantic Textual Similarity, Information Retrieval, Reasoning Retrieval

Out-of-Scope Use

The model might not be optimal for further fine-tuning to do other tasks (such as classification), as it's trained to do representation tasks with similarity matching.

Training Data

Please refer to the paper for the exact process.

Inference

Encoding with our PixelLinguist class is very straightforward, just like using a SentenceTransformer class.

model_name = "AnonymousPage/checkpoint-all"
model = PixelLinguist(model_name)
texts = ["I love you","I like you"]
embeddings = model.encode(texts)
print(outputs[0] @ outputs[1].T)  # just use dot product because the embeddings are normalized automatically in the model class.
#tensor(0.9217)

To use the PixelLinguist class: First install the package following our Github Repo. Then define our PixelLinguist Class as follow.

import torch
from PIL import Image
from pixel import (
    AutoConfig,
    PangoCairoTextRenderer,
    PIXELForSequenceClassification,
    PIXELForRepresentation,
    PoolingMode,
    get_attention_mask,
    get_transforms,
    glue_strip_spaces,
    resize_model_embeddings,
)
from tqdm import tqdm
class PixelLinguist:
    def __init__(self, model_name, batch_size = 16, max_seq_length = 64, 
                 device=None, pooling = "mean", keep_mlp = False):
        if device is not None:
            self.device = device
        else:
            self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.config = AutoConfig.from_pretrained(model_name, num_labels=0)
        self.batch_size = batch_size
        if keep_mlp == True:
            self.model = PIXELForSequenceClassification.from_pretrained(
                model_name,
                config=self.config,
                pooling_mode=PoolingMode.from_string(pooling),
                add_layer_norm=True
            ).to(self.device)
        else:
            self.model = PIXELForRepresentation.from_pretrained(
                model_name,
                config=self.config,
                pooling_mode=PoolingMode.from_string(pooling),
                add_layer_norm=True
            ).to(self.device)
        self.processor = PangoCairoTextRenderer.from_pretrained(model_name, rgb=False)
        self.processor.max_seq_length = max_seq_length
        resize_model_embeddings(self.model, self.processor.max_seq_length)
        self.transforms = get_transforms(do_resize=True, size=(self.processor.pixels_per_patch, self.processor.pixels_per_patch * self.processor.max_seq_length))
    def preprocess(self, texts):
        encodings = [self.processor(text=glue_strip_spaces(a)) for a in texts]
        pixel_values = torch.stack([self.transforms(Image.fromarray(e.pixel_values)) for e in encodings])
        attention_mask = torch.stack([get_attention_mask(e.num_text_patches, seq_length=self.processor.max_seq_length) for e in encodings])
        return {'pixel_values': pixel_values, 'attention_mask': attention_mask}
    def encode(self, texts, **kwargs):
        all_outputs = []
        for i in tqdm(range(0, len(texts), self.batch_size)):
            batch_texts = texts[i:i+batch_size]
            inputs = self.preprocess(batch_texts)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = self.model(**inputs).logits.detach().cpu()
            all_outputs.append(outputs)
        return torch.cat(all_outputs, dim=0)

Evaluation

For STS evaluation (see Github repo):

python tools/evaluation_sts_all.py

For BEIR information retrieval evaluation (see Github repo):

python tools/evaluation_retrieval.py