DoctorSlimm
commited on
Commit
•
d655f51
1
Parent(s):
281995c
add train code and requirements text file...
Browse files- app.py +119 -3
- requirements.txt +3 -0
app.py
CHANGED
@@ -1,7 +1,123 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
demo.launch()
|
|
|
1 |
+
import spaces
|
2 |
import gradio as gr
|
3 |
|
4 |
+
# code
|
5 |
+
import pandas as pd
|
6 |
+
from datasets import load_dataset
|
7 |
|
8 |
+
# from sentence_transformers import (
|
9 |
+
# SentenceTransformer,
|
10 |
+
# SentenceTransformerTrainer,
|
11 |
+
# SentenceTransformerTrainingArguments,
|
12 |
+
# SentenceTransformerModelCardData
|
13 |
+
# ) ### we can imporet everhtuing from the main class...
|
14 |
+
|
15 |
+
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
|
16 |
+
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
|
17 |
+
from sentence_transformers.evaluation import InformationRetrievalEvaluator
|
18 |
+
from sentence_transformers.training_args import SentenceTransformerTrainingArguments, BatchSamplers
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
def get_ir_evaluator(eval_ds):
|
24 |
+
"""create from anchor positive dataset instance... could make from a better dataset... LLM generate?"""
|
25 |
+
|
26 |
+
corpus = {}
|
27 |
+
queries = {}
|
28 |
+
relevant_docs = {} # relevant documents (qid => set[cid])
|
29 |
+
for idx, example in enumerate(eval_ds):
|
30 |
+
query = example['anchor']
|
31 |
+
queries[idx] = query
|
32 |
+
|
33 |
+
document = example['positive']
|
34 |
+
corpus[idx] = document
|
35 |
+
|
36 |
+
relevant_docs[idx] = set([idx]) # note: should have more relevant docs here
|
37 |
+
|
38 |
+
ir_evaluator = InformationRetrievalEvaluator(
|
39 |
+
queries=queries,
|
40 |
+
corpus=corpus,
|
41 |
+
relevant_docs=relevant_docs,
|
42 |
+
name="ir-evaluator",
|
43 |
+
)
|
44 |
+
return ir_evaluator
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
@spaces.GPU(duration=3600)
|
50 |
+
def train(hf_token, dataset_id, model_id, num_epochs, dev):
|
51 |
+
|
52 |
+
ds = load_dataset(dataset_id, split="train", token=hf_token)
|
53 |
+
ds = ds.shuffle(seed=42)
|
54 |
+
if len(ds) > 1000 and dev: ds = ds.select(range(0, 999))
|
55 |
+
ds = ds.train_test_split(train_size=0.75)
|
56 |
+
train_ds, eval_ds = ds['train'], ds['test']
|
57 |
+
print('train: ', len(train_ds), 'eval: ', len(eval_ds))
|
58 |
+
|
59 |
+
# model
|
60 |
+
model = SentenceTransformer(model_id)
|
61 |
+
|
62 |
+
# loss
|
63 |
+
loss = CachedMultipleNegativesRankingLoss(model)
|
64 |
+
|
65 |
+
# training args
|
66 |
+
args = SentenceTransformerTrainingArguments(
|
67 |
+
output_dir="outputs", # required
|
68 |
+
num_train_epochs=num_epochs, # optional...
|
69 |
+
per_device_train_batch_size=16,
|
70 |
+
warmup_ratio=0.1,
|
71 |
+
#fp16=True, # Set to False if your GPU can't handle FP16
|
72 |
+
#bf16=False, # Set to True if your GPU supports BF16
|
73 |
+
batch_sampler=BatchSamplers.NO_DUPLICATES, # Losses using "in-batch negatives" benefit from no duplicates
|
74 |
+
save_total_limit=2
|
75 |
+
# per_device_eval_batch_size=1,
|
76 |
+
# eval_strategy="epoch",
|
77 |
+
# save_strategy="epoch",
|
78 |
+
# logging_steps=100,
|
79 |
+
# Optional tracking/debugging parameters:
|
80 |
+
# eval_strategy="steps",
|
81 |
+
# eval_steps=100,
|
82 |
+
# save_strategy="steps",
|
83 |
+
# save_steps=100,
|
84 |
+
# logging_steps=100,
|
85 |
+
# run_name="jina-code-vechain-pair", # Used in W&B if `wandb` is installed
|
86 |
+
)
|
87 |
+
|
88 |
+
# ir evaluator
|
89 |
+
ir_evaluator = get_ir_evaluator(eval_ds)
|
90 |
+
|
91 |
+
# base model metrics
|
92 |
+
base_metrics = ir_evaluator(model)
|
93 |
+
print(ir_evaluator.primary_metric)
|
94 |
+
print(base_metrics[ir_evaluator.primary_metric])
|
95 |
+
|
96 |
+
|
97 |
+
# train
|
98 |
+
trainer = SentenceTransformerTrainer(
|
99 |
+
model=model,
|
100 |
+
args=args,
|
101 |
+
train_dataset=train_ds,
|
102 |
+
# eval_dataset=eval_ds,
|
103 |
+
loss=loss,
|
104 |
+
# evaluator=ir_evaluator,
|
105 |
+
)
|
106 |
+
trainer.train()
|
107 |
+
|
108 |
+
# fine tuned model metrics
|
109 |
+
ft_metrics = ir_evaluator(model)
|
110 |
+
print(ir_evaluator.primary_metric)
|
111 |
+
print(ft_metrics[ir_evaluator.primary_metric])
|
112 |
+
|
113 |
+
|
114 |
+
metrics = pd.DataFrame([base_metrics, ft_metrics]).T
|
115 |
+
print(metrics)
|
116 |
+
return str(metrics)
|
117 |
+
|
118 |
+
|
119 |
+
## logs to UI
|
120 |
+
# https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778
|
121 |
+
|
122 |
+
demo = gr.Interface(fn=greet, inputs=["text", "text", "text", "number", "bool"], outputs=["text"]) # "dataframe"
|
123 |
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
accelerate
|
3 |
+
sentence-transformers
|