bge-m3-Ko / README.md
jaeyong2's picture
Update README.md
a76eb9d verified
|
raw
history blame
2.24 kB
metadata
license: mit
language:
  - ko
base_model:
  - BAAI/bge-m3

Model Card for Model ID

Model Details

Train

  • H/W : colab A100 40GB
  • Data : jaeyong2/Ko-emb-PreView (step : 18000)
!torchrun --nproc_per_node 1 \
-m FlagEmbedding.finetune.embedder.encoder_only.m3 \
--output_dir "/content/drive/My Drive/bge_ko.pth" \
--model_name_or_path BAAI/bge-m3 \
--train_data ./train.jsonl \
--learning_rate 1e-5 \
--bf16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--dataloader_drop_last True \
--temperature 0.02 \
--query_max_len 2048 \
--passage_max_len 512 \
--train_group_size 2 \
--negatives_cross_device \
--logging_steps 10 \
--save_steps 1000 \
--query_instruction_for_retrieval ""

Evaluation

Code :

import torch
import numpy as np
from sklearn.metrics import pairwise_distances
from tqdm import tqdm
import datasets
def get_embedding(text, model):
    with torch.no_grad():
        embedding = model.encode(text)['dense_vecs']
    return embedding


dataset = datasets.load_dataset("jaeyong2/Ko-emb-PreView")
validation_dataset = dataset["test"].select(range((1000)))


def evaluate(validation_dataset):
    correct_count = 0

    for item in tqdm(validation_dataset):
        query_embedding = get_embedding(item["context"], fine_tuned_model)
        document_embedding = get_embedding(item["Title"], fine_tuned_model)
        negative_embedding = get_embedding(item["Fake Title"], fine_tuned_model)
      

        # 쿼리와 모든 문서 간의 유사도 계산 (코사인 거리 사용)
        positive_distances = pairwise_distances(query_embedding.reshape(1, -1), document_embedding.reshape(1, -1), metric="cosine")
        negative_distances = pairwise_distances(query_embedding.reshape(1, -1), negative_embedding.reshape(1, -1), metric="cosine")

        if positive_distances < negative_distances:
            correct_count += 1

    accuracy = correct_count / len(validation_dataset)
    return accuracy

results = evaluate(validation_dataset)
print(f"Validation Results: {results}")

Accuracy

  • BAAI/bge-m3: 0.971
  • jaeyong2/bge-m3-Ko : 0.992

License