jaeyong2 commited on
Commit
f167063
·
verified ·
1 Parent(s): c75b96e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -0
README.md CHANGED
@@ -20,6 +20,52 @@ base_model:
20
 
21
  ## Train
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  ## Evaluation
25
 
 
20
 
21
  ## Train
22
 
23
+ H/W : colab A100 40GB
24
+ Data : jaeyong2/Ko-emb-PreView
25
+
26
+ ```
27
+ model_name = "Alibaba-NLP/gte-multilingual-base"
28
+ dataset = datasets.load_dataset("jaeyong2/Ko-emb-PreView")
29
+ train_dataloader = DataLoader(dataset['train'], batch_size=8, shuffle=True)
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+ model = AutoModel.from_pretrained(model_name).to(torch.bfloat16)
33
+ triplet_loss = TripletLoss(margin=1.0)
34
+
35
+ optimizer = AdamW(model.parameters(), lr=5e-5)
36
+
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ model.to(device)
39
+
40
+ for epoch in range(3): # 에포크 반복
41
+ model.train()
42
+ total_loss = 0
43
+ count = 0
44
+ for batch in tqdm(train_dataloader):
45
+ optimizer.zero_grad()
46
+ loss = None
47
+ for index in range(len(batch["context"])):
48
+ anchor_encodings = tokenizer([batch["context"][index]], truncation=True, padding="max_length", max_length=4096, return_tensors="pt")
49
+ positive_encodings = tokenizer([batch["Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt")
50
+ negative_encodings = tokenizer([batch["Fake Title"][index]], truncation=True, padding="max_length", max_length=256, return_tensors="pt")
51
+
52
+ anchor_encodings = batch_to_device(anchor_encodings, device)
53
+ positive_encodings = batch_to_device(positive_encodings, device)
54
+ negative_encodings = batch_to_device(negative_encodings, device)
55
+
56
+ # 모델 출력 (임베딩 벡터 생성)
57
+ anchor_output = model(**anchor_encodings)[0][:, 0, :] # [CLS] 토큰의 벡터
58
+ positive_output = model(**positive_encodings)[0][:, 0, :]
59
+ negative_output = model(**negative_encodings)[0][:, 0, :]
60
+ # 삼중항 손실 계산
61
+ if loss==None:
62
+ loss = triplet_loss(anchor_output, positive_output, negative_output)
63
+ else:
64
+ loss += triplet_loss(anchor_output, positive_output, negative_output)
65
+ loss /= len(batch["context"])
66
+ loss.backward()
67
+ optimizer.step()
68
+ ```
69
 
70
  ## Evaluation
71