Zhe-Zhang commited on
Commit
b8828ec
·
verified ·
1 Parent(s): 0b58da8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -119
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # debug_app.py — 把它放到 HF Space 替换原来的 app.py
2
- import os, hashlib, json
3
  import numpy as np
4
  import torch
5
  import torch.nn as nn
@@ -8,7 +6,7 @@ import joblib
8
  from collections import Counter
9
  import gradio as gr
10
 
11
- # --- utils (同训练代码) ---
12
  def ngrams(sentence, n=1, lc=True):
13
  ngram_l = []
14
  sentence = sentence.lower()
@@ -32,16 +30,16 @@ def reproducible_hash(string):
32
  h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
33
  return int.from_bytes(h.digest()[0:8], 'big', signed=True)
34
 
35
- def hash_ngrams(ngrams_list, modulos):
36
  hash_codes = []
37
- for ngram_list, modulo in zip(ngrams_list, modulos):
38
  codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
39
  hash_codes.append(codes)
40
  return hash_codes
41
 
42
  def calc_rel_freq(codes):
43
  cnt = Counter(codes)
44
- total = sum(cnt.values()) if cnt else 1
45
  for k in cnt:
46
  cnt[k] /= total
47
  return cnt
@@ -59,131 +57,45 @@ def shift_keys(dicts, MAX_SHIFT):
59
 
60
  def build_freq_dict(sentence, MAXES=MAXES, MAX_SHIFT=MAX_SHIFT):
61
  hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
62
- fhcodes = list(map(calc_rel_freq, hngrams))
63
  return shift_keys(fhcodes, MAX_SHIFT)
64
 
65
- # --- helper diagnostics ---
66
- def file_md5(path):
67
- if not os.path.exists(path):
68
- return None
69
- with open(path, "rb") as f:
70
- return hashlib.md5(f.read()).hexdigest()
71
-
72
- def model_param_stats(m):
73
- mins, maxs, means = [], [], []
74
- for p in m.parameters():
75
- arr = p.detach().cpu().numpy().ravel()
76
- if arr.size == 0:
77
- continue
78
- mins.append(float(arr.min()))
79
- maxs.append(float(arr.max()))
80
- means.append(float(arr.mean()))
81
- if not mins:
82
- return {"min": None, "max": None, "mean": None}
83
- return {"min": min(mins), "max": max(maxs), "mean": float(np.mean(means))}
84
-
85
- # --- load artifacts (ensure these files exist in your repo) ---
86
- VEC_FN = "nld_vectorizer.joblib"
87
- LANG_FN = "nld_lang_codes.joblib"
88
- MODEL_FN = "nld.pth"
89
-
90
- vectorizer = joblib.load(VEC_FN)
91
- idx2lang = joblib.load(LANG_FN)
92
 
93
  input_dim = len(vectorizer.vocabulary_)
94
  nbr_classes = len(idx2lang)
95
 
96
- # build model skeleton same as training
97
  model = nn.Sequential(
98
  nn.Linear(input_dim, 50),
99
  nn.ReLU(),
100
  nn.Linear(50, nbr_classes)
101
  )
102
- model.load_state_dict(torch.load(MODEL_FN, map_location="cpu"))
103
  model.eval()
104
 
105
- # quick sanity info (will also print to logs)
106
- print(">>> artifact md5:", MODEL_FN, file_md5(MODEL_FN))
107
- print(">>> artifact md5:", VEC_FN, file_md5(VEC_FN))
108
- print(">>> artifact md5:", LANG_FN, file_md5(LANG_FN))
109
- print("vocab size:", len(vectorizer.vocabulary_))
110
- print("sample vocab items:", list(vectorizer.vocabulary_.items())[:10])
111
- print("idx2lang sample:", list(idx2lang.items())[:10])
112
- print("model param stats:", model_param_stats(model))
113
-
114
- # --- prediction + debug function ---
115
- def detect_lang_debug(src_sentence: str):
116
- debug = {}
117
- debug['md5_model'] = file_md5(MODEL_FN)
118
- debug['md5_vectorizer'] = file_md5(VEC_FN)
119
- debug['md5_idx2lang'] = file_md5(LANG_FN)
120
- debug['vocab_size'] = len(vectorizer.vocabulary_)
121
- debug['idx2lang_len'] = len(idx2lang)
122
- debug['idx2lang_sample'] = dict(list(idx2lang.items())[:10])
123
- debug['model_param_stats'] = model_param_stats(model)
124
-
125
- feat_dict = build_freq_dict(src_sentence)
126
- X_test = vectorizer.transform([feat_dict]) # ensure a single dict in list
127
- if hasattr(X_test, "toarray"):
128
- X_arr = X_test.toarray()
129
- else:
130
- X_arr = np.array(X_test)
131
- debug['nonzero_features'] = int(np.count_nonzero(X_arr))
132
- debug['X_shape'] = X_arr.shape
133
-
134
- X_tensor = torch.from_numpy(X_arr.astype("float32"))
135
- with torch.no_grad():
136
- logits = model(X_tensor)
137
- probs = torch.softmax(logits, dim=-1).cpu().numpy().ravel()
138
-
139
- topk = list(np.argsort(probs)[::-1][:5])
140
- topk_info = [(int(k), idx2lang[int(k)], float(probs[int(k)])) for k in topk]
141
- pred_idx = int(topk[0])
142
- pred_lang = idx2lang[pred_idx]
143
-
144
- debug_text = json.dumps({
145
- "pred_lang": pred_lang,
146
- "pred_idx": pred_idx,
147
- "topk": topk_info,
148
- "debug": debug
149
- }, ensure_ascii=False, indent=2)
150
- print("DEBUG:", debug_text) # Visible in Spaces logs
151
- return pred_lang, debug_text
152
-
153
- # --- self-test example set ---
154
- SELF_TESTS = {
155
- "eng": "Hello, how are you?",
156
- "fra": "Bonjour, comment allez-vous?",
157
- "cmn": "你好,你在做什么?",
158
- "jpn": "こんにちは、お元気ですか?",
159
- "kor": "안녕하세요. 잘 지내세요?",
160
- "ara": "مرحبا كيف حالك",
161
- "swe": "Hej, hur mår du?",
162
- "dan": "Godmorgen, hvordan har du det?"
163
- }
164
-
165
- def run_self_test():
166
- results = []
167
- for lang, sent in SELF_TESTS.items():
168
- pred, dbg = detect_lang_debug(sent)
169
- ok = (pred == lang) or (pred == lang) # best-effort equality
170
- results.append(f"{lang} | sent: {sent} | pred: {pred} | ok: {ok}")
171
- out = "\n".join(results)
172
- print("SELF-TEST RESULTS:\n", out)
173
- return out
174
-
175
  # --- Gradio UI ---
