Edit model card

Model Card for gpt-sw3-6.7b-v2-translator

The gpt-sw3-6.7b-v2-translator is a finetuned version of gpt-sw3-6.7b-v2-instruct on a carefully selected translation pair dataset that was gathered by AI Sweden.

Intended usage:

Translate text data from English to Swedish, or Swedish to English.

How to use:

import torch
from transformers import pipeline, StoppingCriteriaList, StoppingCriteria

device = "cuda" if torch.cuda.is_available() else "cpu"


# (Optional) - define a stopping criteria
# We ideally want the model to stop generate once the response from the Bot is generated
class StopOnTokenCriteria(StoppingCriteria):
    def __init__(self, stop_token_id):
        self.stop_token_id = stop_token_id

    def __call__(self, input_ids, scores, **kwargs):
        return input_ids[0, -1] == self.stop_token_id


pipe = pipeline(
    task="text-generation",
    model="AI-Sweden-Models/gpt-sw3-6.7b-v2-translator",
    device=device
)

stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=pipe.tokenizer.bos_token_id)
text = "I like to eat ice cream in the summer."

# This will translate English to Swedish
# To translate from Swedish to English the prompt would be:
# prompt = f"<|endoftext|><s>User: Översätt till Engelska från Svenska\n{text}<s>Bot:"

prompt = f"<|endoftext|><s>User: Översätt till Svenska från Engelska\n{text}<s>Bot:"

input_tokens = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
max_model_length = 2048
dynamic_max_length = max_model_length - input_tokens.shape[1]

response = pipe(
    prompt,
    max_length=dynamic_max_length,
    truncation=True,
    stopping_criteria=StoppingCriteriaList([stop_on_token_criteria])
)

print(response[0]["generated_text"].split("<s>Bot: ")[-1])
>>> "Jag tycker om att äta glass på sommaren."

Training & Data:

The training was done on 1 NVIDIA DGX using DeepSpeed ZeRO 3 for three epochs on roughly 4GB of carefully selected translation data. It is a full finetune of all of the model parameters.

Epoch Training Loss Evaluation Loss
1 1.309 1.281
2 1.161 1.242
3 1.053 1.219
Downloads last month
16
Safetensors
Model size
6.71B params
Tensor type
BF16
·
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Finetuned from