File size: 3,944 Bytes
d655f51
281995c
 
d655f51
 
 
281995c
d655f51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281995c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import spaces
import gradio as gr

# code
import pandas as pd
from datasets import load_dataset

# from sentence_transformers import (
#     SentenceTransformer,
#     SentenceTransformerTrainer,
#     SentenceTransformerTrainingArguments,
#     SentenceTransformerModelCardData
# ) ### we can imporet everhtuing from the main class...

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.training_args import SentenceTransformerTrainingArguments, BatchSamplers




def get_ir_evaluator(eval_ds):
    """create from anchor positive dataset instance... could make from a better dataset... LLM generate?"""

    corpus = {}                                     
    queries = {}
    relevant_docs = {}                                    # relevant documents (qid => set[cid])
    for idx, example in enumerate(eval_ds):
        query = example['anchor']
        queries[idx] = query

        document = example['positive']
        corpus[idx] = document

        relevant_docs[idx] = set([idx])                   # note: should have more relevant docs here

    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name="ir-evaluator",
    )
    return ir_evaluator

    


@spaces.GPU(duration=3600)
def train(hf_token, dataset_id, model_id, num_epochs, dev):

    ds = load_dataset(dataset_id, split="train", token=hf_token)
    ds = ds.shuffle(seed=42)
    if len(ds) > 1000 and dev: ds = ds.select(range(0, 999))
    ds = ds.train_test_split(train_size=0.75)
    train_ds, eval_ds = ds['train'], ds['test']
    print('train: ', len(train_ds), 'eval: ', len(eval_ds))

    # model
    model = SentenceTransformer(model_id)

    # loss
    loss = CachedMultipleNegativesRankingLoss(model)

    # training args
    args = SentenceTransformerTrainingArguments(
        output_dir="outputs",                       # required 
        num_train_epochs=num_epochs,                # optional...
        per_device_train_batch_size=16,
        warmup_ratio=0.1,
        #fp16=True,                                  # Set to False if your GPU can't handle FP16
        #bf16=False,                                 # Set to True if your GPU supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # Losses using "in-batch negatives" benefit from no duplicates
        save_total_limit=2
        # per_device_eval_batch_size=1,
        # eval_strategy="epoch",
        # save_strategy="epoch",
        # logging_steps=100,
        # Optional tracking/debugging parameters:
        # eval_strategy="steps",
        # eval_steps=100,
        # save_strategy="steps",
        # save_steps=100,
        # logging_steps=100,
        # run_name="jina-code-vechain-pair",  # Used in W&B if `wandb` is installed
    )

    # ir evaluator
    ir_evaluator = get_ir_evaluator(eval_ds)

    # base model metrics
    base_metrics = ir_evaluator(model)
    print(ir_evaluator.primary_metric)
    print(base_metrics[ir_evaluator.primary_metric])


    # train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        # eval_dataset=eval_ds,
        loss=loss,
        # evaluator=ir_evaluator,
    )
    trainer.train()

    # fine tuned model metrics
    ft_metrics = ir_evaluator(model)
    print(ir_evaluator.primary_metric)
    print(ft_metrics[ir_evaluator.primary_metric])


    metrics = pd.DataFrame([base_metrics, ft_metrics]).T
    print(metrics)
    return str(metrics)


## logs to UI
# https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778

demo = gr.Interface(fn=greet, inputs=["text", "text", "text", "number", "bool"], outputs=["text"]) # "dataframe"
demo.launch()