hiddenFront commited on
Commit
95c8334
ยท
verified ยท
1 Parent(s): 1efa28d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -29
app.py CHANGED
@@ -7,15 +7,66 @@ import numpy as np
7
  import os
8
  import sys # ์˜ค๋ฅ˜ ์‹œ ์„œ๋น„์Šค ์ข…๋ฃŒ๋ฅผ ์œ„ํ•ด sys ๋ชจ๋“ˆ ์ž„ํฌํŠธ
9
 
10
- # transformers์˜ AutoTokenizer๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
11
- from transformers import AutoTokenizer # BertModel, BertForSequenceClassification ๋“ฑ์€ ์ด์ œ ์ง์ ‘ ํ•„์š” ์—†์Šต๋‹ˆ๋‹ค.
12
  from torch.utils.data import Dataset, DataLoader
13
  import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
14
  from huggingface_hub import hf_hub_download # hf_hub_download ์ž„ํฌํŠธ ์œ ์ง€
15
- # collections ๋ชจ๋“ˆ์€ ๋” ์ด์ƒ ํ•„์š” ์—†์„ ์ˆ˜ ์žˆ์ง€๋งŒ, ํ˜น์‹œ ๋ชฐ๋ผ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.
16
- import collections
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # --- 1. FastAPI ์•ฑ ๋ฐ ์ „์—ญ ๋ณ€์ˆ˜ ์„ค์ • ---
 
 
 
 
 
 
19
  app = FastAPI()
20
  device = torch.device("cpu") # Hugging Face Spaces์˜ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋Š” ์ฃผ๋กœ CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
21
 
@@ -42,7 +93,6 @@ tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1')
42
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
43
 
44
  # โœ… ๋ชจ๋ธ ๋กœ๋“œ (Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œ)
45
- # textClassifierModel.pt ํŒŒ์ผ์€ ์ด๋ฏธ ๊ฒฝ๋Ÿ‰ํ™”๋œ '์™„์ „ํ•œ ๋ชจ๋ธ ๊ฐ์ฒด'๋ผ๊ณ  ๊ฐ€์ •ํ•˜๊ณ  ์ง์ ‘ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
46
  try:
47
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์‚ฌ์šฉ์ž๋‹˜์˜ ์‹ค์ œ Hugging Face ์ €์žฅ์†Œ ID
48
  HF_MODEL_FILENAME = "textClassifierModel.pt" # Hugging Face Hub์— ์—…๋กœ๋“œํ•œ ํŒŒ์ผ ์ด๋ฆ„๊ณผ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
@@ -51,11 +101,35 @@ try:
51
  print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
52
 
53
  # --- ์ˆ˜์ •๋œ ํ•ต์‹ฌ ๋ถ€๋ถ„ ---
