leewatson commited on
Commit
c977434
·
verified ·
1 Parent(s): 192503c

Upload model_loader.py

Browse files
Files changed (1) hide show
  1. models/model_loader.py +42 -0
models/model_loader.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/model_loader.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import ElectraModel, AutoTokenizer
5
+
6
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ class KOTEtagger(nn.Module):
9
+ """
10
+ KcELECTRA + Linear Head, multi-label emotion classifier (44 labels).
11
+ - 가중치 파일: kote_pytorch_lightning.bin (strict=False 로딩)
12
+ """
13
+ def __init__(self, model_name="beomi/KcELECTRA-base", revision='v2021', num_labels=44):
14
+ super().__init__()
15
+ self.electra = ElectraModel.from_pretrained(model_name, revision=revision)
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, revision=revision)
17
+ self.classifier = nn.Linear(self.electra.config.hidden_size, num_labels)
18
+
19
+ def forward(self, text: str):
20
+ encoding = self.tokenizer.encode_plus(
21
+ text,
22
+ add_special_tokens=True,
23
+ max_length=128,
24
+ padding="max_length",
25
+ truncation=True,
26
+ return_attention_mask=True,
27
+ return_tensors='pt',
28
+ )
29
+ input_ids = encoding["input_ids"].to(DEVICE)
30
+ attention_mask = encoding["attention_mask"].to(DEVICE)
31
+ outputs = self.electra(input_ids, attention_mask=attention_mask)
32
+ cls = outputs.last_hidden_state[:, 0, :]
33
+ logits = self.classifier(cls)
34
+ return torch.sigmoid(logits)
35
+
36
+ def load_kote_model(weight_path="kote_pytorch_lightning.bin"):
37
+ model = KOTEtagger()
38
+ model.to(DEVICE)
39
+ state = torch.load(weight_path, map_location=DEVICE)
40
+ model.load_state_dict(state, strict=False)
41
+ model.eval()
42
+ return model