Ellie5757575757 commited on
Commit
01de4e1
·
verified ·
1 Parent(s): 1fa5046

Upload 15 files

Browse files
Cha_Json.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ cha2json.py ── 將單一 CLAN .cha 轉成 JSON(強化 %mor/%wor 對齊)
5
+
6
+ 只要:
7
+ $ python3 cha2json.py
8
+ """
9
+
10
+ # ────────── 這兩行改成你的固定路徑 ──────────
11
+ INPUT_CHA = "/workspace/SH001/website/ACWT01a(4).cha"
12
+ OUTPUT_JSON = "/workspace/SH001/website/Output.json"
13
+ # ──────────────────────────────────────────
14
+
15
+ import re, json, sys
16
+ from pathlib import Path
17
+ from collections import defaultdict
18
+
19
+ TAG_PREFIXES = ("*PAR:", "*INV:", "%mor:", "%gra:", "%wor:", "@")
20
+ WORD_RE = re.compile(r"[A-Za-z0-9]+")
21
+
22
+ # ────────── 同義集合(加速對齊) ──────────
23
+ SYN_SETS = [
24
+ {"be", "am", "is", "are", "was", "were"},
25
+ {"have", "has", "had"},
26
+ {"do", "does", "did"},
27
+ {"go", "going", "went", "gone"},
28
+ ]
29
+ def same_syn(a, b): # 同詞彙不同形態視為相同
30
+ return any(a in s and b in s for s in SYN_SETS)
31
+
32
+ def canonical(txt): # token/word → 比對用字串
33
+ head = re.split(r"[~\-\&|]", txt, 1)[0]
34
+ m = WORD_RE.search(head)
35
+ return m.group(0).lower() if m else ""
36
+
37
+ def merge_multiline(block): # 合併跨行 %mor/%wor/%gra
38
+ merged, buf = [], None
39
+ for raw in block:
40
+ ln = raw.rstrip("\n").replace("\x15", "")
41
+ if ln.lstrip().startswith("%") and ":" in ln:
42
+ if buf: merged.append(buf)
43
+ buf = ln
44
+ else:
45
+ if buf and ln.strip(): buf += " " + ln.strip()
46
+ else: merged.append(ln)
47
+ if buf: merged.append(buf)
48
+ return "\n".join(merged)
49
+
50
+ # ────────── 主轉換 ──────────
51
+ def cha_to_json(lines):
52
+ pos_map = defaultdict(lambda: len(pos_map) + 1)
53
+ gra_map = defaultdict(lambda: len(gra_map) + 1)
54
+ aphasia_map = defaultdict(lambda: len(aphasia_map))
55
+
56
+ data, sent, i = [], None, 0
57
+ while i < len(lines):
58
+ line = lines[i]
59
+
60
+ # --- 標頭 / 結尾 ---
61
+ if line.startswith("@UTF8"):
62
+ sent = {"sentence_id": f"S{len(data)+1}",
63
+ "sentence_pid": None,
64
+ "aphasia_type": None,
65
+ "dialogues": []}
66
+ i += 1; continue
67
+ if line.startswith("@End"):
68
+ if sent and sent["aphasia_type"] and sent["dialogues"]:
69
+ data.append(sent)
70
+ sent = None; i += 1; continue
71
+
72
+ # --- 句子屬性 ---
73
+ if sent and line.startswith("@PID:"):
74
+ parts = line.split("\t")
75
+ if len(parts) > 1:
76
+ sent["sentence_pid"] = parts[1].strip()
77
+ i += 1; continue
78
+ if sent and line.startswith("@ID:") and "|PAR|" in line:
79
+ aph = line.split("|")[5].strip().upper()
80
+ aphasia_map[aph]
81
+ sent["aphasia_type"] = aph
82
+ i += 1; continue
83
+
84
+ # --- 對話行 ---
85
+ if sent and (line.startswith("*INV:") or line.startswith("*PAR:")):
86
+ role = "INV" if line.startswith("*INV:") else "PAR"
87
+ if not sent["dialogues"]:
88
+ sent["dialogues"].append({"INV": [], "PAR": []})
89
+ if role == "INV" and sent["dialogues"][-1]["PAR"]:
90
+ sent["dialogues"].append({"INV": [], "PAR": []})
91
+ sent["dialogues"][-1][role].append(
92
+ {"tokens": [], "word_pos_ids": [], "word_grammar_ids": [], "word_durations": []})
93
+ i += 1; continue
94
+
95
+ # --- %mor ---
96
+ if sent and line.startswith("%mor:"):
97
+ blk = [line]; i += 1
98
+ while i < len(lines) and not lines[i].lstrip().startswith(TAG_PREFIXES):
99
+ blk.append(lines[i]); i += 1
100
+ units = merge_multiline(blk).replace("%mor:", "").strip().split()
101
+
102
+ toks, pos_ids = [], []
103
+ for u in units:
104
+ if "|" in u:
105
+ pos, rest = u.split("|", 1)
106
+ toks.append(rest.split("|", 1)[0])
107
+ pos_ids.append(pos_map[pos])
108
+
109
+ dlg = sent["dialogues"][-1]
110
+ tgt = dlg["PAR"][-1] if dlg["PAR"] else dlg["INV"][-1]
111
+ tgt["tokens"], tgt["word_pos_ids"] = toks, pos_ids
112
+ continue
113
+
114
+ # --- %wor ---
115
+ if sent and line.startswith("%wor:"):
116
+ blk = [line]; i += 1
117
+ while i < len(lines) and not lines[i].lstrip().startswith(TAG_PREFIXES):
118
+ blk.append(lines[i]); i += 1
119
+ merged = merge_multiline(blk).replace("%wor:", "").strip()
120
+ raw = re.findall(r"(\S+)\s+(\d+)\D+(\d+)", merged)
121
+ wor = [(w, int(e)-int(s)) for w,s,e in raw]
122
+
123
+ dlg = sent["dialogues"][-1]
124
+ tgt = dlg["PAR"][-1] if dlg["PAR"] else dlg["INV"][-1]
125
+
126
+ aligned, j = [], 0
127
+ for tok in tgt["tokens"]:
128
+ c_tok = canonical(tok); match = None
129
+ for k in range(j, len(wor)):
130
+ c_w = canonical(wor[k][0])
131
+ if (c_tok == c_w or c_w.startswith(c_tok) or c_tok.startswith(c_w)
132
+ or same_syn(c_tok, c_w)):
133
+ match = wor[k]; j = k+1; break
134
+ aligned.append([tok, match[1] if match else 0])
135
+ tgt["word_durations"] = aligned
136
+ continue
137
+
138
+ # --- %gra ---
139
+ if sent and line.startswith("%gra:"):
140
+ blk = [line]; i += 1
141
+ while i < len(lines) and not lines[i].lstrip().startswith(TAG_PREFIXES):
142
+ blk.append(lines[i]); i += 1
143
+ units = merge_multiline(blk).replace("%gra:", "").strip().split()
144
+
145
+ triples = []
146
+ for u in units:
147
+ a,b,r = u.split("|")
148
+ if a.isdigit() and b.isdigit():
149
+ triples.append([int(a), int(b), gra_map[r]])
150
+
151
+ dlg = sent["dialogues"][-1]
152
+ (dlg["PAR"][-1] if dlg["PAR"] else dlg["INV"][-1])["word_grammar_ids"] = triples
153
+ continue
154
+
155
+ i += 1 # 其他行
156
+
157
+ return {"sentences": data,
158
+ "pos_mapping": dict(pos_map),
159
+ "grammar_mapping": dict(gra_map),
160
+ "aphasia_types": dict(aphasia_map)}
161
+
162
+ # ────────── 執行 ──────────
163
+ def main():
164
+ in_path = Path(INPUT_CHA)
165
+ out_path = Path(OUTPUT_JSON)
166
+
167
+ if not in_path.exists():
168
+ sys.exit(f"❌ 找不到檔案: {in_path}")
169
+
170
+ with in_path.open("r", encoding="utf-8") as fh:
171
+ lines = fh.readlines()
172
+
173
+ dataset = cha_to_json(lines)
174
+ out_path.parent.mkdir(parents=True, exist_ok=True)
175
+ with out_path.open("w", encoding="utf-8") as fh:
176
+ json.dump(dataset, fh, ensure_ascii=False, indent=4)
177
+
178
+ print(f"✅ 轉換完成 → {out_path}")
179
+
180
+ if __name__ == "__main__":
181
+ main()
Json__Output.py ADDED
@@ -0,0 +1,896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 失語症分類推理系統
4
+ 用於載入訓練好的模型並對新的語音數據進行分類預測
5
+ """
6
+
7
+ import json
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ import os
13
+ import math
14
+ from typing import Dict, List, Optional, Tuple
15
+ from dataclasses import dataclass
16
+ import pandas as pd
17
+ from transformers import AutoTokenizer, AutoModel
18
+ from collections import defaultdict
19
+
20
+ # 重新定義模型結構(與訓練程式碼一致)
21
+ @dataclass
22
+ class ModelConfig:
23
+ model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
24
+ max_length: int = 512
25
+ hidden_size: int = 768
26
+ pos_vocab_size: int = 150
27
+ pos_emb_dim: int = 64
28
+ grammar_dim: int = 3
29
+ grammar_hidden_dim: int = 64
30
+ duration_hidden_dim: int = 128
31
+ prosody_dim: int = 32
32
+ num_attention_heads: int = 8
33
+ attention_dropout: float = 0.3
34
+ classifier_hidden_dims: List[int] = None
35
+ dropout_rate: float = 0.3
36
+
37
+ def __post_init__(self):
38
+ if self.classifier_hidden_dims is None:
39
+ self.classifier_hidden_dims = [512, 256]
40
+
41
+ class StablePositionalEncoding(nn.Module):
42
+ def __init__(self, d_model: int, max_len: int = 5000):
43
+ super().__init__()
44
+ self.d_model = d_model
45
+
46
+ pe = torch.zeros(max_len, d_model)
47
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
48
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
49
+ (-math.log(10000.0) / d_model))
50
+
51
+ pe[:, 0::2] = torch.sin(position * div_term)
52
+ pe[:, 1::2] = torch.cos(position * div_term)
53
+
54
+ self.register_buffer('pe', pe.unsqueeze(0))
55
+ self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01)
56
+
57
+ def forward(self, x):
58
+ seq_len = x.size(1)
59
+ sinusoidal = self.pe[:, :seq_len, :].to(x.device)
60
+ learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
61
+ return x + 0.1 * (sinusoidal + learnable)
62
+
63
+ class StableMultiHeadAttention(nn.Module):
64
+ def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3):
65
+ super().__init__()
66
+ self.num_heads = num_heads
67
+ self.feature_dim = feature_dim
68
+ self.head_dim = feature_dim // num_heads
69
+
70
+ assert feature_dim % num_heads == 0
71
+
72
+ self.query = nn.Linear(feature_dim, feature_dim)
73
+ self.key = nn.Linear(feature_dim, feature_dim)
74
+ self.value = nn.Linear(feature_dim, feature_dim)
75
+ self.dropout = nn.Dropout(dropout)
76
+ self.output_proj = nn.Linear(feature_dim, feature_dim)
77
+ self.layer_norm = nn.LayerNorm(feature_dim)
78
+
79
+ def forward(self, x, mask=None):
80
+ batch_size, seq_len, _ = x.size()
81
+
82
+ Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
83
+ K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
84
+ V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
85
+
86
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
87
+
88
+ if mask is not None:
89
+ if mask.dim() == 2:
90
+ mask = mask.unsqueeze(1).unsqueeze(1)
91
+ scores.masked_fill_(mask == 0, -1e9)
92
+
93
+ attn_weights = F.softmax(scores, dim=-1)
94
+ attn_weights = self.dropout(attn_weights)
95
+
96
+ context = torch.matmul(attn_weights, V)
97
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim)
98
+
99
+ output = self.output_proj(context)
100
+ return self.layer_norm(output + x)
101
+
102
+ class StableLinguisticFeatureExtractor(nn.Module):
103
+ def __init__(self, config: ModelConfig):
104
+ super().__init__()
105
+ self.config = config
106
+
107
+ self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0)
108
+ self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4)
109
+
110
+ self.grammar_projection = nn.Sequential(
111
+ nn.Linear(config.grammar_dim, config.grammar_hidden_dim),
112
+ nn.Tanh(),
113
+ nn.LayerNorm(config.grammar_hidden_dim),
114
+ nn.Dropout(config.dropout_rate * 0.3)
115
+ )
116
+
117
+ self.duration_projection = nn.Sequential(
118
+ nn.Linear(1, config.duration_hidden_dim),
119
+ nn.Tanh(),
120
+ nn.LayerNorm(config.duration_hidden_dim)
121
+ )
122
+
123
+ self.prosody_projection = nn.Sequential(
124
+ nn.Linear(config.prosody_dim, config.prosody_dim),
125
+ nn.ReLU(),
126
+ nn.LayerNorm(config.prosody_dim)
127
+ )
128
+
129
+ total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
130
+ config.duration_hidden_dim + config.prosody_dim)
131
+ self.feature_fusion = nn.Sequential(
132
+ nn.Linear(total_feature_dim, total_feature_dim // 2),
133
+ nn.Tanh(),
134
+ nn.LayerNorm(total_feature_dim // 2),
135
+ nn.Dropout(config.dropout_rate)
136
+ )
137
+
138
+ def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask):
139
+ batch_size, seq_len = pos_ids.size()
140
+
141
+ pos_ids_clamped = pos_ids.clamp(0, self.config.pos_vocab_size - 1)
142
+ pos_embeds = self.pos_embedding(pos_ids_clamped)
143
+ pos_features = self.pos_attention(pos_embeds, attention_mask)
144
+
145
+ grammar_features = self.grammar_projection(grammar_ids.float())
146
+ duration_features = self.duration_projection(durations.unsqueeze(-1).float())
147
+ prosody_features = self.prosody_projection(prosody_features.float())
148
+
149
+ combined_features = torch.cat([
150
+ pos_features, grammar_features, duration_features, prosody_features
151
+ ], dim=-1)
152
+
153
+ fused_features = self.feature_fusion(combined_features)
154
+
155
+ mask_expanded = attention_mask.unsqueeze(-1).float()
156
+ pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1)
157
+
158
+ return pooled_features
159
+
160
+ class StableAphasiaClassifier(nn.Module):
161
+ def __init__(self, config: ModelConfig, num_labels: int):
162
+ super().__init__()
163
+ self.config = config
164
+ self.num_labels = num_labels
165
+
166
+ self.bert = AutoModel.from_pretrained(config.model_name)
167
+ self.bert_config = self.bert.config
168
+
169
+ self.positional_encoder = StablePositionalEncoding(
170
+ d_model=self.bert_config.hidden_size,
171
+ max_len=config.max_length
172
+ )
173
+
174
+ self.linguistic_extractor = StableLinguisticFeatureExtractor(config)
175
+
176
+ bert_dim = self.bert_config.hidden_size
177
+ linguistic_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
178
+ config.duration_hidden_dim + config.prosody_dim) // 2
179
+
180
+ self.feature_fusion = nn.Sequential(
181
+ nn.Linear(bert_dim + linguistic_dim, bert_dim),
182
+ nn.LayerNorm(bert_dim),
183
+ nn.Tanh(),
184
+ nn.Dropout(config.dropout_rate)
185
+ )
186
+
187
+ self.classifier = self._build_classifier(bert_dim, num_labels)
188
+
189
+ self.severity_head = nn.Sequential(
190
+ nn.Linear(bert_dim, 4),
191
+ nn.Softmax(dim=-1)
192
+ )
193
+
194
+ self.fluency_head = nn.Sequential(
195
+ nn.Linear(bert_dim, 1),
196
+ nn.Sigmoid()
197
+ )
198
+
199
+ def _build_classifier(self, input_dim: int, num_labels: int):
200
+ layers = []
201
+ current_dim = input_dim
202
+
203
+ for hidden_dim in self.config.classifier_hidden_dims:
204
+ layers.extend([
205
+ nn.Linear(current_dim, hidden_dim),
206
+ nn.LayerNorm(hidden_dim),
207
+ nn.Tanh(),
208
+ nn.Dropout(self.config.dropout_rate)
209
+ ])
210
+ current_dim = hidden_dim
211
+
212
+ layers.append(nn.Linear(current_dim, num_labels))
213
+ return nn.Sequential(*layers)
214
+
215
+ def forward(self, input_ids, attention_mask, labels=None,
216
+ word_pos_ids=None, word_grammar_ids=None, word_durations=None,
217
+ prosody_features=None, **kwargs):
218
+
219
+ bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
220
+ sequence_output = bert_outputs.last_hidden_state
221
+
222
+ position_enhanced = self.positional_encoder(sequence_output)
223
+ pooled_output = self._attention_pooling(position_enhanced, attention_mask)
224
+
225
+ if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]):
226
+ if prosody_features is None:
227
+ batch_size, seq_len = input_ids.size()
228
+ prosody_features = torch.zeros(
229
+ batch_size, seq_len, self.config.prosody_dim,
230
+ device=input_ids.device
231
+ )
232
+
233
+ linguistic_features = self.linguistic_extractor(
234
+ word_pos_ids, word_grammar_ids, word_durations,
235
+ prosody_features, attention_mask
236
+ )
237
+ else:
238
+ linguistic_features = torch.zeros(
239
+ input_ids.size(0),
240
+ (self.config.pos_emb_dim + self.config.grammar_hidden_dim +
241
+ self.config.duration_hidden_dim + self.config.prosody_dim) // 2,
242
+ device=input_ids.device
243
+ )
244
+
245
+ combined_features = torch.cat([pooled_output, linguistic_features], dim=1)
246
+ fused_features = self.feature_fusion(combined_features)
247
+
248
+ logits = self.classifier(fused_features)
249
+ severity_pred = self.severity_head(fused_features)
250
+ fluency_pred = self.fluency_head(fused_features)
251
+
252
+ return {
253
+ "logits": logits,
254
+ "severity_pred": severity_pred,
255
+ "fluency_pred": fluency_pred,
256
+ "loss": None
257
+ }
258
+
259
+ def _attention_pooling(self, sequence_output, attention_mask):
260
+ attention_weights = torch.softmax(
261
+ torch.sum(sequence_output, dim=-1, keepdim=True), dim=1
262
+ )
263
+ attention_weights = attention_weights * attention_mask.unsqueeze(-1).float()
264
+ attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9)
265
+ pooled = torch.sum(sequence_output * attention_weights, dim=1)
266
+ return pooled
267
+
268
+
269
+ class AphasiaInferenceSystem:
270
+ """失語症分類推理系統"""
271
+
272
+ def __init__(self, model_dir: str):
273
+ """
274
+ 初始化推理系統
275
+ Args:
276
+ model_dir: 訓練好的模型目錄路徑
277
+ """
278
+ self.model_dir = '/workspace/SH001/adaptive_aphasia_model'
279
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
280
+
281
+ # 失語症類型描述
282
+ self.aphasia_descriptions = {
283
+ "BROCA": {
284
+ "name": "Broca's Aphasia (Non-fluent)",
285
+ "description": "Characterized by limited speech output, difficulty with grammar and sentence formation, but relatively preserved comprehension. Speech is typically effortful and halting.",
286
+ "features": ["Non-fluent speech", "Preserved comprehension", "Grammar difficulties", "Word-finding problems"]
287
+ },
288
+ "TRANSMOTOR": {
289
+ "name": "Trans-cortical Motor Aphasia",
290
+ "description": "Similar to Broca's aphasia but with preserved repetition abilities. Speech is non-fluent with good comprehension.",
291
+ "features": ["Non-fluent speech", "Good repetition", "Preserved comprehension", "Grammar difficulties"]
292
+ },
293
+ "NOTAPHASICBYWAB": {
294
+ "name": "Not Aphasic by WAB",
295
+ "description": "Individuals who do not meet the criteria for aphasia according to the Western Aphasia Battery assessment.",
296
+ "features": ["Normal language function", "No significant language impairment", "Good comprehension", "Fluent speech"]
297
+ },
298
+ "CONDUCTION": {
299
+ "name": "Conduction Aphasia",
300
+ "description": "Characterized by fluent speech with good comprehension but severely impaired repetition. Often involves phonemic paraphasias.",
301
+ "features": ["Fluent speech", "Good comprehension", "Poor repetition", "Phonemic errors"]
302
+ },
303
+ "WERNICKE": {
304
+ "name": "Wernicke's Aphasia (Fluent)",
305
+ "description": "Fluent but often meaningless speech with poor comprehension. Speech may contain neologisms and jargon.",
306
+ "features": ["Fluent speech", "Poor comprehension", "Jargon speech", "Neologisms"]
307
+ },
308
+ "ANOMIC": {
309
+ "name": "Anomic Aphasia",
310
+ "description": "Primarily characterized by word-finding difficulties with otherwise relatively preserved language abilities.",
311
+ "features": ["Word-finding difficulties", "Good comprehension", "Fluent speech", "Circumlocution"]
312
+ },
313
+ "GLOBAL": {
314
+ "name": "Global Aphasia",
315
+ "description": "Severe impairment in all language modalities - comprehension, production, repetition, and naming.",
316
+ "features": ["Severe comprehension deficit", "Non-fluent speech", "Poor repetition", "Severe naming difficulties"]
317
+ },
318
+ "ISOLATION": {
319
+ "name": "Isolation Syndrome",
320
+ "description": "Rare condition with preserved repetition but severely impaired comprehension and spontaneous speech.",
321
+ "features": ["Good repetition", "Poor comprehension", "Limited spontaneous speech", "Echolalia"]
322
+ },
323
+ "TRANSSENSORY": {
324
+ "name": "Trans-cortical Sensory Aphasia",
325
+ "description": "Fluent speech with good repetition but impaired comprehension, similar to Wernicke's but with preserved repetition.",
326
+ "features": ["Fluent speech", "Good repetition", "Poor comprehension", "Semantic errors"]
327
+ }
328
+ }
329
+
330
+ # 載入模型配置和映射
331
+ self.load_configuration()
332
+
333
+ # 載入模型
334
+ self.load_model()
335
+
336
+ print(f"推理系統初始化完成,使用設備: {self.device}")
337
+
338
+ def load_configuration(self):
339
+ """載入模型配置"""
340
+ config_path = os.path.join(self.model_dir, "config.json")
341
+ if os.path.exists(config_path):
342
+ with open(config_path, "r", encoding="utf-8") as f:
343
+ config_data = json.load(f)
344
+
345
+ self.aphasia_types_mapping = config_data.get("aphasia_types_mapping", {
346
+ "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
347
+ "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
348
+ "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
349
+ })
350
+ self.num_labels = config_data.get("num_labels", 9)
351
+ self.model_name = config_data.get("model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
352
+ else:
353
+ # 預設配置
354
+ self.aphasia_types_mapping = {
355
+ "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
356
+ "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
357
+ "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
358
+ }
359
+ self.num_labels = 9
360
+ self.model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
361
+
362
+ # 建立反向映射
363
+ self.id_to_aphasia_type = {v: k for k, v in self.aphasia_types_mapping.items()}
364
+
365
+ def load_model(self):
366
+ """載入訓練好的模型"""
367
+ # 載入tokenizer
368
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
369
+ if self.tokenizer.pad_token is None:
370
+ self.tokenizer.pad_token = self.tokenizer.eos_token
371
+ added_tokens_path = os.path.join(self.model_dir, "added_tokens.json")
372
+ if os.path.exists(added_tokens_path):
373
+ with open(added_tokens_path, "r", encoding="utf-8") as f:
374
+ data = json.load(f)
375
+ # 如果是 dict,就取出所有 key 當作要新增的 token 清單
376
+ if isinstance(data, dict):
377
+ tokens = list(data.keys())
378
+ else:
379
+ tokens = data # 萬一已經是 list,就直接用
380
+ num_added = self.tokenizer.add_tokens(tokens)
381
+ print(f"新增到 tokenizer 的 token 數量: {num_added}")
382
+ # 建立模型配置
383
+ self.config = ModelConfig()
384
+ self.config.model_name = self.model_name
385
+
386
+ # 建立模型
387
+ self.model = StableAphasiaClassifier(self.config, self.num_labels)
388
+ self.model.bert.resize_token_embeddings(len(self.tokenizer))
389
+ # 載入模型權重
390
+ model_path = os.path.join(self.model_dir, "pytorch_model.bin")
391
+ if os.path.exists(model_path):
392
+ state_dict = torch.load(model_path, map_location=self.device)
393
+ self.model.load_state_dict(state_dict)
394
+ self.model.load_state_dict(state_dict)
395
+ print("模型權重載入成功")
396
+ else:
397
+ raise FileNotFoundError(f"模型權重文件不存在: {model_path}")
398
+
399
+ # 調整tokenizer尺寸
400
+ self.model.bert.resize_token_embeddings(len(self.tokenizer))
401
+
402
+ # 移動到設備並設置為評估模式
403
+ self.model.to(self.device)
404
+ self.model.eval()
405
+
406
+ def preprocess_sentence(self, sentence_data: dict) -> dict:
407
+ """預處理單個句子數據"""
408
+ all_tokens, all_pos, all_grammar, all_durations = [], [], [], []
409
+
410
+ # 處理對話數據
411
+ for dialogue_idx, dialogue in enumerate(sentence_data.get("dialogues", [])):
412
+ if dialogue_idx > 0:
413
+ all_tokens.append("[DIALOGUE]")
414
+ all_pos.append(0)
415
+ all_grammar.append([0, 0, 0])
416
+ all_durations.append(0.0)
417
+
418
+ # 處理參與者的語音
419
+ for par in dialogue.get("PAR", []):
420
+ if "tokens" in par and par["tokens"]:
421
+ tokens = par["tokens"]
422
+ pos_ids = par.get("word_pos_ids", [0] * len(tokens))
423
+ grammar_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(tokens))
424
+ durations = par.get("word_durations", [0.0] * len(tokens))
425
+
426
+ all_tokens.extend(tokens)
427
+ all_pos.extend(pos_ids)
428
+ all_grammar.extend(grammar_ids)
429
+ all_durations.extend(durations)
430
+
431
+ if not all_tokens:
432
+ return None
433
+
434
+ # 文本tokenization
435
+ text = " ".join(all_tokens)
436
+ encoded = self.tokenizer(
437
+ text,
438
+ max_length=self.config.max_length,
439
+ padding="max_length",
440
+ truncation=True,
441
+ return_tensors="pt"
442
+ )
443
+
444
+ # 對齊特徵
445
+ aligned_pos, aligned_grammar, aligned_durations = self._align_features(
446
+ all_tokens, all_pos, all_grammar, all_durations, encoded
447
+ )
448
+
449
+ # 建立韻律特徵
450
+ prosody_features = self._extract_prosodic_features(all_durations, all_tokens)
451
+ prosody_tensor = torch.tensor(prosody_features).unsqueeze(0).repeat(
452
+ self.config.max_length, 1
453
+ )
454
+
455
+ return {
456
+ "input_ids": encoded["input_ids"].squeeze(0),
457
+ "attention_mask": encoded["attention_mask"].squeeze(0),
458
+ "word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long),
459
+ "word_grammar_ids": torch.tensor(aligned_grammar, dtype=torch.long),
460
+ "word_durations": torch.tensor(aligned_durations, dtype=torch.float),
461
+ "prosody_features": prosody_tensor.float(),
462
+ "sentence_id": sentence_data.get("sentence_id", "unknown"),
463
+ "original_tokens": all_tokens,
464
+ "text": text
465
+ }
466
+
467
+ def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded):
468
+ """對齊特徵與BERT子詞"""
469
+ subtoken_to_token = []
470
+
471
+ for token_idx, token in enumerate(tokens):
472
+ subtokens = self.tokenizer.tokenize(token)
473
+ subtoken_to_token.extend([token_idx] * len(subtokens))
474
+
475
+ aligned_pos = [0] # [CLS]
476
+ aligned_grammar = [[0, 0, 0]] # [CLS]
477
+ aligned_durations = [0.0] # [CLS]
478
+
479
+ for subtoken_idx in range(1, self.config.max_length - 1):
480
+ if subtoken_idx - 1 < len(subtoken_to_token):
481
+ original_idx = subtoken_to_token[subtoken_idx - 1]
482
+ aligned_pos.append(pos_ids[original_idx] if original_idx < len(pos_ids) else 0)
483
+ aligned_grammar.append(grammar_ids[original_idx] if original_idx < len(grammar_ids) else [0, 0, 0])
484
+
485
+ # 處理duration數據
486
+ raw_duration = durations[original_idx] if original_idx < len(durations) else 0.0
487
+ if isinstance(raw_duration, list) and len(raw_duration) >= 2:
488
+ try:
489
+ duration_val = float(raw_duration[1]) - float(raw_duration[0])
490
+ except (ValueError, TypeError):
491
+ duration_val = 0.0
492
+ elif isinstance(raw_duration, (int, float)):
493
+ duration_val = float(raw_duration)
494
+ else:
495
+ duration_val = 0.0
496
+
497
+ aligned_durations.append(duration_val)
498
+ else:
499
+ aligned_pos.append(0)
500
+ aligned_grammar.append([0, 0, 0])
501
+ aligned_durations.append(0.0)
502
+
503
+ aligned_pos.append(0) # [SEP]
504
+ aligned_grammar.append([0, 0, 0]) # [SEP]
505
+ aligned_durations.append(0.0) # [SEP]
506
+
507
+ return aligned_pos, aligned_grammar, aligned_durations
508
+
509
+ def _extract_prosodic_features(self, durations, tokens):
510
+ """提取韻律特徵"""
511
+ if not durations:
512
+ return [0.0] * self.config.prosody_dim
513
+
514
+ # 處理duration數據並提取數值
515
+ processed_durations = []
516
+ for d in durations:
517
+ if isinstance(d, list) and len(d) >= 2:
518
+ try:
519
+ processed_durations.append(float(d[1]) - float(d[0]))
520
+ except (ValueError, TypeError):
521
+ continue
522
+ elif isinstance(d, (int, float)):
523
+ processed_durations.append(float(d))
524
+
525
+ if not processed_durations:
526
+ return [0.0] * self.config.prosody_dim
527
+
528
+ # 計算基本統計特徵
529
+ features = [
530
+ np.mean(processed_durations),
531
+ np.std(processed_durations),
532
+ np.median(processed_durations),
533
+ len([d for d in processed_durations if d > np.mean(processed_durations) * 1.5])
534
+ ]
535
+
536
+ # 填充至所需維度
537
+ while len(features) < self.config.prosody_dim:
538
+ features.append(0.0)
539
+
540
+ return features[:self.config.prosody_dim]
541
+
542
+ def predict_single(self, sentence_data: dict) -> dict:
543
+ """對單個句子進行預測"""
544
+ # 預處理數據
545
+ processed_data = self.preprocess_sentence(sentence_data)
546
+ if processed_data is None:
547
+ return {
548
+ "error": "無法處理輸入數據",
549
+ "sentence_id": sentence_data.get("sentence_id", "unknown")
550
+ }
551
+
552
+ # 準備輸入數據
553
+ input_data = {
554
+ "input_ids": processed_data["input_ids"].unsqueeze(0).to(self.device),
555
+ "attention_mask": processed_data["attention_mask"].unsqueeze(0).to(self.device),
556
+ "word_pos_ids": processed_data["word_pos_ids"].unsqueeze(0).to(self.device),
557
+ "word_grammar_ids": processed_data["word_grammar_ids"].unsqueeze(0).to(self.device),
558
+ "word_durations": processed_data["word_durations"].unsqueeze(0).to(self.device),
559
+ "prosody_features": processed_data["prosody_features"].unsqueeze(0).to(self.device)
560
+ }
561
+
562
+ # 模型推理
563
+ with torch.no_grad():
564
+ outputs = self.model(**input_data)
565
+
566
+ logits = outputs["logits"]
567
+ probabilities = F.softmax(logits, dim=1).cpu().numpy()[0]
568
+ predicted_class_id = np.argmax(probabilities)
569
+
570
+ severity_pred = outputs["severity_pred"].cpu().numpy()[0]
571
+ fluency_pred = outputs["fluency_pred"].cpu().numpy()[0][0]
572
+
573
+ # 建立結果
574
+ predicted_type = self.id_to_aphasia_type[predicted_class_id]
575
+ confidence = float(probabilities[predicted_class_id])
576
+
577
+ # 建立機率分佈
578
+ probability_distribution = {}
579
+ for aphasia_type, type_id in self.aphasia_types_mapping.items():
580
+ probability_distribution[aphasia_type] = {
581
+ "probability": float(probabilities[type_id]),
582
+ "percentage": f"{probabilities[type_id]*100:.2f}%"
583
+ }
584
+
585
+ # 排序機率分佈
586
+ sorted_probabilities = sorted(
587
+ probability_distribution.items(),
588
+ key=lambda x: x[1]["probability"],
589
+ reverse=True
590
+ )
591
+
592
+ result = {
593
+ "sentence_id": processed_data["sentence_id"],
594
+ "input_text": processed_data["text"],
595
+ "original_tokens": processed_data["original_tokens"],
596
+ "prediction": {
597
+ "predicted_class": predicted_type,
598
+ "confidence": confidence,
599
+ "confidence_percentage": f"{confidence*100:.2f}%"
600
+ },
601
+ "class_description": self.aphasia_descriptions.get(predicted_type, {
602
+ "name": predicted_type,
603
+ "description": "Description not available",
604
+ "features": []
605
+ }),
606
+ "probability_distribution": dict(sorted_probabilities),
607
+ "additional_predictions": {
608
+ "severity_distribution": {
609
+ "level_0": float(severity_pred[0]),
610
+ "level_1": float(severity_pred[1]),
611
+ "level_2": float(severity_pred[2]),
612
+ "level_3": float(severity_pred[3])
613
+ },
614
+ "predicted_severity_level": int(np.argmax(severity_pred)),
615
+ "fluency_score": float(fluency_pred),
616
+ "fluency_rating": "High" if fluency_pred > 0.7 else "Medium" if fluency_pred > 0.4 else "Low"
617
+ }
618
+ }
619
+
620
+ return result
621
+
622
+ def predict_batch(self, input_file: str, output_file: str = None) -> List[dict]:
623
+ """批次預測JSON文件中的所有句子"""
624
+ # 載入輸入文件
625
+ with open(input_file, "r", encoding="utf-8") as f:
626
+ data = json.load(f)
627
+
628
+ sentences = data.get("sentences", [])
629
+ results = []
630
+
631
+ print(f"開始處理 {len(sentences)} 個句子...")
632
+
633
+ for i, sentence in enumerate(sentences):
634
+ print(f"處理第 {i+1}/{len(sentences)} 個句子...")
635
+ result = self.predict_single(sentence)
636
+ results.append(result)
637
+
638
+ # 建立摘要統計
639
+ summary = self._generate_summary(results)
640
+
641
+ final_output = {
642
+ "summary": summary,
643
+ "total_sentences": len(results),
644
+ "predictions": results
645
+ }
646
+
647
+ # 保存結果
648
+ if output_file:
649
+ with open(output_file, "w", encoding="utf-8") as f:
650
+ json.dump(final_output, f, ensure_ascii=False, indent=2)
651
+ print(f"結果已保存到: {output_file}")
652
+
653
+ return final_output
654
+
655
+ def _generate_summary(self, results: List[dict]) -> dict:
656
+ """生成預測結果摘要"""
657
+ if not results:
658
+ return {}
659
+
660
+ # 統計各類別預測數量
661
+ class_counts = defaultdict(int)
662
+ confidence_scores = []
663
+ fluency_scores = []
664
+ severity_levels = defaultdict(int)
665
+
666
+ for result in results:
667
+ if "error" not in result:
668
+ predicted_class = result["prediction"]["predicted_class"]
669
+ confidence = result["prediction"]["confidence"]
670
+ fluency = result["additional_predictions"]["fluency_score"]
671
+ severity = result["additional_predictions"]["predicted_severity_level"]
672
+
673
+ class_counts[predicted_class] += 1
674
+ confidence_scores.append(confidence)
675
+ fluency_scores.append(fluency)
676
+ severity_levels[severity] += 1
677
+
678
+ # 計算統計數據
679
+ avg_confidence = np.mean(confidence_scores) if confidence_scores else 0
680
+ avg_fluency = np.mean(fluency_scores) if fluency_scores else 0
681
+
682
+ summary = {
683
+ "classification_distribution": dict(class_counts),
684
+ "classification_percentages": {
685
+ k: f"{v/len(results)*100:.1f}%"
686
+ for k, v in class_counts.items()
687
+ },
688
+ "average_confidence": f"{avg_confidence:.3f}",
689
+ "average_fluency_score": f"{avg_fluency:.3f}",
690
+ "severity_distribution": dict(severity_levels),
691
+ "confidence_statistics": {
692
+ "mean": f"{np.mean(confidence_scores):.3f}",
693
+ "std": f"{np.std(confidence_scores):.3f}",
694
+ "min": f"{np.min(confidence_scores):.3f}",
695
+ "max": f"{np.max(confidence_scores):.3f}"
696
+ } if confidence_scores else {},
697
+ "most_common_prediction": max(class_counts.items(), key=lambda x: x[1])[0] if class_counts else "None"
698
+ }
699
+
700
+ return summary
701
+
702
+ def generate_detailed_report(self, results: List[dict], output_dir: str = "./inference_results"):
703
+ """生成詳細的分析報告"""
704
+ os.makedirs(output_dir, exist_ok=True)
705
+
706
+ # 建立詳細的CSV報告
707
+ report_data = []
708
+ for result in results:
709
+ if "error" not in result:
710
+ row = {
711
+ "sentence_id": result["sentence_id"],
712
+ "predicted_class": result["prediction"]["predicted_class"],
713
+ "confidence": result["prediction"]["confidence"],
714
+ "class_name": result["class_description"]["name"],
715
+ "severity_level": result["additional_predictions"]["predicted_severity_level"],
716
+ "fluency_score": result["additional_predictions"]["fluency_score"],
717
+ "fluency_rating": result["additional_predictions"]["fluency_rating"],
718
+ "input_text": result["input_text"]
719
+ }
720
+
721
+ # 添加各類別機率
722
+ for aphasia_type in self.aphasia_types_mapping.keys():
723
+ row[f"prob_{aphasia_type}"] = result["probability_distribution"][aphasia_type]["probability"]
724
+
725
+ report_data.append(row)
726
+
727
+ # 保存CSV
728
+ if report_data:
729
+ df = pd.DataFrame(report_data)
730
+ df.to_csv(os.path.join(output_dir, "detailed_predictions.csv"), index=False, encoding='utf-8')
731
+
732
+ # 生成統計摘要
733
+ summary_stats = {
734
+ "total_predictions": len(report_data),
735
+ "class_distribution": df["predicted_class"].value_counts().to_dict(),
736
+ "average_confidence": df["confidence"].mean(),
737
+ "confidence_std": df["confidence"].std(),
738
+ "average_fluency": df["fluency_score"].mean(),
739
+ "fluency_std": df["fluency_score"].std(),
740
+ "severity_distribution": df["severity_level"].value_counts().to_dict()
741
+ }
742
+
743
+ with open(os.path.join(output_dir, "summary_statistics.json"), "w", encoding="utf-8") as f:
744
+ json.dump(summary_stats, f, ensure_ascii=False, indent=2)
745
+
746
+ print(f"詳細報告已生成並保存到: {output_dir}")
747
+ return df
748
+
749
+ return None
750
+
751
+
752
+ def main():
753
+ """主程式 - 命令行介面"""
754
+ import argparse
755
+
756
+ parser = argparse.ArgumentParser(description="失語症分類推理系統")
757
+ parser.add_argument("--model_dir", type=str, default = '/workspace/SH001/adaptive_aphasia_model',
758
+ help="訓練好的模型目錄路徑")
759
+ parser.add_argument("--input_file", type=str, default = '/workspace/SH001/website/sample.input.json',
760
+ help="輸入JSON文件路徑")
761
+ parser.add_argument("--output_file", type=str, default="./aphasia_predictions.json",
762
+ help="輸出JSON文件路徑")
763
+ parser.add_argument("--report_dir", type=str, default="./inference_results",
764
+ help="詳細報告輸出目錄")
765
+ parser.add_argument("--generate_report", action="store_true",
766
+ help="是否生成詳細的CSV報告")
767
+
768
+ args = parser.parse_args()
769
+
770
+ try:
771
+ # 初始化推理系統
772
+ print("正在初始化推理系統...")
773
+ inference_system = AphasiaInferenceSystem(args.model_dir)
774
+
775
+ # 執行批次預測
776
+ print("開始執行批次預測...")
777
+ results = inference_system.predict_batch(args.input_file, args.output_file)
778
+
779
+ # 生成詳細報告
780
+ if args.generate_report:
781
+ print("生成詳細報告...")
782
+ inference_system.generate_detailed_report(results["predictions"], args.report_dir)
783
+
784
+ # 顯示摘要
785
+ print("\n=== 預測摘要 ===")
786
+ summary = results["summary"]
787
+ print(f"總句子數: {results['total_sentences']}")
788
+ print(f"平均信心度: {summary.get('average_confidence', 'N/A')}")
789
+ print(f"平均流利度: {summary.get('average_fluency_score', 'N/A')}")
790
+ print(f"最常見預測: {summary.get('most_common_prediction', 'N/A')}")
791
+
792
+ print("\n類別分佈:")
793
+ for class_name, count in summary.get("classification_distribution", {}).items():
794
+ percentage = summary.get("classification_percentages", {}).get(class_name, "0%")
795
+ print(f" {class_name}: {count} ({percentage})")
796
+
797
+ print(f"\n結果已保存到: {args.output_file}")
798
+
799
+ except Exception as e:
800
+ print(f"錯誤: {str(e)}")
801
+ import traceback
802
+ traceback.print_exc()
803
+
804
+
805
+ # 使用範例
806
+ def example_usage():
807
+ """使用範例"""
808
+
809
+ # 1. 基本使用
810
+ print("=== 失語症分類推理系統使用範例 ===\n")
811
+
812
+ # 範例輸入數據
813
+ sample_input = {
814
+ "sentences": [
815
+ {
816
+ "sentence_id": "S1",
817
+ "aphasia_type": "BROCA", # 這在推理時會被忽略
818
+ "dialogues": [
819
+ {
820
+ "INV": [
821
+ {
822
+ "tokens": ["how", "are", "you", "feeling"],
823
+ "word_pos_ids": [9, 10, 5, 6],
824
+ "word_grammar_ids": [[1, 4, 11], [2, 4, 2], [3, 4, 1], [4, 0, 3]],
825
+ "word_durations": [["how", 300], ["are", 200], ["you", 150], ["feeling", 500]]
826
+ }
827
+ ],
828
+ "PAR": [
829
+ {
830
+ "tokens": ["I", "feel", "good"],
831
+ "word_pos_ids": [1, 6, 8],
832
+ "word_grammar_ids": [[1, 2, 1], [2, 3, 2], [3, 4, 8]],
833
+ "word_durations": [["I", 200], ["feel", 400], ["good", 600]]
834
+ }
835
+ ]
836
+ }
837
+ ]
838
+ }
839
+ ]
840
+ }
841
+
842
+ # 保存範例輸入
843
+ with open("sample_input.json", "w", encoding="utf-8") as f:
844
+ json.dump(sample_input, f, ensure_ascii=False, indent=2)
845
+
846
+ print("範例輸入文件已創建: sample_input.json")
847
+
848
+ # 顯示使用說明
849
+ usage_instructions = """
850
+ 使用方法:
851
+
852
+ 1. 命令行使用:
853
+ python aphasia_inference.py \\
854
+ --model_dir ./adaptive_aphasia_model \\
855
+ --input_file sample_input.json \\
856
+ --output_file predictions.json \\
857
+ --generate_report \\
858
+ --report_dir ./results
859
+
860
+ 2. Python代碼使用:
861
+ from aphasia_inference import AphasiaInferenceSystem
862
+
863
+ # 初始化系統
864
+ system = AphasiaInferenceSystem("./adaptive_aphasia_model")
865
+
866
+ # 單個預測
867
+ with open("sample_input.json", "r") as f:
868
+ data = json.load(f)
869
+ result = system.predict_single(data["sentences"][0])
870
+
871
+ # 批次預測
872
+ results = system.predict_batch("sample_input.json", "output.json")
873
+
874
+ 3. 輸出格式:
875
+ - JSON格式包含詳細的預測結果和機率分佈
876
+ - CSV格式包含表格化的預測數據
877
+ - 統計摘要包含整體分析結果
878
+
879
+ 4. 支援的失語症類型:
880
+ - BROCA: 布若卡失語症
881
+ - WERNICKE: 韋尼克失語症
882
+ - ANOMIC: 命名性失語症
883
+ - CONDUCTION: 傳導性失語症
884
+ - GLOBAL: 全面性失語症
885
+ - 以及其他類型...
886
+ """
887
+
888
+ print(usage_instructions)
889
+
890
+
891
+ if __name__ == "__main__":
892
+ # 如果作為腳本執行,運行主程式
893
+ main()
894
+
895
+ # 如果想看使用範例,取消下面這行的註釋
896
+ # example_usage()
Output.json ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,12 +1 @@
1
- ---
2
- title: Aphasia Classification
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
1
+ # Aphasia-Classifier
 
 
 
 
 
 
 
 
 
 
 