54
- # ๊ฒฝ๋Ÿ‰ํ™”๋œ ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ์ง์ ‘ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
55
- # ์ด ํŒŒ์ผ์€ ์ด๋ฏธ PyTorch ๋ชจ๋ธ ๊ฐ์ฒด(์–‘์žํ™”๋œ ๋ชจ๋ธ ํฌํ•จ)์ด๋ฏ€๋กœ ๋ฐ”๋กœ ๋กœ๋“œํ•˜์—ฌ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
56
- model = torch.load(model_path, map_location=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # --- ์ˆ˜์ •๋œ ํ•ต์‹ฌ ๋ถ€๋ถ„ ๋ ---
58
 
 
59
  model.eval() # ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์„ค์ •
60
  print("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
61
 
@@ -64,25 +138,6 @@ except Exception as e:
64
  sys.exit(1) # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
65
 
66
 
67
- # --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ (dataset.py์—์„œ ์˜ฎ๊ฒจ์˜ด) ---
68
- # ์ด ํด๋ž˜์Šค๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
69
- class BERTDataset(Dataset):
70
- def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
71
- # nlp.data.BERTSentenceTransform์€ ํ† ํฌ๋‚˜์ด์ € ํ•จ์ˆ˜๋ฅผ ๋ฐ›์Šต๋‹ˆ๋‹ค.
72
- # AutoTokenizer์˜ tokenize ๋ฉ”์„œ๋“œ๋ฅผ ์ง์ ‘ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
73
- transform = nlp.data.BERTSentenceTransform(
74
- bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
75
- )
76
- self.sentences = [transform([i[sent_idx]]) for i in dataset]
77
- self.labels = [np.int32(i[label_idx]) for i in dataset]
78
-
79
- def __getitem__(self, i):
80
- return (self.sentences[i] + (self.labels[i],))
81
-
82
- def __len__(self):
83
- return len(self.labels)
84
-
85
-
86
  # โœ… ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
87
  max_len = 64
88
  batch_size = 32
@@ -125,4 +180,3 @@ def root():
125
  async def predict_route(item: InputText):
126
  result = predict(item.text)
127
  return {"text": item.text, "classification": result}
128
-
 
7
  import os
8
  import sys # ์˜ค๋ฅ˜ ์‹œ ์„œ๋น„์Šค ์ข…๋ฃŒ๋ฅผ ์œ„ํ•ด sys ๋ชจ๋“ˆ ์ž„ํฌํŠธ
9
 
10
+ # transformers์˜ AutoTokenizer ๋ฐ BertModel ์ž„ํฌํŠธ
11
+ from transformers import AutoTokenizer, BertModel # BertModel ์ž„ํฌํŠธ ์ถ”๊ฐ€
12
  from torch.utils.data import Dataset, DataLoader
13
  import logging # ๋กœ๊น… ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
14
  from huggingface_hub import hf_hub_download # hf_hub_download ์ž„ํฌํŠธ ์œ ์ง€
15
+ import collections # collections ๋ชจ๋“ˆ ์ž„ํฌํŠธ ์œ ์ง€
16
+
17
+ # --- 1. BERTClassifier ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜ ---
18
+ # ์ด ํด๋ž˜์Šค๋Š” ๋ชจ๋ธ์˜ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
19
+ class BERTClassifier(torch.nn.Module):
20
+ def __init__(self,
21
+ bert,
22
+ hidden_size = 768,
23
+ num_classes=5, # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค ์ˆ˜ (category ๋”•์…”๋„ˆ๋ฆฌ ํฌ๊ธฐ์™€ ์ผ์น˜)
24
+ dr_rate=None,
25
+ params=None):
26
+ super(BERTClassifier, self).__init__()
27
+ self.bert = bert
28
+ self.dr_rate = dr_rate
29
+
30
+ self.classifier = torch.nn.Linear(hidden_size , num_classes)
31
+ if dr_rate:
32
+ self.dropout = torch.nn.Dropout(p=dr_rate)
33
+
34
+ def gen_attention_mask(self, token_ids, valid_length):
35
+ attention_mask = torch.zeros_like(token_ids)
36
+ for i, v in enumerate(valid_length):
37
+ attention_mask[i][:v] = 1
38
+ return attention_mask.float()
39
+
40
+ def forward(self, token_ids, valid_length, segment_ids):
41
+ attention_mask = self.gen_attention_mask(token_ids, valid_length)
42
+
43
+ _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False)
44
+
45
+ if self.dr_rate:
46
+ out = self.dropout(pooler)
47
+ else:
48
+ out = pooler
49
+ return self.classifier(out)
50
+
51
+ # --- 2. BERTDataset ํด๋ž˜์Šค ์ •์˜ ---
52
+ # ์ด ํด๋ž˜์Šค๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
53
+ class BERTDataset(Dataset):
54
+ def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
55
+ # nlp.data.BERTSentenceTransform์€ ํ† ํฌ๋‚˜์ด์ € ํ•จ์ˆ˜๋ฅผ ๋ฐ›์Šต๋‹ˆ๋‹ค.
56
+ # AutoTokenizer์˜ tokenize ๋ฉ”์„œ๋“œ๋ฅผ ์ง์ ‘ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
57
+ transform = nlp.data.BERTSentenceTransform(
58
+ bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair
59
+ )
60
+ self.sentences = [transform([i[sent_idx]]) for i in dataset]
61
+ self.labels = [np.int32(i[label_idx]) for i in dataset]
62
 
63
+ def __getitem__(self, i):
64
+ return (self.sentences[i] + (self.labels[i],))
65
+
66
+ def __len__(self):
67
+ return len(self.labels)
68
+
69
+ # --- 3. FastAPI ์•ฑ ๋ฐ ์ „์—ญ ๋ณ€์ˆ˜ ์„ค์ • ---
70
  app = FastAPI()
71
  device = torch.device("cpu") # Hugging Face Spaces์˜ ๋ฌด๋ฃŒ ํ‹ฐ์–ด๋Š” ์ฃผ๋กœ CPU๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
72
 
 
93
  print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
94
 
95
  # โœ… ๋ชจ๋ธ ๋กœ๋“œ (Hugging Face Hub์—์„œ ๋‹ค์šด๋กœ๋“œ)
 
96
  try:
97
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์‚ฌ์šฉ์ž๋‹˜์˜ ์‹ค์ œ Hugging Face ์ €์žฅ์†Œ ID
98
  HF_MODEL_FILENAME = "textClassifierModel.pt" # Hugging Face Hub์— ์—…๋กœ๋“œํ•œ ํŒŒ์ผ ์ด๋ฆ„๊ณผ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
 
101
  print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
102
 
103
  # --- ์ˆ˜์ •๋œ ํ•ต์‹ฌ ๋ถ€๋ถ„ ---
104
+ # 1. BertModel.from_pretrained๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ธฐ๋ณธ BERT ๋ชจ๋ธ์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
105
+ # ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋ชจ๋ธ์˜ ์•„ํ‚คํ…์ฒ˜์™€ ์‚ฌ์ „ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๊ฐ€ ๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค.
106
+ bert_base_model = BertModel.from_pretrained('skt/kobert-base-v1')
107
+
108
+ # 2. BERTClassifier ์ธ์Šคํ„ด์Šค๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
109
+ # ์—ฌ๊ธฐ์„œ num_classes๋Š” category ๋”•์…”๋„ˆ๋ฆฌ์˜ ํฌ๊ธฐ์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
110
+ model = BERTClassifier(
111
+ bert_base_model,
112
+ dr_rate=0.5, # ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋œ dr_rate ๊ฐ’์œผ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
113
+ num_classes=len(category)
114
+ )
115
+
116
+ # 3. ๋‹ค์šด๋กœ๋“œ๋œ ํŒŒ์ผ์—์„œ state_dict๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
117
+ # ์ด ํŒŒ์ผ์€ ์‚ฌ์šฉ์ž๋‹˜์˜ ๊ฒฝ๋Ÿ‰ํ™”๋œ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋งŒ ํฌํ•จํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
118
+ loaded_state_dict = torch.load(model_path, map_location=device)
119
+
120
+ # 4. ๋กœ๋“œ๋œ state_dict์˜ ํ‚ค๋ฅผ ์กฐ์ •ํ•˜๊ณ  ๋ชจ๋ธ์— ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
121
+ # 'module.' ์ ‘๋‘์‚ฌ๊ฐ€ ๏ฟฝ๏ฟฝ๏ฟฝ์–ด์žˆ๋Š” ๊ฒฝ์šฐ ์ œ๊ฑฐํ•˜๋Š” ๋กœ์ง์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
122
+ new_state_dict = collections.OrderedDict()
123
+ for k, v in loaded_state_dict.items():
124
+ name = k
125
+ if name.startswith('module.'):
126
+ name = name[7:]
127
+ new_state_dict[name] = v
128
+
129
+ model.load_state_dict(new_state_dict)
130
  # --- ์ˆ˜์ •๋œ ํ•ต์‹ฌ ๋ถ€๋ถ„ ๋ ---
131
 
132
+ model.to(device) # ๋ชจ๋ธ์„ ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
133
  model.eval() # ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์„ค์ •
134
  print("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
135
 
 
138
  sys.exit(1) # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜์ง€ ์•Š์Œ
139
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  # โœ… ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
142
  max_len = 64
143
  batch_size = 32
 
180
  async def predict_route(item: InputText):
181
  result = predict(item.text)
182
  return {"text": item.text, "classification": result}