176
- with gr.Blocks(title="Antons language detector (debug)") as demo:
177
- gr.Markdown("# Antons language detector — debug build")
178
  with gr.Row():
179
- with gr.Column(scale=3):
180
- src = gr.Textbox(label="Text", placeholder="Write your text...")
181
- btn = gr.Button("Guess the language!")
182
- selftest_btn = gr.Button("Run self-test")
183
- with gr.Column(scale=2):
184
- out_lang = gr.Textbox(label="Language", interactive=False)
185
- out_debug = gr.Textbox(label="Debug info (JSON)", interactive=False, lines=20)
186
- btn.click(fn=detect_lang_debug, inputs=[src], outputs=[out_lang, out_debug])
187
- selftest_btn.click(fn=run_self_test, inputs=[], outputs=[out_debug])
188
-
189
- demo.launch()
 
 
 
1
  import numpy as np
2
  import torch
3
  import torch.nn as nn
 
6
  from collections import Counter
7
  import gradio as gr
8
 
9
+ # --- utils (from the notebook) ---
10
  def ngrams(sentence, n=1, lc=True):
11
  ngram_l = []
12
  sentence = sentence.lower()
 
30
  h = hashlib.md5(string.encode("utf-8"), usedforsecurity=False)
31
  return int.from_bytes(h.digest()[0:8], 'big', signed=True)
32
 
33
+ def hash_ngrams(ngrams, modulos):
34
  hash_codes = []
35
+ for ngram_list, modulo in zip(ngrams, modulos):
36
  codes = [(reproducible_hash(x) % modulo) for x in ngram_list]
37
  hash_codes.append(codes)
38
  return hash_codes
39
 
40
  def calc_rel_freq(codes):
41
  cnt = Counter(codes)
42
+ total = sum(cnt.values())
43
  for k in cnt:
44
  cnt[k] /= total
45
  return cnt
 
57
 
58
  def build_freq_dict(sentence, MAXES=MAXES, MAX_SHIFT=MAX_SHIFT):
59
  hngrams = hash_ngrams(all_ngrams(sentence), MAXES)
60
+ fhcodes = map(calc_rel_freq, hngrams)
61
  return shift_keys(fhcodes, MAX_SHIFT)
62
 
63
+ # --- load models ---
64
+ clf = joblib.load("nld.joblib")
65
+ vectorizer = joblib.load("nld_vectorizer.joblib")
66
+ idx2lang = joblib.load("nld_lang_codes.joblib")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  input_dim = len(vectorizer.vocabulary_)
69
  nbr_classes = len(idx2lang)
70
 
 
71
  model = nn.Sequential(
72
  nn.Linear(input_dim, 50),
73
  nn.ReLU(),
74
  nn.Linear(50, nbr_classes)
75
  )
76
+ model.load_state_dict(torch.load("nld.pth", map_location="cpu"))
77
  model.eval()
78
 
79
+ # --- prediction function ---
80
+ # 修改预测函数为sklearn版本
81
+ def detect_lang(src_sentence):
82
+ # 特征提取逻辑不变
83
+ X_test = vectorizer.transform([build_freq_dict(src_sentence)])
84
+ # 使用sklearn模型预测
85
+ pred_idx = clf.predict(X_test)[0]
86
+ return idx2lang[pred_idx]
87
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  # --- Gradio UI ---
89
+ with gr.Blocks(title="Antons language detector") as demo:
90
+ gr.Markdown("# Antons language detector")
91
  with gr.Row():
92
+ with gr.Column():
93
+ src_sentence = gr.Textbox(
94
+ label="Text", placeholder="Write your text...")
95
+ with gr.Column():
96
+ tgt_sentence = gr.Textbox(
97
+ label="Language", placeholder="Language will show here...")
98
+ btn = gr.Button("Guess the language!")
99
+ btn.click(fn=detect_lang, inputs=[src_sentence], outputs=[tgt_sentence])
100
+
101
+ demo.launch()