added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "[DIALOGUE]": 30522,
3
+ "[HESITATION]": 30526,
4
+ "[PAUSE]": 30524,
5
+ "[REPEAT]": 30525,
6
+ "[TURN]": 30523
7
+ }
aphasia_class_2025_8_5--testing.py ADDED
@@ -0,0 +1,1712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Advanced Multi-Modal Aphasia Classification System
4
+ With Adaptive Learning Rate and Comprehensive Reporting
5
+ """
6
+
7
+ import re
8
+ import json
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import time
13
+ import datetime
14
+ import numpy as np
15
+ import os
16
+ import random
17
+ import csv
18
+ import math
19
+ from collections import Counter, defaultdict
20
+ from typing import Dict, List, Optional, Tuple, Union
21
+ from dataclasses import dataclass
22
+
23
+ import torch.optim as optim
24
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
25
+ from transformers import (
26
+ AutoTokenizer, AutoModel, AutoConfig,
27
+ TrainingArguments, Trainer, TrainerCallback,
28
+ EarlyStoppingCallback, get_cosine_schedule_with_warmup,
29
+ default_data_collator, set_seed
30
+ )
31
+
32
+ import seaborn as sns
33
+ import matplotlib.pyplot as plt
34
+ import pandas as pd
35
+ from sklearn.metrics import (
36
+ accuracy_score, f1_score, precision_score, recall_score,
37
+ confusion_matrix, classification_report, roc_auc_score
38
+ )
39
+ from sklearn.model_selection import StratifiedKFold
40
+ import gc
41
+ from scipy import stats
42
+
43
+ # Environment setup for stability
44
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
45
+ os.environ["TORCH_USE_CUDA_DSA"] = "1"
46
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
47
+ json_file = '/workspace/SH001/aphasia_data_augmented.json'
48
+
49
+ # Set seeds for reproducibility
50
+ def set_all_seeds(seed=42):
51
+ random.seed(seed)
52
+ np.random.seed(seed)
53
+ torch.manual_seed(seed)
54
+ torch.cuda.manual_seed_all(seed)
55
+ os.environ['PYTHONHASHSEED'] = str(seed)
56
+
57
+ set_all_seeds(42)
58
+
59
+ # Configuration
60
+ @dataclass
61
+ class ModelConfig:
62
+ # Model architecture
63
+ model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
64
+ max_length: int = 512
65
+ hidden_size: int = 768
66
+
67
+ # Feature dimensions
68
+ pos_vocab_size: int = 150
69
+ pos_emb_dim: int = 64
70
+ grammar_dim: int = 3
71
+ grammar_hidden_dim: int = 64
72
+ duration_hidden_dim: int = 128
73
+ prosody_dim: int = 32
74
+
75
+ # Multi-head attention
76
+ num_attention_heads: int = 8
77
+ attention_dropout: float = 0.3
78
+
79
+ # Classification head
80
+ classifier_hidden_dims: List[int] = None
81
+ dropout_rate: float = 0.3
82
+ activation_fn: str = "tanh"
83
+
84
+ # Training
85
+ learning_rate: float = 5e-4
86
+ weight_decay: float = 0.01
87
+ warmup_ratio: float = 0.1
88
+ batch_size: int = 10
89
+ num_epochs: int = 500
90
+ gradient_accumulation_steps: int = 4
91
+
92
+ # Adaptive Learning Rate Parameters
93
+ adaptive_lr: bool = True
94
+ lr_patience: int = 3 # Patience for learning rate adjustment
95
+ lr_factor: float = 0.8 # Factor to multiply learning rate
96
+ lr_increase_factor: float = 1.2 # Factor to increase learning rate
97
+ min_lr: float = 1e-6
98
+ max_lr: float = 1e-3
99
+ oscillation_amplitude: float = 0.1 # For sinusoidal oscillation
100
+
101
+ # Advanced techniques
102
+ use_focal_loss: bool = True
103
+ focal_alpha: float = 1.0
104
+ focal_gamma: float = 2.0
105
+ use_mixup: bool = False
106
+ mixup_alpha: float = 0.2
107
+ use_label_smoothing: bool = True
108
+ label_smoothing: float = 0.1
109
+
110
+ def __post_init__(self):
111
+ if self.classifier_hidden_dims is None:
112
+ self.classifier_hidden_dims = [512, 256]
113
+
114
+ # Utility functions
115
+ def log_message(message):
116
+ timestamp = datetime.datetime.now().isoformat()
117
+ full_message = f"{timestamp}: {message}"
118
+ log_file = "./training_log.txt"
119
+ with open(log_file, "a", encoding="utf-8") as f:
120
+ f.write(full_message + "\n")
121
+ print(full_message, flush=True)
122
+
123
+ def clear_memory():
124
+ gc.collect()
125
+ if torch.cuda.is_available():
126
+ torch.cuda.empty_cache()
127
+
128
+ def normalize_type(t):
129
+ return t.strip().upper() if isinstance(t, str) else t
130
+
131
+ # Adaptive Learning Rate Scheduler
132
+ class AdaptiveLearningRateScheduler:
133
+ """智能學習率調度器,結合多種策略"""
134
+ def __init__(self, optimizer, config: ModelConfig, total_steps: int):
135
+ self.optimizer = optimizer
136
+ self.config = config
137
+ self.total_steps = total_steps
138
+
139
+ # 歷史記錄
140
+ self.loss_history = []
141
+ self.f1_history = []
142
+ self.accuracy_history = []
143
+ self.lr_history = []
144
+
145
+ # 狀態追蹤
146
+ self.plateau_counter = 0
147
+ self.best_f1 = 0.0
148
+ self.best_loss = float('inf')
149
+ self.step_count = 0
150
+
151
+ # 初始學習率
152
+ self.base_lr = config.learning_rate
153
+ self.current_lr = self.base_lr
154
+
155
+ log_message(f"Adaptive LR Scheduler initialized with base_lr={self.base_lr}")
156
+
157
+ def calculate_slope(self, values, window=3):
158
+ """計算近期數值的斜率"""
159
+ if len(values) < window:
160
+ return 0.0
161
+
162
+ recent_values = values[-window:]
163
+ x = np.arange(len(recent_values))
164
+ slope, _, _, _, _ = stats.linregress(x, recent_values)
165
+ return slope
166
+
167
+ def exponential_adjustment(self, current_value, target_value, base_factor=1.1):
168
+ """指數調整函數"""
169
+ ratio = current_value / target_value if target_value != 0 else 1.0
170
+ factor = math.exp(-ratio) * base_factor
171
+ return factor
172
+
173
+ def logarithmic_adjustment(self, current_value, threshold=0.1):
174
+ """對數調整函數"""
175
+ if current_value <= 0:
176
+ return 1.0
177
+ factor = math.log(1 + current_value / threshold)
178
+ return max(0.5, min(2.0, factor))
179
+
180
+ def sinusoidal_oscillation(self, step, amplitude=None):
181
+ """正弦波動調整"""
182
+ if amplitude is None:
183
+ amplitude = self.config.oscillation_amplitude
184
+
185
+ # 基於步數的正弦波動
186
+ phase = 2 * math.pi * step / (self.total_steps / 4) # 4個週期
187
+ oscillation = 1 + amplitude * math.sin(phase)
188
+ return oscillation
189
+
190
+ def cosine_decay(self, step):
191
+ """餘弦衰減"""
192
+ progress = step / self.total_steps
193
+ decay = 0.5 * (1 + math.cos(math.pi * progress))
194
+ return decay
195
+
196
+ def adaptive_lr_calculation(self, current_loss, current_f1, current_acc):
197
+ """智能學習率計算"""
198
+ # 記錄歷史
199
+ self.loss_history.append(current_loss)
200
+ self.f1_history.append(current_f1)
201
+ self.accuracy_history.append(current_acc)
202
+
203
+ # 計算斜率
204
+ loss_slope = self.calculate_slope(self.loss_history)
205
+ f1_slope = self.calculate_slope(self.f1_history)
206
+ acc_slope = self.calculate_slope(self.accuracy_history)
207
+
208
+ # 基礎學習率調整因子
209
+ adjustment_factor = 1.0
210
+
211
+ # 1. 基於Loss斜率的調整
212
+ if abs(loss_slope) < 0.001: # Loss plateau
213
+ log_message(f"Loss plateau detected (slope: {loss_slope:.6f})")
214
+ # 指數增加學習率
215
+ exp_factor = self.exponential_adjustment(abs(loss_slope), 0.01, 1.15)
216
+ adjustment_factor *= exp_factor
217
+
218
+ elif current_loss > 2.0: # Loss太高
219
+ log_message(f"High loss detected: {current_loss:.4f}")
220
+ # 對數調整
221
+ log_factor = self.logarithmic_adjustment(current_loss, 1.0)
222
+ adjustment_factor *= log_factor
223
+
224
+ # 2. 基於F1分數的調整
225
+ if current_f1 < 0.3: # F1太低
226
+ log_message(f"Low F1 detected: {current_f1:.4f}")
227
+ # 指數增加學習率
228
+ exp_factor = self.exponential_adjustment(0.3, current_f1, 1.2)
229
+ adjustment_factor *= exp_factor
230
+
231
+ elif abs(f1_slope) < 0.001: # F1 plateau
232
+ log_message(f"F1 plateau detected (slope: {f1_slope:.6f})")
233
+ adjustment_factor *= 1.1
234
+
235
+ # 3. 添加正弦波動性
236
+ sin_factor = self.sinusoidal_oscillation(self.step_count)
237
+
238
+ # 4. 添加餘弦衰減
239
+ cos_factor = self.cosine_decay(self.step_count)
240
+
241
+ # 綜合調整
242
+ final_factor = adjustment_factor * sin_factor * (0.3 + 0.7 * cos_factor)
243
+
244
+ # 計算新的學習率
245
+ new_lr = self.current_lr * final_factor
246
+
247
+ # 限制學習率範圍
248
+ new_lr = max(self.config.min_lr, min(self.config.max_lr, new_lr))
249
+
250
+ # 更新學習率
251
+ if abs(new_lr - self.current_lr) > 1e-7: # 只有變化足夠大才更新
252
+ self.current_lr = new_lr
253
+ for param_group in self.optimizer.param_groups:
254
+ param_group['lr'] = new_lr
255
+
256
+ log_message(f"Learning rate adjusted: {new_lr:.2e} (factor: {final_factor:.3f})")
257
+ log_message(f" - Loss slope: {loss_slope:.6f}, F1 slope: {f1_slope:.6f}")
258
+ log_message(f" - Sin factor: {sin_factor:.3f}, Cos factor: {cos_factor:.3f}")
259
+
260
+ self.lr_history.append(self.current_lr)
261
+ self.step_count += 1
262
+
263
+ return self.current_lr
264
+
265
+ # Training History Tracker
266
+ class TrainingHistoryTracker:
267
+ """訓練歷史記錄器"""
268
+ def __init__(self):
269
+ self.history = {
270
+ 'epoch': [],
271
+ 'train_loss': [],
272
+ 'eval_loss': [],
273
+ 'train_accuracy': [],
274
+ 'eval_accuracy': [],
275
+ 'train_f1': [],
276
+ 'eval_f1': [],
277
+ 'learning_rate': [],
278
+ 'train_precision': [],
279
+ 'eval_precision': [],
280
+ 'train_recall': [],
281
+ 'eval_recall': []
282
+ }
283
+
284
+ def update(self, epoch, metrics):
285
+ """更新歷史記錄"""
286
+ self.history['epoch'].append(epoch)
287
+ for key, value in metrics.items():
288
+ if key in self.history:
289
+ self.history[key].append(value)
290
+
291
+ def save_history(self, output_dir):
292
+ """保存歷史記錄"""
293
+ df = pd.DataFrame(self.history)
294
+ df.to_csv(os.path.join(output_dir, "training_history.csv"), index=False)
295
+ return df
296
+
297
+ def plot_training_curves(self, output_dir):
298
+ """繪製訓練曲線"""
299
+ if not self.history['epoch']:
300
+ return
301
+
302
+ # 設置圖表樣式
303
+ plt.style.use('seaborn-v0_8')
304
+ fig, axes = plt.subplots(2, 3, figsize=(18, 12))
305
+
306
+ epochs = self.history['epoch']
307
+
308
+ # 1. Loss曲線
309
+ axes[0, 0].plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
310
+ axes[0, 0].plot(epochs, self.history['eval_loss'], 'r-', label='Eval Loss', linewidth=2)
311
+ axes[0, 0].set_title('Loss Over Time', fontsize=14, fontweight='bold')
312
+ axes[0, 0].set_xlabel('Epoch')
313
+ axes[0, 0].set_ylabel('Loss')
314
+ axes[0, 0].legend()
315
+ axes[0, 0].grid(True, alpha=0.3)
316
+
317
+ # 2. 準確率曲線
318
+ axes[0, 1].plot(epochs, self.history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2)
319
+ axes[0, 1].plot(epochs, self.history['eval_accuracy'], 'r-', label='Eval Accuracy', linewidth=2)
320
+ axes[0, 1].set_title('Accuracy Over Time', fontsize=14, fontweight='bold')
321
+ axes[0, 1].set_xlabel('Epoch')
322
+ axes[0, 1].set_ylabel('Accuracy')
323
+ axes[0, 1].legend()
324
+ axes[0, 1].grid(True, alpha=0.3)
325
+
326
+ # 3. F1分數曲線
327
+ axes[0, 2].plot(epochs, self.history['train_f1'], 'b-', label='Train F1', linewidth=2)
328
+ axes[0, 2].plot(epochs, self.history['eval_f1'], 'r-', label='Eval F1', linewidth=2)
329
+ axes[0, 2].set_title('F1 Score Over Time', fontsize=14, fontweight='bold')
330
+ axes[0, 2].set_xlabel('Epoch')
331
+ axes[0, 2].set_ylabel('F1 Score')
332
+ axes[0, 2].legend()
333
+ axes[0, 2].grid(True, alpha=0.3)
334
+
335
+ # 4. 學習率曲線
336
+ axes[1, 0].plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
337
+ axes[1, 0].set_title('Learning Rate Over Time', fontsize=14, fontweight='bold')
338
+ axes[1, 0].set_xlabel('Epoch')
339
+ axes[1, 0].set_ylabel('Learning Rate')
340
+ axes[1, 0].set_yscale('log')
341
+ axes[1, 0].grid(True, alpha=0.3)
342
+
343
+ # 5. Precision曲線
344
+ axes[1, 1].plot(epochs, self.history['train_precision'], 'b-', label='Train Precision', linewidth=2)
345
+ axes[1, 1].plot(epochs, self.history['eval_precision'], 'r-', label='Eval Precision', linewidth=2)
346
+ axes[1, 1].set_title('Precision Over Time', fontsize=14, fontweight='bold')
347
+ axes[1, 1].set_xlabel('Epoch')
348
+ axes[1, 1].set_ylabel('Precision')
349
+ axes[1, 1].legend()
350
+ axes[1, 1].grid(True, alpha=0.3)
351
+
352
+ # 6. Recall曲線
353
+ axes[1, 2].plot(epochs, self.history['train_recall'], 'b-', label='Train Recall', linewidth=2)
354
+ axes[1, 2].plot(epochs, self.history['eval_recall'], 'r-', label='Eval Recall', linewidth=2)
355
+ axes[1, 2].set_title('Recall Over Time', fontsize=14, fontweight='bold')
356
+ axes[1, 2].set_xlabel('Epoch')
357
+ axes[1, 2].set_ylabel('Recall')
358
+ axes[1, 2].legend()
359
+ axes[1, 2].grid(True, alpha=0.3)
360
+
361
+ plt.tight_layout()
362
+ plt.savefig(os.path.join(output_dir, "training_curves.png"), dpi=300, bbox_inches='tight')
363
+ plt.close()
364
+
365
+ # Focal loss implementation
366
+ class FocalLoss(nn.Module):
367
+ def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
368
+ super().__init__()
369
+ self.alpha = alpha
370
+ self.gamma = gamma
371
+ self.reduction = reduction
372
+
373
+ def forward(self, inputs, targets):
374
+ ce_loss = F.cross_entropy(inputs, targets, reduction='none')
375
+ pt = torch.exp(-ce_loss)
376
+ focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
377
+
378
+ if self.reduction == 'mean':
379
+ return focal_loss.mean()
380
+ elif self.reduction == 'sum':
381
+ return focal_loss.sum()
382
+ else:
383
+ return focal_loss
384
+
385
+ # Stable positional encoding
386
+ class StablePositionalEncoding(nn.Module):
387
+ """Simplified but stable positional encoding"""
388
+ def __init__(self, d_model: int, max_len: int = 5000):
389
+ super().__init__()
390
+ self.d_model = d_model
391
+
392
+ # Traditional sinusoidal encoding
393
+ pe = torch.zeros(max_len, d_model)
394
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
395
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
396
+ (-math.log(10000.0) / d_model))
397
+
398
+ pe[:, 0::2] = torch.sin(position * div_term)
399
+ pe[:, 1::2] = torch.cos(position * div_term)
400
+
401
+ self.register_buffer('pe', pe.unsqueeze(0))
402
+
403
+ # Simple learnable component
404
+ self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01)
405
+
406
+ def forward(self, x):
407
+ seq_len = x.size(1)
408
+ sinusoidal = self.pe[:, :seq_len, :].to(x.device)
409
+ learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
410
+ return x + 0.1 * (sinusoidal + learnable)
411
+
412
+ # Stable multi-head attention
413
+ class StableMultiHeadAttention(nn.Module):
414
+ """Stable multi-head attention for feature fusion"""
415
+ def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3):
416
+ super().__init__()
417
+ self.num_heads = num_heads
418
+ self.feature_dim = feature_dim
419
+ self.head_dim = feature_dim // num_heads
420
+
421
+ assert feature_dim % num_heads == 0
422
+
423
+ self.query = nn.Linear(feature_dim, feature_dim)
424
+ self.key = nn.Linear(feature_dim, feature_dim)
425
+ self.value = nn.Linear(feature_dim, feature_dim)
426
+ self.dropout = nn.Dropout(dropout)
427
+ self.output_proj = nn.Linear(feature_dim, feature_dim)
428
+ self.layer_norm = nn.LayerNorm(feature_dim)
429
+
430
+ def forward(self, x, mask=None):
431
+ batch_size, seq_len, _ = x.size()
432
+
433
+ Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
434
+ K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
435
+ V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
436
+
437
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
438
+
439
+ if mask is not None:
440
+ if mask.dim() == 2:
441
+ mask = mask.unsqueeze(1).unsqueeze(1)
442
+ scores.masked_fill_(mask == 0, -1e9)
443
+
444
+ attn_weights = F.softmax(scores, dim=-1)
445
+ attn_weights = self.dropout(attn_weights)
446
+
447
+ context = torch.matmul(attn_weights, V)
448
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim)
449
+
450
+ output = self.output_proj(context)
451
+ return self.layer_norm(output + x)
452
+
453
+ # Stable linguistic feature extractor
454
+ class StableLinguisticFeatureExtractor(nn.Module):
455
+ """Stable linguistic feature processing"""
456
+ def __init__(self, config: ModelConfig):
457
+ super().__init__()
458
+ self.config = config
459
+
460
+ # POS embeddings
461
+ self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0)
462
+ self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4)
463
+
464
+ # Grammar feature processing
465
+ self.grammar_projection = nn.Sequential(
466
+ nn.Linear(config.grammar_dim, config.grammar_hidden_dim),
467
+ nn.Tanh(),
468
+ nn.LayerNorm(config.grammar_hidden_dim),
469
+ nn.Dropout(config.dropout_rate * 0.3)
470
+ )
471
+
472
+ # Duration processing
473
+ self.duration_projection = nn.Sequential(
474
+ nn.Linear(1, config.duration_hidden_dim),
475
+ nn.Tanh(),
476
+ nn.LayerNorm(config.duration_hidden_dim)
477
+ )
478
+
479
+ # Prosody processing
480
+ self.prosody_projection = nn.Sequential(
481
+ nn.Linear(config.prosody_dim, config.prosody_dim),
482
+ nn.ReLU(),
483
+ nn.LayerNorm(config.prosody_dim)
484
+ )
485
+
486
+ # Feature fusion
487
+ total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
488
+ config.duration_hidden_dim + config.prosody_dim)
489
+ self.feature_fusion = nn.Sequential(
490
+ nn.Linear(total_feature_dim, total_feature_dim // 2),
491
+ nn.Tanh(),
492
+ nn.LayerNorm(total_feature_dim // 2),
493
+ nn.Dropout(config.dropout_rate)
494
+ )
495
+
496
+ def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask):
497
+ batch_size, seq_len = pos_ids.size()
498
+
499
+ # Process POS features with clamping
500
+ pos_ids_clamped = pos_ids.clamp(0, self.config.pos_vocab_size - 1)
501
+ pos_embeds = self.pos_embedding(pos_ids_clamped)
502
+ pos_features = self.pos_attention(pos_embeds, attention_mask)
503
+
504
+ # Process grammar features
505
+ grammar_features = self.grammar_projection(grammar_ids.float())
506
+
507
+ # Process duration features
508
+ duration_features = self.duration_projection(durations.unsqueeze(-1).float())
509
+
510
+ # Process prosodic features
511
+ prosody_features = self.prosody_projection(prosody_features.float())
512
+
513
+ # Combine features
514
+ combined_features = torch.cat([
515
+ pos_features, grammar_features, duration_features, prosody_features
516
+ ], dim=-1)
517
+
518
+ # Feature fusion
519
+ fused_features = self.feature_fusion(combined_features)
520
+
521
+ # Global pooling
522
+ mask_expanded = attention_mask.unsqueeze(-1).float()
523
+ pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1)
524
+
525
+ return pooled_features
526
+
527
+ # Main classifier with stability improvements
528
+ class StableAphasiaClassifier(nn.Module):
529
+ """Stable aphasia classification model"""
530
+ def __init__(self, config: ModelConfig, num_labels: int):
531
+ super().__init__()
532
+ self.config = config
533
+ self.num_labels = num_labels
534
+
535
+ # Pre-trained model
536
+ self.bert = AutoModel.from_pretrained(config.model_name)
537
+ self.bert_config = self.bert.config
538
+
539
+ # Freeze embeddings for stability
540
+ for param in self.bert.embeddings.parameters():
541
+ param.requires_grad = False
542
+
543
+ # Positional encoding
544
+ self.positional_encoder = StablePositionalEncoding(
545
+ d_model=self.bert_config.hidden_size,
546
+ max_len=config.max_length
547
+ )
548
+
549
+ # Linguistic feature extractor
550
+ self.linguistic_extractor = StableLinguisticFeatureExtractor(config)
551
+
552
+ # Calculate dimensions
553
+ bert_dim = self.bert_config.hidden_size
554
+ linguistic_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
555
+ config.duration_hidden_dim + config.prosody_dim) // 2
556
+
557
+ # Feature fusion
558
+ self.feature_fusion = nn.Sequential(
559
+ nn.Linear(bert_dim + linguistic_dim, bert_dim),
560
+ nn.LayerNorm(bert_dim),
561
+ nn.Tanh(),
562
+ nn.Dropout(config.dropout_rate)
563
+ )
564
+
565
+ # Classifier
566
+ self.classifier = self._build_classifier(bert_dim, num_labels)
567
+
568
+ # Multi-task heads (simplified)
569
+ self.severity_head = nn.Sequential(
570
+ nn.Linear(bert_dim, 4),
571
+ nn.Softmax(dim=-1)
572
+ )
573
+
574
+ self.fluency_head = nn.Sequential(
575
+ nn.Linear(bert_dim, 1),
576
+ nn.Sigmoid()
577
+ )
578
+
579
+ def _build_classifier(self, input_dim: int, num_labels: int):
580
+ layers = []
581
+ current_dim = input_dim
582
+
583
+ for hidden_dim in self.config.classifier_hidden_dims:
584
+ layers.extend([
585
+ nn.Linear(current_dim, hidden_dim),
586
+ nn.LayerNorm(hidden_dim),
587
+ nn.Tanh(),
588
+ nn.Dropout(self.config.dropout_rate)
589
+ ])
590
+ current_dim = hidden_dim
591
+
592
+ layers.append(nn.Linear(current_dim, num_labels))
593
+ return nn.Sequential(*layers)
594
+
595
+ def forward(self, input_ids, attention_mask, labels=None,
596
+ word_pos_ids=None, word_grammar_ids=None, word_durations=None,
597
+ prosody_features=None, **kwargs):
598
+
599
+ # BERT encoding
600
+ bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
601
+ sequence_output = bert_outputs.last_hidden_state
602
+
603
+ # Apply positional encoding
604
+ position_enhanced = self.positional_encoder(sequence_output)
605
+
606
+ # Attention pooling
607
+ pooled_output = self._attention_pooling(position_enhanced, attention_mask)
608
+
609
+ # Process linguistic features
610
+ if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]):
611
+ if prosody_features is None:
612
+ batch_size, seq_len = input_ids.size()
613
+ prosody_features = torch.zeros(
614
+ batch_size, seq_len, self.config.prosody_dim,
615
+ device=input_ids.device
616
+ )
617
+
618
+ linguistic_features = self.linguistic_extractor(
619
+ word_pos_ids, word_grammar_ids, word_durations,
620
+ prosody_features, attention_mask
621
+ )
622
+ else:
623
+ linguistic_features = torch.zeros(
624
+ input_ids.size(0),
625
+ (self.config.pos_emb_dim + self.config.grammar_hidden_dim +
626
+ self.config.duration_hidden_dim + self.config.prosody_dim) // 2,
627
+ device=input_ids.device
628
+ )
629
+
630
+ # Feature fusion
631
+ combined_features = torch.cat([pooled_output, linguistic_features], dim=1)
632
+ fused_features = self.feature_fusion(combined_features)
633
+
634
+ # Predictions
635
+ logits = self.classifier(fused_features)
636
+ severity_pred = self.severity_head(fused_features)
637
+ fluency_pred = self.fluency_head(fused_features)
638
+
639
+ # Loss computation
640
+ loss = None
641
+ if labels is not None:
642
+ loss = self._compute_loss(logits, labels)
643
+
644
+ return {
645
+ "logits": logits,
646
+ "severity_pred": severity_pred,
647
+ "fluency_pred": fluency_pred,
648
+ "loss": loss
649
+ }
650
+
651
+ def _attention_pooling(self, sequence_output, attention_mask):
652
+ """Attention-based pooling"""
653
+ attention_weights = torch.softmax(
654
+ torch.sum(sequence_output, dim=-1, keepdim=True), dim=1
655
+ )
656
+ attention_weights = attention_weights * attention_mask.unsqueeze(-1).float()
657
+ attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9)
658
+ pooled = torch.sum(sequence_output * attention_weights, dim=1)
659
+ return pooled
660
+
661
+ def _compute_loss(self, logits, labels):
662
+ if self.config.use_focal_loss:
663
+ focal_loss = FocalLoss(
664
+ alpha=self.config.focal_alpha,
665
+ gamma=self.config.focal_gamma,
666
+ reduction='mean'
667
+ )
668
+ return focal_loss(logits, labels)
669
+ else:
670
+ if self.config.use_label_smoothing:
671
+ return F.cross_entropy(
672
+ logits, labels,
673
+ label_smoothing=self.config.label_smoothing
674
+ )
675
+ else:
676
+ return F.cross_entropy(logits, labels)
677
+
678
+ # Stable dataset class
679
+ class StableAphasiaDataset(Dataset):
680
+ """Stable dataset with simplified processing"""
681
+ def __init__(self, sentences, tokenizer, aphasia_types_mapping, config: ModelConfig):
682
+ self.samples = []
683
+ self.tokenizer = tokenizer
684
+ self.config = config
685
+ self.aphasia_types_mapping = aphasia_types_mapping
686
+
687
+ # Add special tokens
688
+ special_tokens = ["[DIALOGUE]", "[TURN]", "[PAUSE]", "[REPEAT]", "[HESITATION]"]
689
+ tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
690
+
691
+ for idx, item in enumerate(sentences):
692
+ sentence_id = item.get("sentence_id", f"S{idx}")
693
+ aphasia_type = normalize_type(item.get("aphasia_type", ""))
694
+
695
+ if aphasia_type not in aphasia_types_mapping:
696
+ log_message(f"Skipping Sentence {sentence_id}: Invalid aphasia type '{aphasia_type}'")
697
+ continue
698
+
699
+ self._process_sentence(item, sentence_id, aphasia_type)
700
+
701
+ if not self.samples:
702
+ raise ValueError("No valid samples found in dataset!")
703
+
704
+ log_message(f"Dataset created with {len(self.samples)} samples")
705
+ self._print_class_distribution()
706
+
707
+ def _process_sentence(self, item, sentence_id, aphasia_type):
708
+ """Process sentence with stable approach"""
709
+ all_tokens, all_pos, all_grammar, all_durations = [], [], [], []
710
+
711
+ for dialogue_idx, dialogue in enumerate(item.get("dialogues", [])):
712
+ if dialogue_idx > 0:
713
+ all_tokens.append("[DIALOGUE]")
714
+ all_pos.append(0)
715
+ all_grammar.append([0, 0, 0])
716
+ all_durations.append(0.0)
717
+
718
+ for par in dialogue.get("PAR", []):
719
+ if "tokens" in par and par["tokens"]:
720
+ tokens = par["tokens"]
721
+ pos_ids = par.get("word_pos_ids", [0] * len(tokens))
722
+ grammar_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(tokens))
723
+ durations = par.get("word_durations", [0.0] * len(tokens))
724
+
725
+ all_tokens.extend(tokens)
726
+ all_pos.extend(pos_ids)
727
+ all_grammar.extend(grammar_ids)
728
+ all_durations.extend(durations)
729
+
730
+ if not all_tokens:
731
+ return
732
+
733
+ # Create sample
734
+ self._create_sample(all_tokens, all_pos, all_grammar, all_durations,
735
+ sentence_id, aphasia_type)
736
+
737
+ def _create_sample(self, tokens, pos_ids, grammar_ids, durations,
738
+ sentence_id, aphasia_type):
739
+ """Create training sample"""
740
+ # Tokenize
741
+ text = " ".join(tokens)
742
+ encoded = self.tokenizer(
743
+ text,
744
+ max_length=self.config.max_length,
745
+ padding="max_length",
746
+ truncation=True,
747
+ return_tensors="pt"
748
+ )
749
+
750
+ # Align features
751
+ aligned_pos, aligned_grammar, aligned_durations = self._align_features(
752
+ tokens, pos_ids, grammar_ids, durations, encoded
753
+ )
754
+
755
+ # Create prosody features
756
+ prosody_features = self._extract_prosodic_features(durations, tokens)
757
+ prosody_tensor = torch.tensor(prosody_features).unsqueeze(0).repeat(
758
+ self.config.max_length, 1
759
+ )
760
+
761
+ label = self.aphasia_types_mapping[aphasia_type]
762
+
763
+ sample = {
764
+ "input_ids": encoded["input_ids"].squeeze(0),
765
+ "attention_mask": encoded["attention_mask"].squeeze(0),
766
+ "labels": torch.tensor(label, dtype=torch.long),
767
+ "word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long),
768
+ "word_grammar_ids": torch.tensor(aligned_grammar, dtype=torch.long),
769
+ "word_durations": torch.tensor(aligned_durations, dtype=torch.float),
770
+ "prosody_features": prosody_tensor.float(),
771
+ "sentence_id": sentence_id
772
+ }
773
+ self.samples.append(sample)
774
+
775
+ def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded):
776
+ """Align features with BERT subtokens"""
777
+ subtoken_to_token = []
778
+
779
+ for token_idx, token in enumerate(tokens):
780
+ subtokens = self.tokenizer.tokenize(token)
781
+ subtoken_to_token.extend([token_idx] * len(subtokens))
782
+
783
+ aligned_pos = [0] # [CLS]
784
+ aligned_grammar = [[0, 0, 0]] # [CLS]
785
+ aligned_durations = [0.0] # [CLS]
786
+
787
+ for subtoken_idx in range(1, self.config.max_length - 1):
788
+ if subtoken_idx - 1 < len(subtoken_to_token):
789
+ original_idx = subtoken_to_token[subtoken_idx - 1]
790
+ aligned_pos.append(pos_ids[original_idx] if original_idx < len(pos_ids) else 0)
791
+ aligned_grammar.append(grammar_ids[original_idx] if original_idx < len(grammar_ids) else [0, 0, 0])
792
+ raw = durations[original_idx] if original_idx < len(durations) else 0.0
793
+ if isinstance(raw, list) and (isinstance(raw[1], int) and isinstance(raw[0], int)):
794
+ if len(raw) >= 2:
795
+ duration_val = int(raw[1]) - int(raw[0])
796
+ else:
797
+ duration_val = raw[0]
798
+ else:
799
+ duration_val = 0.0
800
+ aligned_durations.append(duration_val)
801
+ else:
802
+ aligned_pos.append(0)
803
+ aligned_grammar.append([0, 0, 0])
804
+ aligned_durations.append(0.0)
805
+
806
+ aligned_pos.append(0) # [SEP]
807
+ aligned_grammar.append([0, 0, 0]) # [SEP]
808
+ aligned_durations.append(0.0) # [SEP]
809
+
810
+ return aligned_pos, aligned_grammar, aligned_durations
811
+
812
+ def _extract_prosodic_features(self, durations, tokens):
813
+ """Extract prosodic features"""
814
+ if not durations:
815
+ return [0.0] * self.config.prosody_dim
816
+
817
+ valid_durations = [d for d in durations if isinstance(d, (int, float)) and d > 0]
818
+ if not valid_durations:
819
+ return [0.0] * self.config.prosody_dim
820
+
821
+ features = [
822
+ np.mean(valid_durations),
823
+ np.std(valid_durations),
824
+ np.median(valid_durations),
825
+ len([d for d in valid_durations if d > np.mean(valid_durations) * 1.5])
826
+ ]
827
+
828
+ # Pad to prosody_dim
829
+ while len(features) < self.config.prosody_dim:
830
+ features.append(0.0)
831
+
832
+ return features[:self.config.prosody_dim]
833
+
834
+ def _print_class_distribution(self):
835
+ """Print class distribution"""
836
+ label_counts = Counter(sample["labels"].item() for sample in self.samples)
837
+ reverse_mapping = {v: k for k, v in self.aphasia_types_mapping.items()}
838
+
839
+ log_message("\nClass Distribution:")
840
+ for label_id, count in sorted(label_counts.items()):
841
+ class_name = reverse_mapping.get(label_id, f"Unknown_{label_id}")
842
+ log_message(f" {class_name}: {count} samples")
843
+
844
+ def __len__(self):
845
+ return len(self.samples)
846
+
847
+ def __getitem__(self, idx):
848
+ return self.samples[idx]
849
+
850
+ # Stable data collator
851
+ def stable_collate_fn(batch):
852
+ """Stable data collation"""
853
+ if not batch or batch[0] is None:
854
+ return None
855
+
856
+ try:
857
+ max_length = batch[0]["input_ids"].size(0)
858
+
859
+ collated_batch = {
860
+ "input_ids": torch.stack([item["input_ids"] for item in batch]),
861
+ "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
862
+ "labels": torch.stack([item["labels"] for item in batch]),
863
+ "sentence_ids": [item.get("sentence_id", "N/A") for item in batch],
864
+ "word_pos_ids": torch.stack([item.get("word_pos_ids", torch.zeros(max_length, dtype=torch.long)) for item in batch]),
865
+ "word_grammar_ids": torch.stack([item.get("word_grammar_ids", torch.zeros(max_length, 3, dtype=torch.long)) for item in batch]),
866
+ "word_durations": torch.stack([item.get("word_durations", torch.zeros(max_length, dtype=torch.float)) for item in batch]),
867
+ "prosody_features": torch.stack([item.get("prosody_features", torch.zeros(max_length, 32, dtype=torch.float)) for item in batch])
868
+ }
869
+ return collated_batch
870
+ except Exception as e:
871
+ log_message(f"Collation error: {e}")
872
+ return None
873
+
874
+ # Enhanced Training callback with adaptive learning rate
875
+ class AdaptiveTrainingCallback(TrainerCallback):
876
+ """Enhanced training callback with adaptive learning rate and comprehensive tracking"""
877
+ def __init__(self, config: ModelConfig, patience=5, min_delta=0.8):
878
+ self.config = config
879
+ self.patience = patience
880
+ self.min_delta = min_delta
881
+ self.best_metric = float('-inf')
882
+ self.patience_counter = 0
883
+
884
+ # Learning rate scheduler
885
+ self.lr_scheduler = None
886
+
887
+ # History tracker
888
+ self.history_tracker = TrainingHistoryTracker()
889
+
890
+ # Metrics for current epoch
891
+ self.current_train_metrics = {}
892
+ self.current_eval_metrics = {}
893
+
894
+ def on_train_begin(self, args, state, control, **kwargs):
895
+ """Initialize learning rate scheduler"""
896
+ if self.config.adaptive_lr:
897
+ model = kwargs.get('model')
898
+ optimizer = kwargs.get('optimizer')
899
+ if optimizer and model:
900
+ total_steps = state.max_steps if state.max_steps > 0 else len(kwargs.get('train_dataloader', [])) * args.num_train_epochs
901
+ self.lr_scheduler = AdaptiveLearningRateScheduler(optimizer, self.config, total_steps)
902
+ log_message("Adaptive learning rate scheduler initialized")
903
+
904
+ def on_log(self, args, state, control, logs=None, **kwargs):
905
+ """Capture training metrics"""
906
+ if logs:
907
+ # Store training metrics
908
+ if 'train_loss' in logs:
909
+ self.current_train_metrics['loss'] = logs['train_loss']
910
+ if 'learning_rate' in logs:
911
+ self.current_train_metrics['lr'] = logs['learning_rate']
912
+
913
+ def on_evaluate(self, args, state, control, logs=None, **kwargs):
914
+ """Handle evaluation and learning rate adjustment"""
915
+ if logs is not None:
916
+ current_metric = logs.get('eval_f1', 0)
917
+ current_loss = logs.get('eval_loss', float('inf'))
918
+ current_acc = logs.get('eval_accuracy', 0)
919
+
920
+ # Store evaluation metrics
921
+ self.current_eval_metrics = {
922
+ 'loss': current_loss,
923
+ 'f1': current_metric,
924
+ 'accuracy': current_acc,
925
+ 'precision': logs.get('eval_precision_macro', 0),
926
+ 'recall': logs.get('eval_recall_macro', 0)
927
+ }
928
+
929
+ # Update history
930
+ epoch_metrics = {
931
+ 'train_loss': self.current_train_metrics.get('loss', 0),
932
+ 'eval_loss': current_loss,
933
+ 'train_accuracy': 0, # Will be computed separately if needed
934
+ 'eval_accuracy': current_acc,
935
+ 'train_f1': 0, # Will be computed separately if needed
936
+ 'eval_f1': current_metric,
937
+ 'learning_rate': self.current_train_metrics.get('lr', self.config.learning_rate),
938
+ 'train_precision': 0,
939
+ 'eval_precision': logs.get('eval_precision_macro', 0),
940
+ 'train_recall': 0,
941
+ 'eval_recall': logs.get('eval_recall_macro', 0)
942
+ }
943
+
944
+ self.history_tracker.update(state.epoch, epoch_metrics)
945
+
946
+ # Adaptive learning rate adjustment
947
+ if self.lr_scheduler and self.config.adaptive_lr:
948
+ new_lr = self.lr_scheduler.adaptive_lr_calculation(current_loss, current_metric, current_acc)
949
+ if current_acc > 0.84:
950
+ log_message(f"Target accuracy reached ({current_acc:.2%}) → stopping and saving model")
951
+ control.should_save = True
952
+ control.should_training_stop = True
953
+ return control
954
+ # Early stopping logic
955
+ if current_metric > self.best_metric + self.min_delta:
956
+ self.best_metric = current_metric
957
+ self.patience_counter = 0
958
+ log_message(f"New best F1 score: {current_metric:.4f}")
959
+ else:
960
+ self.patience_counter += 1
961
+ log_message(f"No improvement for {self.patience_counter} evaluations")
962
+
963
+ if self.patience_counter >= self.patience:
964
+ log_message("Early stopping triggered")
965
+ control.should_training_stop = True
966
+
967
+ clear_memory()
968
+
969
+ def on_train_end(self, args, state, control, **kwargs):
970
+ """Save training history at the end"""
971
+ output_dir = args.output_dir
972
+ self.history_tracker.save_history(output_dir)
973
+ self.history_tracker.plot_training_curves(output_dir)
974
+ log_message("Training history and curves saved")
975
+
976
+ # Metrics computation
977
+ def compute_comprehensive_metrics(pred):
978
+ """Compute comprehensive evaluation metrics"""
979
+ predictions = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions
980
+ labels = pred.label_ids
981
+
982
+ preds = np.argmax(predictions, axis=1)
983
+
984
+ acc = accuracy_score(labels, preds)
985
+ f1_macro = f1_score(labels, preds, average='macro', zero_division=0)
986
+ f1_weighted = f1_score(labels, preds, average='weighted', zero_division=0)
987
+ precision_macro = precision_score(labels, preds, average='macro', zero_division=0)
988
+ recall_macro = recall_score(labels, preds, average='macro', zero_division=0)
989
+
990
+ # Per-class metrics
991
+ f1_per_class = f1_score(labels, preds, average=None, zero_division=0)
992
+ precision_per_class = precision_score(labels, preds, average=None, zero_division=0)
993
+ recall_per_class = recall_score(labels, preds, average=None, zero_division=0)
994
+
995
+ return {
996
+ "accuracy": acc,
997
+ "f1": f1_weighted,
998
+ "f1_macro": f1_macro,
999
+ "precision_macro": precision_macro,
1000
+ "recall_macro": recall_macro,
1001
+ "f1_std": np.std(f1_per_class),
1002
+ "precision_std": np.std(precision_per_class),
1003
+ "recall_std": np.std(recall_per_class)
1004
+ }
1005
+
1006
+ # Enhanced analysis and visualization
1007
+ def generate_comprehensive_reports(trainer, eval_dataset, aphasia_types_mapping, tokenizer, output_dir):
1008
+ """Generate comprehensive analysis reports and visualizations"""
1009
+ log_message("Generating comprehensive reports...")
1010
+
1011
+ model = trainer.model
1012
+ if hasattr(model, 'module'):
1013
+ model = model.module
1014
+
1015
+ model.eval()
1016
+ device = next(model.parameters()).device
1017
+
1018
+ predictions = []
1019
+ true_labels = []
1020
+ sentence_ids = []
1021
+ severity_preds = []
1022
+ fluency_preds = []
1023
+ prediction_probs = []
1024
+
1025
+ # Evaluation
1026
+ dataloader = DataLoader(eval_dataset, batch_size=8, collate_fn=stable_collate_fn)
1027
+
1028
+ with torch.no_grad():
1029
+ for batch_idx, batch in enumerate(dataloader):
1030
+ if batch is None:
1031
+ continue
1032
+
1033
+ # Move to device
1034
+ for key in ['input_ids', 'attention_mask', 'word_pos_ids',
1035
+ 'word_grammar_ids', 'word_durations', 'labels', 'prosody_features']:
1036
+ if key in batch:
1037
+ batch[key] = batch[key].to(device)
1038
+
1039
+ outputs = model(**batch)
1040
+
1041
+ logits = outputs["logits"]
1042
+ probs = F.softmax(logits, dim=1)
1043
+ preds = torch.argmax(logits, dim=1).cpu().numpy()
1044
+
1045
+ predictions.extend(preds)
1046
+ true_labels.extend(batch["labels"].cpu().numpy())
1047
+ sentence_ids.extend(batch["sentence_ids"])
1048
+ severity_preds.extend(outputs["severity_pred"].cpu().numpy())
1049
+ fluency_preds.extend(outputs["fluency_pred"].cpu().numpy())
1050
+ prediction_probs.extend(probs.cpu().numpy())
1051
+
1052
+ # Analysis
1053
+ reverse_mapping = {v: k for k, v in aphasia_types_mapping.items()}
1054
+
1055
+ # 1. 詳細預測結果
1056
+ log_message("=== DETAILED PREDICTIONS (First 20) ===")
1057
+ for i in range(min(20, len(predictions))):
1058
+ true_type = reverse_mapping.get(true_labels[i], 'Unknown')
1059
+ pred_type = reverse_mapping.get(predictions[i], 'Unknown')
1060
+ severity_level = np.argmax(severity_preds[i])
1061
+ fluency_score = fluency_preds[i][0] if isinstance(fluency_preds[i], np.ndarray) else fluency_preds[i]
1062
+ confidence = np.max(prediction_probs[i])
1063
+
1064
+ log_message(f"ID: {sentence_ids[i]} | True: {true_type} | Pred: {pred_type} | "
1065
+ f"Confidence: {confidence:.3f} | Severity: {severity_level} | Fluency: {fluency_score:.3f}")
1066
+
1067
+ # 2. 混淆矩陣
1068
+ cm = confusion_matrix(true_labels, predictions)
1069
+
1070
+ # Enhanced confusion matrix plot
1071
+ plt.figure(figsize=(14, 12))
1072
+
1073
+ # Calculate percentages
1074
+ cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
1075
+
1076
+ # Create annotation array
1077
+ annotations = np.empty_like(cm, dtype=object)
1078
+ for i in range(cm.shape[0]):
1079
+ for j in range(cm.shape[1]):
1080
+ annotations[i, j] = f'{cm[i, j]}\n({cm_percentage[i, j]:.1f}%)'
1081
+
1082
+ sns.heatmap(cm, annot=annotations, fmt='', cmap="Blues",
1083
+ xticklabels=list(aphasia_types_mapping.keys()),
1084
+ yticklabels=list(aphasia_types_mapping.keys()),
1085
+ cbar_kws={'label': 'Count'})
1086
+
1087
+ plt.xlabel("Predicted Label", fontsize=12, fontweight='bold')
1088
+ plt.ylabel("True Label", fontsize=12, fontweight='bold')
1089
+ plt.title("Enhanced Confusion Matrix\n(Count and Percentage)", fontsize=14, fontweight='bold')
1090
+ plt.xticks(rotation=45, ha='right')
1091
+ plt.yticks(rotation=0)
1092
+ plt.tight_layout()
1093
+ plt.savefig(os.path.join(output_dir, "enhanced_confusion_matrix.png"), dpi=300, bbox_inches='tight')
1094
+ plt.close()
1095
+
1096
+ # 3. 分類報告
1097
+ all_label_ids = list(aphasia_types_mapping.values())
1098
+ report_dict = classification_report(
1099
+ true_labels,
1100
+ predictions,
1101
+ labels=all_label_ids,
1102
+ target_names=list(aphasia_types_mapping.keys()),
1103
+ output_dict=True,
1104
+ zero_division=0
1105
+ )
1106
+
1107
+ df_report = pd.DataFrame(report_dict).transpose()
1108
+ df_report.to_csv(os.path.join(output_dir, "comprehensive_classification_report.csv"))
1109
+
1110
+ # 4. Per-class performance visualization
1111
+ class_names = list(aphasia_types_mapping.keys())
1112
+ metrics_data = []
1113
+
1114
+ for i, class_name in enumerate(class_names):
1115
+ if class_name in report_dict:
1116
+ metrics_data.append({
1117
+ 'Class': class_name,
1118
+ 'Precision': report_dict[class_name]['precision'],
1119
+ 'Recall': report_dict[class_name]['recall'],
1120
+ 'F1-Score': report_dict[class_name]['f1-score'],
1121
+ 'Support': report_dict[class_name]['support']
1122
+ })
1123
+
1124
+ df_metrics = pd.DataFrame(metrics_data)
1125
+ df_metrics.to_csv(os.path.join(output_dir, "per_class_metrics.csv"), index=False)
1126
+
1127
+ # Plot per-class performance
1128
+ fig, axes = plt.subplots(2, 2, figsize=(16, 12))
1129
+
1130
+ # Precision
1131
+ axes[0, 0].bar(df_metrics['Class'], df_metrics['Precision'], color='skyblue', alpha=0.8)
1132
+ axes[0, 0].set_title('Precision by Class', fontweight='bold')
1133
+ axes[0, 0].set_ylabel('Precision')
1134
+ axes[0, 0].tick_params(axis='x', rotation=45)
1135
+ axes[0, 0].grid(True, alpha=0.3)
1136
+
1137
+ # Recall
1138
+ axes[0, 1].bar(df_metrics['Class'], df_metrics['Recall'], color='lightcoral', alpha=0.8)
1139
+ axes[0, 1].set_title('Recall by Class', fontweight='bold')
1140
+ axes[0, 1].set_ylabel('Recall')
1141
+ axes[0, 1].tick_params(axis='x', rotation=45)
1142
+ axes[0, 1].grid(True, alpha=0.3)
1143
+
1144
+ # F1-Score
1145
+ axes[1, 0].bar(df_metrics['Class'], df_metrics['F1-Score'], color='lightgreen', alpha=0.8)
1146
+ axes[1, 0].set_title('F1-Score by Class', fontweight='bold')
1147
+ axes[1, 0].set_ylabel('F1-Score')
1148
+ axes[1, 0].tick_params(axis='x', rotation=45)
1149
+ axes[1, 0].grid(True, alpha=0.3)
1150
+
1151
+ # Support
1152
+ axes[1, 1].bar(df_metrics['Class'], df_metrics['Support'], color='gold', alpha=0.8)
1153
+ axes[1, 1].set_title('Support by Class', fontweight='bold')
1154
+ axes[1, 1].set_ylabel('Support (Number of Samples)')
1155
+ axes[1, 1].tick_params(axis='x', rotation=45)
1156
+ axes[1, 1].grid(True, alpha=0.3)
1157
+
1158
+ plt.tight_layout()
1159
+ plt.savefig(os.path.join(output_dir, "per_class_performance.png"), dpi=300, bbox_inches='tight')
1160
+ plt.close()
1161
+
1162
+ # 5. Prediction confidence distribution
1163
+ confidences = [np.max(prob) for prob in prediction_probs]
1164
+ correct_predictions = [pred == true for pred, true in zip(predictions, true_labels)]
1165
+
1166
+ plt.figure(figsize=(12, 8))
1167
+
1168
+ # Separate correct and incorrect predictions
1169
+ correct_confidences = [conf for conf, correct in zip(confidences, correct_predictions) if correct]
1170
+ incorrect_confidences = [conf for conf, correct in zip(confidences, correct_predictions) if not correct]
1171
+
1172
+ plt.hist(correct_confidences, bins=30, alpha=0.7, label='Correct Predictions', color='green', density=True)
1173
+ plt.hist(incorrect_confidences, bins=30, alpha=0.7, label='Incorrect Predictions', color='red', density=True)
1174
+
1175
+ plt.xlabel('Prediction Confidence', fontsize=12)
1176
+ plt.ylabel('Density', fontsize=12)
1177
+ plt.title('Distribution of Prediction Confidence', fontsize=14, fontweight='bold')
1178
+ plt.legend()
1179
+ plt.grid(True, alpha=0.3)
1180
+ plt.tight_layout()
1181
+ plt.savefig(os.path.join(output_dir, "confidence_distribution.png"), dpi=300, bbox_inches='tight')
1182
+ plt.close()
1183
+
1184
+ # 6. 特徵分析
1185
+ log_message("=== FEATURE ANALYSIS ===")
1186
+ avg_severity = np.mean(severity_preds, axis=0)
1187
+ avg_fluency = np.mean(fluency_preds)
1188
+ std_fluency = np.std(fluency_preds)
1189
+
1190
+ log_message(f"Average Severity Distribution: {avg_severity}")
1191
+ log_message(f"Average Fluency Score: {avg_fluency:.3f} ± {std_fluency:.3f}")
1192
+
1193
+ # 7. 詳細結果保存
1194
+ results_df = pd.DataFrame({
1195
+ 'sentence_id': sentence_ids,
1196
+ 'true_label': [reverse_mapping[label] for label in true_labels],
1197
+ 'predicted_label': [reverse_mapping[pred] for pred in predictions],
1198
+ 'prediction_confidence': confidences,
1199
+ 'correct_prediction': correct_predictions,
1200
+ 'severity_level': [np.argmax(severity) for severity in severity_preds],
1201
+ 'fluency_score': [fluency[0] if isinstance(fluency, np.ndarray) else fluency for fluency in fluency_preds]
1202
+ })
1203
+
1204
+ # Add probability columns for each class
1205
+ for i, class_name in enumerate(aphasia_types_mapping.keys()):
1206
+ results_df[f'prob_{class_name}'] = [prob[i] for prob in prediction_probs]
1207
+
1208
+ results_df.to_csv(os.path.join(output_dir, "comprehensive_results.csv"), index=False)
1209
+
1210
+ # 8. 統計摘要
1211
+ summary_stats = {
1212
+ 'Overall Accuracy': accuracy_score(true_labels, predictions),
1213
+ 'Macro F1': f1_score(true_labels, predictions, average='macro'),
1214
+ 'Weighted F1': f1_score(true_labels, predictions, average='weighted'),
1215
+ 'Macro Precision': precision_score(true_labels, predictions, average='macro'),
1216
+ 'Macro Recall': recall_score(true_labels, predictions, average='macro'),
1217
+ 'Average Confidence': np.mean(confidences),
1218
+ 'Confidence Std': np.std(confidences),
1219
+ 'Average Severity': avg_severity.tolist(),
1220
+ 'Average Fluency': avg_fluency,
1221
+ 'Fluency Std': std_fluency
1222
+ }
1223
+
1224
+ serializable_summary = {
1225
+ k: float(v) if isinstance(v, (np.floating, np.integer)) else v
1226
+ for k, v in summary_stats.items()
1227
+ }
1228
+ with open(os.path.join(output_dir, "summary_statistics.json"), "w") as f:
1229
+ json.dump(serializable_summary, f, indent=2)
1230
+
1231
+ log_message("Comprehensive Classification Report:")
1232
+ log_message(df_report.to_string())
1233
+ log_message(f"Comprehensive results saved to {output_dir}")
1234
+
1235
+ return results_df, df_report, summary_stats
1236
+
1237
+ # Main training function with adaptive learning rate
1238
+ def train_adaptive_model(json_file: str, output_dir: str = "./adaptive_aphasia_model"):
1239
+ """Main training function with adaptive learning rate"""
1240
+
1241
+ log_message("Starting Adaptive Aphasia Classification Training")
1242
+ log_message("=" * 60)
1243
+
1244
+ # Setup
1245
+ config = ModelConfig()
1246
+ os.makedirs(output_dir, exist_ok=True)
1247
+
1248
+ # Device setup
1249
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1250
+ log_message(f"Using device: {device}")
1251
+
1252
+ # Load data
1253
+ log_message("Loading dataset...")
1254
+ with open(json_file, "r", encoding="utf-8") as f:
1255
+ dataset_json = json.load(f)
1256
+
1257
+ sentences = dataset_json.get("sentences", [])
1258
+
1259
+ # Normalize aphasia types
1260
+ for item in sentences:
1261
+ if "aphasia_type" in item:
1262
+ item["aphasia_type"] = normalize_type(item["aphasia_type"])
1263
+
1264
+ # Aphasia types mapping
1265
+ aphasia_types_mapping = {
1266
+ "BROCA": 0,
1267
+ "TRANSMOTOR": 1,
1268
+ "NOTAPHASICBYWAB": 2,
1269
+ "CONDUCTION": 3,
1270
+ "WERNICKE": 4,
1271
+ "ANOMIC": 5,
1272
+ "GLOBAL": 6,
1273
+ "ISOLATION": 7,
1274
+ "TRANSSENSORY": 8
1275
+ }
1276
+
1277
+ log_message(f"Aphasia Types Mapping: {aphasia_types_mapping}")
1278
+
1279
+ num_labels = len(aphasia_types_mapping)
1280
+ log_message(f"Number of labels: {num_labels}")
1281
+
1282
+ # Filter sentences
1283
+ filtered_sentences = []
1284
+ for item in sentences:
1285
+ aphasia_type = item.get("aphasia_type", "")
1286
+ if aphasia_type in aphasia_types_mapping:
1287
+ filtered_sentences.append(item)
1288
+ else:
1289
+ log_message(f"Excluding sentence with invalid type: {aphasia_type}")
1290
+
1291
+ log_message(f"Filtered dataset: {len(filtered_sentences)} sentences")
1292
+
1293
+ # Initialize tokenizer
1294
+ tokenizer = AutoTokenizer.from_pretrained(config.model_name)
1295
+ if tokenizer.pad_token is None:
1296
+ tokenizer.pad_token = tokenizer.eos_token
1297
+
1298
+ # Create dataset
1299
+ random.shuffle(filtered_sentences)
1300
+ dataset_all = StableAphasiaDataset(
1301
+ filtered_sentences, tokenizer, aphasia_types_mapping, config
1302
+ )
1303
+
1304
+ # Split dataset
1305
+ total_samples = len(dataset_all)
1306
+ train_size = int(0.8 * total_samples)
1307
+ eval_size = total_samples - train_size
1308
+
1309
+ train_dataset, eval_dataset = torch.utils.data.random_split(
1310
+ dataset_all, [train_size, eval_size]
1311
+ )
1312
+
1313
+ log_message(f"Train size: {train_size}, Eval size: {eval_size}")
1314
+
1315
+ # Setup weighted sampling for class imbalance
1316
+ train_labels = [dataset_all.samples[idx]["labels"].item() for idx in train_dataset.indices]
1317
+ label_counts = Counter(train_labels)
1318
+ sample_weights = [1.0 / label_counts[label] for label in train_labels]
1319
+ sampler = WeightedRandomSampler(
1320
+ weights=sample_weights,
1321
+ num_samples=len(sample_weights),
1322
+ replacement=True
1323
+ )
1324
+
1325
+ # Model initialization
1326
+ def model_init():
1327
+ model = StableAphasiaClassifier(config, num_labels)
1328
+ model.bert.resize_token_embeddings(len(tokenizer))
1329
+ return model.to(device)
1330
+
1331
+ # Training arguments
1332
+ training_args = TrainingArguments(
1333
+ output_dir=output_dir,
1334
+ eval_strategy="epoch",
1335
+ save_strategy="epoch",
1336
+ learning_rate=config.learning_rate,
1337
+ per_device_train_batch_size=config.batch_size,
1338
+ per_device_eval_batch_size=config.batch_size,
1339
+ num_train_epochs=config.num_epochs,
1340
+ weight_decay=config.weight_decay,
1341
+ warmup_ratio=config.warmup_ratio,
1342
+ logging_strategy="steps",
1343
+ logging_steps=50,
1344
+ seed=42,
1345
+ dataloader_num_workers=0,
1346
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
1347
+ max_grad_norm=1.0,
1348
+ fp16=False,
1349
+ dataloader_drop_last=True,
1350
+ report_to=None,
1351
+ load_best_model_at_end=True,
1352
+ metric_for_best_model="eval_f1",
1353
+ greater_is_better=True,
1354
+ save_total_limit=3,
1355
+ remove_unused_columns=False,
1356
+ )
1357
+
1358
+ # Initialize trainer with adaptive callback
1359
+ trainer = Trainer(
1360
+ model_init=model_init,
1361
+ args=training_args,
1362
+ train_dataset=train_dataset,
1363
+ eval_dataset=eval_dataset,
1364
+ compute_metrics=compute_comprehensive_metrics,
1365
+ data_collator=stable_collate_fn,
1366
+ callbacks=[AdaptiveTrainingCallback(config, patience=5, min_delta=0.8)]
1367
+ )
1368
+
1369
+ # Start training
1370
+ log_message("Starting adaptive training...")
1371
+ try:
1372
+ trainer.train()
1373
+ log_message("Training completed successfully!")
1374
+ except Exception as e:
1375
+ log_message(f"Training error: {str(e)}")
1376
+ import traceback
1377
+ log_message(traceback.format_exc())
1378
+ raise
1379
+
1380
+ # Final evaluation
1381
+ log_message("Starting final evaluation...")
1382
+ eval_results = trainer.evaluate()
1383
+ log_message(f"Final evaluation results: {eval_results}")
1384
+
1385
+ # Generate comprehensive reports
1386
+ results_df, report_df, summary_stats = generate_comprehensive_reports(
1387
+ trainer, eval_dataset, aphasia_types_mapping, tokenizer, output_dir
1388
+ )
1389
+
1390
+ # Save model
1391
+ model_to_save = trainer.model
1392
+ if hasattr(model_to_save, 'module'):
1393
+ model_to_save = model_to_save.module
1394
+
1395
+ torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
1396
+ tokenizer.save_pretrained(output_dir)
1397
+
1398
+ # Save configuration
1399
+ config_dict = {
1400
+ "model_name": config.model_name,
1401
+ "num_labels": num_labels,
1402
+ "aphasia_types_mapping": aphasia_types_mapping,
1403
+ "training_args": training_args.to_dict(),
1404
+ "adaptive_lr_config": {
1405
+ "adaptive_lr": config.adaptive_lr,
1406
+ "lr_patience": config.lr_patience,
1407
+ "lr_factor": config.lr_factor,
1408
+ "lr_increase_factor": config.lr_increase_factor,
1409
+ "min_lr": config.min_lr,
1410
+ "max_lr": config.max_lr,
1411
+ "oscillation_amplitude": config.oscillation_amplitude
1412
+ }
1413
+ }
1414
+
1415
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
1416
+ json.dump(config_dict, f, indent=2)
1417
+
1418
+ log_message(f"Adaptive model and comprehensive reports saved to {output_dir}")
1419
+ clear_memory()
1420
+
1421
+ return trainer, eval_results, results_df
1422
+
1423
+ # Cross-validation with adaptive learning rate
1424
+ def train_adaptive_cross_validation(json_file: str, output_dir: str = "./adaptive_cv_results", n_folds: int = 5):
1425
+ """Cross-validation training with adaptive learning rate"""
1426
+ log_message("Starting Adaptive Cross-Validation Training")
1427
+
1428
+ config = ModelConfig()
1429
+ os.makedirs(output_dir, exist_ok=True)
1430
+
1431
+ # Load and prepare data
1432
+ with open(json_file, "r", encoding="utf-8") as f:
1433
+ dataset_json = json.load(f)
1434
+
1435
+ sentences = dataset_json.get("sentences", [])
1436
+
1437
+ # Normalize and filter
1438
+ for item in sentences:
1439
+ if "aphasia_type" in item:
1440
+ item["aphasia_type"] = normalize_type(item["aphasia_type"])
1441
+
1442
+ aphasia_types_mapping = {
1443
+ "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
1444
+ "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
1445
+ "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
1446
+ }
1447
+
1448
+ filtered_sentences = [s for s in sentences if s.get("aphasia_type") in aphasia_types_mapping]
1449
+
1450
+ # Initialize tokenizer
1451
+ tokenizer = AutoTokenizer.from_pretrained(config.model_name)
1452
+ if tokenizer.pad_token is None:
1453
+ tokenizer.pad_token = tokenizer.eos_token
1454
+
1455
+ # Create full dataset
1456
+ full_dataset = StableAphasiaDataset(
1457
+ filtered_sentences, tokenizer, aphasia_types_mapping, config
1458
+ )
1459
+
1460
+ # Extract labels for stratification
1461
+ sample_labels = [sample["labels"].item() for sample in full_dataset.samples]
1462
+
1463
+ # Cross-validation
1464
+ skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
1465
+ fold_results = []
1466
+ all_predictions = []
1467
+ all_true_labels = []
1468
+
1469
+ for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(sample_labels)), sample_labels)):
1470
+ log_message(f"\n=== Fold {fold + 1}/{n_folds} ===")
1471
+
1472
+ train_subset = Subset(full_dataset, train_idx)
1473
+ val_subset = Subset(full_dataset, val_idx)
1474
+
1475
+ # Train single fold
1476
+ fold_trainer, fold_results_dict, fold_predictions = train_adaptive_single_fold(
1477
+ train_subset, val_subset, config, aphasia_types_mapping,
1478
+ tokenizer, fold, output_dir
1479
+ )
1480
+
1481
+ fold_results.append({
1482
+ 'fold': fold + 1,
1483
+ **fold_results_dict
1484
+ })
1485
+
1486
+ # Collect predictions for ensemble analysis
1487
+ all_predictions.extend(fold_predictions['predictions'])
1488
+ all_true_labels.extend(fold_predictions['true_labels'])
1489
+
1490
+ clear_memory()
1491
+
1492
+ # Aggregate results
1493
+ results_df = pd.DataFrame(fold_results)
1494
+ results_df.to_csv(os.path.join(output_dir, "adaptive_cv_summary.csv"), index=False)
1495
+
1496
+ # Cross-validation summary statistics
1497
+ cv_summary = {
1498
+ 'mean_accuracy': results_df['accuracy'].mean(),
1499
+ 'std_accuracy': results_df['accuracy'].std(),
1500
+ 'mean_f1': results_df['f1'].mean(),
1501
+ 'std_f1': results_df['f1'].std(),
1502
+ 'mean_f1_macro': results_df['f1_macro'].mean(),
1503
+ 'std_f1_macro': results_df['f1_macro'].std(),
1504
+ 'mean_precision': results_df['precision_macro'].mean(),
1505
+ 'std_precision': results_df['precision_macro'].std(),
1506
+ 'mean_recall': results_df['recall_macro'].mean(),
1507
+ 'std_recall': results_df['recall_macro'].std()
1508
+ }
1509
+
1510
+ with open(os.path.join(output_dir, "cv_statistics.json"), "w") as f:
1511
+ json.dump(cv_summary, f, indent=2)
1512
+
1513
+ # Overall confusion matrix across all folds
1514
+ overall_cm = confusion_matrix(all_true_labels, all_predictions)
1515
+
1516
+ plt.figure(figsize=(12, 10))
1517
+ sns.heatmap(overall_cm, annot=True, fmt="d", cmap="Blues",
1518
+ xticklabels=list(aphasia_types_mapping.keys()),
1519
+ yticklabels=list(aphasia_types_mapping.keys()))
1520
+ plt.xlabel("Predicted Label")
1521
+ plt.ylabel("True Label")
1522
+ plt.title("Overall Confusion Matrix (All Folds)")
1523
+ plt.xticks(rotation=45)
1524
+ plt.yticks(rotation=0)
1525
+ plt.tight_layout()
1526
+ plt.savefig(os.path.join(output_dir, "overall_confusion_matrix.png"), dpi=300, bbox_inches='tight')
1527
+ plt.close()
1528
+
1529
+ # Cross-validation results visualization
1530
+ fig, axes = plt.subplots(2, 2, figsize=(15, 12))
1531
+
1532
+ # Accuracy across folds
1533
+ axes[0, 0].bar(range(1, n_folds + 1), results_df['accuracy'], color='skyblue', alpha=0.8)
1534
+ axes[0, 0].axhline(y=results_df['accuracy'].mean(), color='red', linestyle='--',
1535
+ label=f'Mean: {results_df["accuracy"].mean():.3f}')
1536
+ axes[0, 0].set_title('Accuracy Across Folds')
1537
+ axes[0, 0].set_xlabel('Fold')
1538
+ axes[0, 0].set_ylabel('Accuracy')
1539
+ axes[0, 0].legend()
1540
+ axes[0, 0].grid(True, alpha=0.3)
1541
+
1542
+ # F1 Score across folds
1543
+ axes[0, 1].bar(range(1, n_folds + 1), results_df['f1'], color='lightgreen', alpha=0.8)
1544
+ axes[0, 1].axhline(y=results_df['f1'].mean(), color='red', linestyle='--',
1545
+ label=f'Mean: {results_df["f1"].mean():.3f}')
1546
+ axes[0, 1].set_title('F1 Score Across Folds')
1547
+ axes[0, 1].set_xlabel('Fold')
1548
+ axes[0, 1].set_ylabel('F1 Score')
1549
+ axes[0, 1].legend()
1550
+ axes[0, 1].grid(True, alpha=0.3)
1551
+
1552
+ # Precision across folds
1553
+ axes[1, 0].bar(range(1, n_folds + 1), results_df['precision_macro'], color='coral', alpha=0.8)
1554
+ axes[1, 0].axhline(y=results_df['precision_macro'].mean(), color='red', linestyle='--',
1555
+ label=f'Mean: {results_df["precision_macro"].mean():.3f}')
1556
+ axes[1, 0].set_title('Precision Across Folds')
1557
+ axes[1, 0].set_xlabel('Fold')
1558
+ axes[1, 0].set_ylabel('Precision')
1559
+ axes[1, 0].legend()
1560
+ axes[1, 0].grid(True, alpha=0.3)
1561
+
1562
+ # Recall across folds
1563
+ axes[1, 1].bar(range(1, n_folds + 1), results_df['recall_macro'], color='gold', alpha=0.8)
1564
+ axes[1, 1].axhline(y=results_df['recall_macro'].mean(), color='red', linestyle='--',
1565
+ label=f'Mean: {results_df["recall_macro"].mean():.3f}')
1566
+ axes[1, 1].set_title('Recall Across Folds')
1567
+ axes[1, 1].set_xlabel('Fold')
1568
+ axes[1, 1].set_ylabel('Recall')
1569
+ axes[1, 1].legend()
1570
+ axes[1, 1].grid(True, alpha=0.3)
1571
+
1572
+ plt.tight_layout()
1573
+ plt.savefig(os.path.join(output_dir, "cv_performance_comparison.png"), dpi=300, bbox_inches='tight')
1574
+ plt.close()
1575
+
1576
+ log_message("\n=== Adaptive Cross-Validation Summary ===")
1577
+ log_message(results_df.to_string(index=False))
1578
+
1579
+ # Statistics
1580
+ log_message(f"\nMean F1: {results_df['f1'].mean():.4f} ± {results_df['f1'].std():.4f}")
1581
+ log_message(f"Mean Accuracy: {results_df['accuracy'].mean():.4f} ± {results_df['accuracy'].std():.4f}")
1582
+ log_message(f"Mean F1 Macro: {results_df['f1_macro'].mean():.4f} ± {results_df['f1_macro'].std():.4f}")
1583
+
1584
+ return results_df, cv_summary
1585
+
1586
+ def train_adaptive_single_fold(train_dataset, val_dataset, config, aphasia_types_mapping,
1587
+ tokenizer, fold, output_dir):
1588
+ """Train a single fold with adaptive learning rate"""
1589
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1590
+ num_labels = len(aphasia_types_mapping)
1591
+
1592
+ # Setup weighted sampling
1593
+ train_labels = [train_dataset[i]["labels"].item() for i in range(len(train_dataset))]
1594
+ label_counts = Counter(train_labels)
1595
+ sample_weights = [1.0 / label_counts[label] for label in train_labels]
1596
+ sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
1597
+
1598
+ # Model initialization
1599
+ def model_init():
1600
+ model = StableAphasiaClassifier(config, num_labels)
1601
+ model.bert.resize_token_embeddings(len(tokenizer))
1602
+ return model.to(device)
1603
+
1604
+ # Training arguments
1605
+ fold_output_dir = os.path.join(output_dir, f"fold_{fold}")
1606
+ os.makedirs(fold_output_dir, exist_ok=True)
1607
+
1608
+ training_args = TrainingArguments(
1609
+ output_dir=fold_output_dir,
1610
+ eval_strategy="epoch",
1611
+ save_strategy="epoch",
1612
+ learning_rate=config.learning_rate,
1613
+ per_device_train_batch_size=config.batch_size,
1614
+ per_device_eval_batch_size=config.batch_size,
1615
+ num_train_epochs=config.num_epochs,
1616
+ weight_decay=config.weight_decay,
1617
+ warmup_ratio=config.warmup_ratio,
1618
+ logging_steps=50,
1619
+ seed=42,
1620
+ dataloader_num_workers=0,
1621
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
1622
+ max_grad_norm=1.0,
1623
+ fp16=False,
1624
+ dataloader_drop_last=True,
1625
+ report_to=None,
1626
+ load_best_model_at_end=True,
1627
+ metric_for_best_model="eval_f1",
1628
+ greater_is_better=True,
1629
+ save_total_limit=1,
1630
+ remove_unused_columns=False,
1631
+ )
1632
+
1633
+ # Trainer with adaptive callback
1634
+ trainer = Trainer(
1635
+ model_init=model_init,
1636
+ args=training_args,
1637
+ train_dataset=train_dataset,
1638
+ eval_dataset=val_dataset,
1639
+ compute_metrics=compute_comprehensive_metrics,
1640
+ data_collator=stable_collate_fn,
1641
+ callbacks=[AdaptiveTrainingCallback(config, patience=5, min_delta=0.8)]
1642
+ )
1643
+
1644
+ # Train
1645
+ trainer.train()
1646
+
1647
+ # Evaluate
1648
+ eval_results = trainer.evaluate()
1649
+
1650
+ # Get predictions for ensemble analysis
1651
+ predictions = trainer.predict(val_dataset)
1652
+ pred_labels = np.argmax(predictions.predictions[0] if isinstance(predictions.predictions, tuple) else predictions.predictions, axis=1)
1653
+ true_labels = predictions.label_ids
1654
+
1655
+ fold_predictions = {
1656
+ 'predictions': pred_labels.tolist(),
1657
+ 'true_labels': true_labels.tolist()
1658
+ }
1659
+
1660
+ # Save fold model
1661
+ model_to_save = trainer.model
1662
+ if hasattr(model_to_save, 'module'):
1663
+ model_to_save = model_to_save.module
1664
+
1665
+ torch.save(model_to_save.state_dict(), os.path.join(fold_output_dir, "pytorch_model.bin"))
1666
+
1667
+ return trainer, eval_results, fold_predictions
1668
+
1669
+ # Main execution
1670
+ if __name__ == "__main__":
1671
+ import argparse
1672
+
1673
+ parser = argparse.ArgumentParser(description="Adaptive Learning Rate Aphasia Classification Training")
1674
+ parser.add_argument("--output_dir", type=str, default="./adaptive_aphasia_model", help="Output directory")
1675
+ parser.add_argument("--cross_validation", action="store_true", help="Use cross-validation")
1676
+ parser.add_argument("--n_folds", type=int, default=5, help="Number of CV folds")
1677
+ parser.add_argument("--json_file", type=str, default=json_file, help="Path to JSON dataset file")
1678
+ parser.add_argument("--learning_rate", type=float, default=5e-4, help="Initial learning rate")
1679
+ parser.add_argument("--batch_size", type=int, default=24, help="Batch size")
1680
+ parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs")
1681
+ parser.add_argument("--adaptive_lr", action="store_true", default=True, help="Use adaptive learning rate")
1682
+
1683
+ args = parser.parse_args()
1684
+
1685
+ # Update config with command line arguments
1686
+ config = ModelConfig()
1687
+ config.learning_rate = args.learning_rate
1688
+ config.batch_size = args.batch_size
1689
+ config.num_epochs = args.num_epochs
1690
+ config.adaptive_lr = args.adaptive_lr
1691
+
1692
+ try:
1693
+ clear_memory()
1694
+
1695
+ log_message(f"Starting training with adaptive_lr={config.adaptive_lr}")
1696
+ log_message(f"Config: lr={config.learning_rate}, batch_size={config.batch_size}, epochs={config.num_epochs}")
1697
+
1698
+ if args.cross_validation:
1699
+ results_df, cv_summary = train_adaptive_cross_validation(args.json_file, args.output_dir, args.n_folds)
1700
+ log_message("Cross-validation training completed!")
1701
+ else:
1702
+ trainer, eval_results, results_df = train_adaptive_model(args.json_file, args.output_dir)
1703
+ log_message("Single model training completed!")
1704
+
1705
+ log_message("All adaptive training completed successfully!")
1706
+
1707
+ except Exception as e:
1708
+ log_message(f"Training failed: {str(e)}")
1709
+ import traceback
1710
+ log_message(traceback.format_exc())
1711
+ finally:
1712
+ clear_memory()
aphasia_predictions.json ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "summary": {
3
+ "classification_distribution": {
4
+ "BROCA": 1
5
+ },
6
+ "classification_percentages": {
7
+ "BROCA": "100.0%"
8
+ },
9
+ "average_confidence": "0.995",
10
+ "average_fluency_score": "0.571",
11
+ "severity_distribution": {
12
+ "3": 1
13
+ },
14
+ "confidence_statistics": {
15
+ "mean": "0.995",
16
+ "std": "0.000",
17
+ "min": "0.995",
18
+ "max": "0.995"
19
+ },
20
+ "most_common_prediction": "BROCA"
21
+ },
22
+ "total_sentences": 1,
23
+ "predictions": [
24
+ {
25
+ "sentence_id": "S1",
26
+ "input_text": "yeah well [DIALOGUE] I yeah you_know dada dada [DIALOGUE] [DIALOGUE] [DIALOGUE] yes beg it~cop two thousand two day-PL no after New_Year's_Day two thousand [DIALOGUE] [DIALOGUE] [DIALOGUE] I do~neg remember I do~neg remember [DIALOGUE] [DIALOGUE] [DIALOGUE] oh beg yeah beg x I aphasia oh beg yes [DIALOGUE] [DIALOGUE] [DIALOGUE] [DIALOGUE] yeah [DIALOGUE] [DIALOGUE] [DIALOGUE] [DIALOGUE] oh beg yes Turkey to China China oh beg yes up at Beijing yes oh and walk on the wall yes beg oh beg god beg I love-PAST it yes beg oh beg just amaze-PRESP oh beg just amaze-PRESP oh beg I just oh yeah [DIALOGUE] oh beg yes [DIALOGUE] [DIALOGUE] [DIALOGUE] [DIALOGUE] [DIALOGUE] kick-PRESP the ball window accident window break&PASTP and it~cop all big end yeah and the window break&PASTP and a ball end yeah [DIALOGUE] [DIALOGUE] [DIALOGUE] [DIALOGUE] oh beg I no I do~neg want it no and rain yes beg rain rain rain yes beg oh_no no no no yes beg mother look-PRESP at son and son get-3S a umbrella [DIALOGUE] [DIALOGUE] [DIALOGUE] cat up the tree darling get cat out the tree ladder break&PASTP there~cop a up the tree tree bark-PRESP oh beg bark-PRESP end yeah and x get-3S mother to down the tree [DIALOGUE] [DIALOGUE] [DIALOGUE] yep [DIALOGUE] [DIALOGUE] [DIALOGUE] [DIALOGUE] oh [DIALOGUE] [DIALOGUE] [DIALOGUE] [DIALOGUE] well beg Cinderella be&PAST&13S a poor child in in oh_god Cinderella poor child in to do many thing-PL in and oh_god [DIALOGUE] oh and you troll child in child&PL oh_god child&PL want-PAST to go to dance and beautiful dadada dadadada and Cinderella be&PAST~neg sure about go-PRESP to dance and oh_god I just [DIALOGUE] she get&PAST to go to the dance in shoe-PL and oh_god oh_god [DIALOGUE] dance and yeah and she be&PAST&13S dance-PRESP around night and she suppose-PAST to be somewhere else and she put&ZERO her foot in the so and she ride&PAST off with the prince [DIALOGUE] [DIALOGUE] [DIALOGUE] [DIALOGUE] [DIALOGUE] oh beg bread two piece-PL of bread and jelly and peanut butter and turn them over and make a peanut butter sandwich [DIALOGUE]",
27
+ "original_tokens": [
28
+ "yeah",
29
+ "well",
30
+ "[DIALOGUE]",
31
+ "I",
32
+ "yeah",
33
+ "you_know",
34
+ "dada",
35
+ "dada",
36
+ "[DIALOGUE]",
37
+ "[DIALOGUE]",
38
+ "[DIALOGUE]",
39
+ "yes",
40
+ "beg",
41
+ "it~cop",
42
+ "two",
43
+ "thousand",
44
+ "two",
45
+ "day-PL",
46
+ "no",
47
+ "after",
48
+ "New_Year's_Day",
49
+ "two",
50
+ "thousand",
51
+ "[DIALOGUE]",
52
+ "[DIALOGUE]",
53
+ "[DIALOGUE]",
54
+ "I",
55
+ "do~neg",
56
+ "remember",
57
+ "I",
58
+ "do~neg",
59
+ "remember",
60
+ "[DIALOGUE]",
61
+ "[DIALOGUE]",
62
+ "[DIALOGUE]",
63
+ "oh",
64
+ "beg",
65
+ "yeah",
66
+ "beg",
67
+ "x",
68
+ "I",
69
+ "aphasia",
70
+ "oh",
71
+ "beg",
72
+ "yes",
73
+ "[DIALOGUE]",
74
+ "[DIALOGUE]",
75
+ "[DIALOGUE]",
76
+ "[DIALOGUE]",
77
+ "yeah",
78
+ "[DIALOGUE]",
79
+ "[DIALOGUE]",
80
+ "[DIALOGUE]",
81
+ "[DIALOGUE]",
82
+ "oh",
83
+ "beg",
84
+ "yes",
85
+ "Turkey",
86
+ "to",
87
+ "China",
88
+ "China",
89
+ "oh",
90
+ "beg",
91
+ "yes",
92
+ "up",
93
+ "at",
94
+ "Beijing",
95
+ "yes",
96
+ "oh",
97
+ "and",
98
+ "walk",
99
+ "on",
100
+ "the",
101
+ "wall",
102
+ "yes",
103
+ "beg",
104
+ "oh",
105
+ "beg",
106
+ "god",
107
+ "beg",
108
+ "I",
109
+ "love-PAST",
110
+ "it",
111
+ "yes",
112
+ "beg",
113
+ "oh",
114
+ "beg",
115
+ "just",
116
+ "amaze-PRESP",
117
+ "oh",
118
+ "beg",
119
+ "just",
120
+ "amaze-PRESP",
121
+ "oh",
122
+ "beg",
123
+ "I",
124
+ "just",
125
+ "oh",
126
+ "yeah",
127
+ "[DIALOGUE]",
128
+ "oh",
129
+ "beg",
130
+ "yes",
131
+ "[DIALOGUE]",
132
+ "[DIALOGUE]",
133
+ "[DIALOGUE]",
134
+ "[DIALOGUE]",
135
+ "[DIALOGUE]",
136
+ "kick-PRESP",
137
+ "the",
138
+ "ball",
139
+ "window",
140
+ "accident",
141
+ "window",
142
+ "break&PASTP",
143
+ "and",
144
+ "it~cop",
145
+ "all",
146
+ "big",
147
+ "end",
148
+ "yeah",
149
+ "and",
150
+ "the",
151
+ "window",
152
+ "break&PASTP",
153
+ "and",
154
+ "a",
155
+ "ball",
156
+ "end",
157
+ "yeah",
158
+ "[DIALOGUE]",
159
+ "[DIALOGUE]",
160
+ "[DIALOGUE]",
161
+ "[DIALOGUE]",
162
+ "oh",
163
+ "beg",
164
+ "I",
165
+ "no",
166
+ "I",
167
+ "do~neg",
168
+ "want",
169
+ "it",
170
+ "no",
171
+ "and",
172
+ "rain",
173
+ "yes",
174
+ "beg",
175
+ "rain",
176
+ "rain",
177
+ "rain",
178
+ "yes",
179
+ "beg",
180
+ "oh_no",
181
+ "no",
182
+ "no",
183
+ "no",
184
+ "yes",
185
+ "beg",
186
+ "mother",
187
+ "look-PRESP",
188
+ "at",
189
+ "son",
190
+ "and",
191
+ "son",
192
+ "get-3S",
193
+ "a",
194
+ "umbrella",
195
+ "[DIALOGUE]",
196
+ "[DIALOGUE]",
197
+ "[DIALOGUE]",
198
+ "cat",
199
+ "up",
200
+ "the",
201
+ "tree",
202
+ "darling",
203
+ "get",
204
+ "cat",
205
+ "out",
206
+ "the",
207
+ "tree",
208
+ "ladder",
209
+ "break&PASTP",
210
+ "there~cop",
211
+ "a",
212
+ "up",
213
+ "the",
214
+ "tree",
215
+ "tree",
216
+ "bark-PRESP",
217
+ "oh",
218
+ "beg",
219
+ "bark-PRESP",
220
+ "end",
221
+ "yeah",
222
+ "and",
223
+ "x",
224
+ "get-3S",
225
+ "mother",
226
+ "to",
227
+ "down",
228
+ "the",
229
+ "tree",
230
+ "[DIALOGUE]",
231
+ "[DIALOGUE]",
232
+ "[DIALOGUE]",
233
+ "yep",
234
+ "[DIALOGUE]",
235
+ "[DIALOGUE]",
236
+ "[DIALOGUE]",
237
+ "[DIALOGUE]",
238
+ "oh",
239
+ "[DIALOGUE]",
240
+ "[DIALOGUE]",
241
+ "[DIALOGUE]",
242
+ "[DIALOGUE]",
243
+ "well",
244
+ "beg",
245
+ "Cinderella",
246
+ "be&PAST&13S",
247
+ "a",
248
+ "poor",
249
+ "child",
250
+ "in",
251
+ "in",
252
+ "oh_god",
253
+ "Cinderella",
254
+ "poor",
255
+ "child",
256
+ "in",
257
+ "to",
258
+ "do",
259
+ "many",
260
+ "thing-PL",
261
+ "in",
262
+ "and",
263
+ "oh_god",
264
+ "[DIALOGUE]",
265
+ "oh",
266
+ "and",
267
+ "you",
268
+ "troll",
269
+ "child",
270
+ "in",
271
+ "child&PL",
272
+ "oh_god",
273
+ "child&PL",
274
+ "want-PAST",
275
+ "to",
276
+ "go",
277
+ "to",
278
+ "dance",
279
+ "and",
280
+ "beautiful",
281
+ "dadada",
282
+ "dadadada",
283
+ "and",
284
+ "Cinderella",
285
+ "be&PAST~neg",
286
+ "sure",
287
+ "about",
288
+ "go-PRESP",
289
+ "to",
290
+ "dance",
291
+ "and",
292
+ "oh_god",
293
+ "I",
294
+ "just",
295
+ "[DIALOGUE]",
296
+ "she",
297
+ "get&PAST",
298
+ "to",
299
+ "go",
300
+ "to",
301
+ "the",
302
+ "dance",
303
+ "in",
304
+ "shoe-PL",
305
+ "and",
306
+ "oh_god",
307
+ "oh_god",
308
+ "[DIALOGUE]",
309
+ "dance",
310
+ "and",
311
+ "yeah",
312
+ "and",
313
+ "she",
314
+ "be&PAST&13S",
315
+ "dance-PRESP",
316
+ "around",
317
+ "night",
318
+ "and",
319
+ "she",
320
+ "suppose-PAST",
321
+ "to",
322
+ "be",
323
+ "somewhere",
324
+ "else",
325
+ "and",
326
+ "she",
327
+ "put&ZERO",
328
+ "her",
329
+ "foot",
330
+ "in",
331
+ "the",
332
+ "so",
333
+ "and",
334
+ "she",
335
+ "ride&PAST",
336
+ "off",
337
+ "with",
338
+ "the",
339
+ "prince",
340
+ "[DIALOGUE]",
341
+ "[DIALOGUE]",
342
+ "[DIALOGUE]",
343
+ "[DIALOGUE]",
344
+ "[DIALOGUE]",
345
+ "oh",
346
+ "beg",
347
+ "bread",
348
+ "two",
349
+ "piece-PL",
350
+ "of",
351
+ "bread",
352
+ "and",
353
+ "jelly",
354
+ "and",
355
+ "peanut",
356
+ "butter",
357
+ "and",
358
+ "turn",
359
+ "them",
360
+ "over",
361
+ "and",
362
+ "make",
363
+ "a",
364
+ "peanut",
365
+ "butter",
366
+ "sandwich",
367
+ "[DIALOGUE]"
368
+ ],
369
+ "prediction": {
370
+ "predicted_class": "BROCA",
371
+ "confidence": 0.994691789150238,
372
+ "confidence_percentage": "99.47%"
373
+ },
374
+ "class_description": {
375
+ "name": "Broca's Aphasia (Non-fluent)",
376
+ "description": "Characterized by limited speech output, difficulty with grammar and sentence formation, but relatively preserved comprehension. Speech is typically effortful and halting.",
377
+ "features": [
378
+ "Non-fluent speech",
379
+ "Preserved comprehension",
380
+ "Grammar difficulties",
381
+ "Word-finding problems"
382
+ ]
383
+ },
384
+ "probability_distribution": {
385
+ "BROCA": {
386
+ "probability": 0.994691789150238,
387
+ "percentage": "99.47%"
388
+ },
389
+ "CONDUCTION": {
390
+ "probability": 0.001859842101112008,
391
+ "percentage": "0.19%"
392
+ },
393
+ "GLOBAL": {
394
+ "probability": 0.0015279082581400871,
395
+ "percentage": "0.15%"
396
+ },
397
+ "ANOMIC": {
398
+ "probability": 0.0014873514883220196,
399
+ "percentage": "0.15%"
400
+ },
401
+ "TRANSMOTOR": {
402
+ "probability": 0.00028855769778601825,
403
+ "percentage": "0.03%"
404
+ },
405
+ "NOTAPHASICBYWAB": {
406
+ "probability": 9.208399569615722e-05,
407
+ "percentage": "0.01%"
408
+ },
409
+ "WERNICKE": {
410
+ "probability": 4.5590277295559645e-05,
411
+ "percentage": "0.00%"
412
+ },
413
+ "ISOLATION": {
414
+ "probability": 6.9648622229578905e-06,
415
+ "percentage": "0.00%"
416
+ },
417
+ "TRANSSENSORY": {
418
+ "probability": 8.662294881389698e-09,
419
+ "percentage": "0.00%"
420
+ }
421
+ },
422
+ "additional_predictions": {
423
+ "severity_distribution": {
424
+ "level_0": 0.22366976737976074,
425
+ "level_1": 0.1340962052345276,
426
+ "level_2": 0.2849337160587311,
427
+ "level_3": 0.3573003113269806
428
+ },
429
+ "predicted_severity_level": 3,
430
+ "fluency_score": 0.571057915687561,
431
+ "fluency_rating": "Medium"
432
+ }
433
+ }
434
+ ]
435
+ }
config.json ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
3
+ "num_labels": 9,
4
+ "aphasia_types_mapping": {
5
+ "BROCA": 0,
6
+ "TRANSMOTOR": 1,
7
+ "NOTAPHASICBYWAB": 2,
8
+ "CONDUCTION": 3,
9
+ "WERNICKE": 4,
10
+ "ANOMIC": 5,
11
+ "GLOBAL": 6,
12
+ "ISOLATION": 7,
13
+ "TRANSSENSORY": 8
14
+ },
15
+ "training_args": {
16
+ "output_dir": "./adaptive_aphasia_model",
17
+ "overwrite_output_dir": false,
18
+ "do_train": false,
19
+ "do_eval": true,
20
+ "do_predict": false,
21
+ "eval_strategy": "epoch",
22
+ "prediction_loss_only": false,
23
+ "per_device_train_batch_size": 10,
24
+ "per_device_eval_batch_size": 10,
25
+ "per_gpu_train_batch_size": null,
26
+ "per_gpu_eval_batch_size": null,
27
+ "gradient_accumulation_steps": 4,
28
+ "eval_accumulation_steps": null,
29
+ "eval_delay": 0,
30
+ "torch_empty_cache_steps": null,
31
+ "learning_rate": 0.0005,
32
+ "weight_decay": 0.01,
33
+ "adam_beta1": 0.9,
34
+ "adam_beta2": 0.999,
35
+ "adam_epsilon": 1e-08,
36
+ "max_grad_norm": 1.0,
37
+ "num_train_epochs": 500,
38
+ "max_steps": -1,
39
+ "lr_scheduler_type": "linear",
40
+ "lr_scheduler_kwargs": {},
41
+ "warmup_ratio": 0.1,
42
+ "warmup_steps": 0,
43
+ "log_level": "passive",
44
+ "log_level_replica": "warning",
45
+ "log_on_each_node": true,
46
+ "logging_dir": "./adaptive_aphasia_model/runs/Aug06_00-31-47_ikm-gpu-9104",
47
+ "logging_strategy": "steps",
48
+ "logging_first_step": false,
49
+ "logging_steps": 50,
50
+ "logging_nan_inf_filter": true,
51
+ "save_strategy": "epoch",
52
+ "save_steps": 500,
53
+ "save_total_limit": 3,
54
+ "save_safetensors": true,
55
+ "save_on_each_node": false,
56
+ "save_only_model": false,
57
+ "restore_callback_states_from_checkpoint": false,
58
+ "no_cuda": false,
59
+ "use_cpu": false,
60
+ "use_mps_device": false,
61
+ "seed": 42,
62
+ "data_seed": null,
63
+ "jit_mode_eval": false,
64
+ "use_ipex": false,
65
+ "bf16": false,
66
+ "fp16": false,
67
+ "fp16_opt_level": "O1",
68
+ "half_precision_backend": "auto",
69
+ "bf16_full_eval": false,
70
+ "fp16_full_eval": false,
71
+ "tf32": null,
72
+ "local_rank": 1,
73
+ "ddp_backend": null,
74
+ "tpu_num_cores": null,
75
+ "tpu_metrics_debug": false,
76
+ "debug": [],
77
+ "dataloader_drop_last": true,
78
+ "eval_steps": null,
79
+ "dataloader_num_workers": 0,
80
+ "dataloader_prefetch_factor": null,
81
+ "past_index": -1,
82
+ "run_name": "./adaptive_aphasia_model",
83
+ "disable_tqdm": false,
84
+ "remove_unused_columns": false,
85
+ "label_names": null,
86
+ "load_best_model_at_end": true,
87
+ "metric_for_best_model": "eval_f1",
88
+ "greater_is_better": true,
89
+ "ignore_data_skip": false,
90
+ "fsdp": [],
91
+ "fsdp_min_num_params": 0,
92
+ "fsdp_config": {
93
+ "min_num_params": 0,
94
+ "xla": false,
95
+ "xla_fsdp_v2": false,
96
+ "xla_fsdp_grad_ckpt": false
97
+ },
98
+ "fsdp_transformer_layer_cls_to_wrap": null,
99
+ "accelerator_config": {
100
+ "split_batches": false,
101
+ "dispatch_batches": null,
102
+ "even_batches": true,
103
+ "use_seedable_sampler": true,
104
+ "non_blocking": false,
105
+ "gradient_accumulation_kwargs": null
106
+ },
107
+ "deepspeed": null,
108
+ "label_smoothing_factor": 0.0,
109
+ "optim": "adamw_torch",
110
+ "optim_args": null,
111
+ "adafactor": false,
112
+ "group_by_length": false,
113
+ "length_column_name": "length",
114
+ "report_to": [],
115
+ "ddp_find_unused_parameters": null,
116
+ "ddp_bucket_cap_mb": null,
117
+ "ddp_broadcast_buffers": null,
118
+ "dataloader_pin_memory": true,
119
+ "dataloader_persistent_workers": false,
120
+ "skip_memory_metrics": true,
121
+ "use_legacy_prediction_loop": false,
122
+ "push_to_hub": false,
123
+ "resume_from_checkpoint": null,
124
+ "hub_model_id": null,
125
+ "hub_strategy": "every_save",
126
+ "hub_token": "<HUB_TOKEN>",
127
+ "hub_private_repo": null,
128
+ "hub_always_push": false,
129
+ "gradient_checkpointing": false,
130
+ "gradient_checkpointing_kwargs": null,
131
+ "include_inputs_for_metrics": false,
132
+ "include_for_metrics": [],
133
+ "eval_do_concat_batches": true,
134
+ "fp16_backend": "auto",
135
+ "push_to_hub_model_id": null,
136
+ "push_to_hub_organization": null,
137
+ "push_to_hub_token": "<PUSH_TO_HUB_TOKEN>",
138
+ "mp_parameters": "",
139
+ "auto_find_batch_size": false,
140
+ "full_determinism": false,
141
+ "torchdynamo": null,
142
+ "ray_scope": "last",
143
+ "ddp_timeout": 1800,
144
+ "torch_compile": false,
145
+ "torch_compile_backend": null,
146
+ "torch_compile_mode": null,
147
+ "include_tokens_per_second": false,
148
+ "include_num_input_tokens_seen": false,
149
+ "neftune_noise_alpha": null,
150
+ "optim_target_modules": null,
151
+ "batch_eval_metrics": false,
152
+ "eval_on_start": false,
153
+ "use_liger_kernel": false,
154
+ "eval_use_gather_object": false,
155
+ "average_tokens_across_devices": false
156
+ },
157
+ "adaptive_lr_config": {
158
+ "adaptive_lr": true,
159
+ "lr_patience": 3,
160
+ "lr_factor": 0.8,
161
+ "lr_increase_factor": 1.2,
162
+ "min_lr": 1e-06,
163
+ "max_lr": 0.001,
164
+ "oscillation_amplitude": 0.1
165
+ }
166
+ }
sample.input.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "[DIALOGUE]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "[TURN]",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ },
17
+ {
18
+ "content": "[PAUSE]",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ {
25
+ "content": "[REPEAT]",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ {
32
+ "content": "[HESITATION]",
33
+ "lstrip": false,
34
+ "normalized": false,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ }
38
+ ],
39
+ "cls_token": "[CLS]",
40
+ "mask_token": "[MASK]",
41
+ "pad_token": "[PAD]",
42
+ "sep_token": "[SEP]",
43
+ "unk_token": "[UNK]"
44
+ }
summary_statistics.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Overall Accuracy": 0.8802153432032301,
3
+ "Macro F1": 0.8909792764791806,
4
+ "Weighted F1": 0.8772149647566893,
5
+ "Macro Precision": 0.8990448362732847,
6
+ "Macro Recall": 0.8876134036897266,
7
+ "Average Confidence": 0.9344870448112488,
8
+ "Confidence Std": 0.13039512932300568,
9
+ "Average Severity": [
10
+ 0.23586010932922363,
11
+ 0.2251170426607132,
12
+ 0.29972559213638306,
13
+ 0.2392973005771637
14
+ ],
15
+ "Average Fluency": 0.5604473352432251,
16
+ "Fluency Std": 0.08302813023328781
17
+ }
to_cha.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import batchalign as ba
2
+ nlp = ba.BatchalignPipeline.new("asr,speaker,morphosyntax", lang="eng", num_speakers=2)
3
+ doc = ba.Document.new(media_path="/workspace/SH001/videos/ACWT07a.wav", lang="eng")
4
+ doc = nlp(doc)
5
+ chat = ba.CHATFile(doc=doc)
6
+ chat.write("/workspace/SH001/vid_output/output.cha", write_wor=True)
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "30522": {
44
+ "content": "[DIALOGUE]",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "30523": {
52
+ "content": "[TURN]",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "30524": {
60
+ "content": "[PAUSE]",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ },
67
+ "30525": {
68
+ "content": "[REPEAT]",
69
+ "lstrip": false,
70
+ "normalized": false,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": true
74
+ },
75
+ "30526": {
76
+ "content": "[HESITATION]",
77
+ "lstrip": false,
78
+ "normalized": false,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": true
82
+ }
83
+ },
84
+ "additional_special_tokens": [
85
+ "[DIALOGUE]",
86
+ "[TURN]",
87
+ "[PAUSE]",
88
+ "[REPEAT]",
89
+ "[HESITATION]"
90
+ ],
91
+ "clean_up_tokenization_spaces": true,
92
+ "cls_token": "[CLS]",
93
+ "do_basic_tokenize": true,
94
+ "do_lower_case": true,
95
+ "extra_special_tokens": {},
96
+ "mask_token": "[MASK]",
97
+ "model_max_length": 1000000000000000019884624838656,
98
+ "never_split": null,
99
+ "pad_token": "[PAD]",
100
+ "sep_token": "[SEP]",
101
+ "strip_accents": null,
102
+ "tokenize_chinese_chars": true,
103
+ "tokenizer_class": "BertTokenizer",
104
+ "unk_token": "[UNK]"
105
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff