caixiaoshun commited on
Commit
3ce0244
verified
1 Parent(s): a07ae21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -65
app.py CHANGED
@@ -1,11 +1,10 @@
1
  # app.py
2
  import os
3
- import subprocess
4
- import requests
5
  import gradio as gr
6
  import torch
7
  from tokenizers import Tokenizer
8
  from huggingface_hub import hf_hub_download
 
9
  from src.config import Config
10
  from src.model import TranslateModel
11
 
@@ -17,63 +16,16 @@ HF_TOKENIZER_FILE = os.getenv("TOKENIZER_FILE", "tokenizer.json")
17
  LOCAL_CKPT_PATH = os.getenv("LOCAL_CKPT_PATH", "checkpoints/translate-step=290000.ckpt")
18
  LOCAL_TOKENIZER_PATH = os.getenv("LOCAL_TOKENIZER_PATH", "checkpoints/tokenizer.json")
19
 
20
- # Optional: direct download URLs (useful when huggingface_hub is restricted or to force plain HTTP download)
21
- DIRECT_CKPT_URL = os.getenv(
22
- "DIRECT_CKPT_URL",
23
- "https://huggingface.co/caixiaoshun/tiny-translator-zh2en/resolve/main/translate-step%3D290000.ckpt",
24
- )
25
- DIRECT_TOKENIZER_URL = os.getenv(
26
- "DIRECT_TOKENIZER_URL",
27
- "https://huggingface.co/caixiaoshun/tiny-translator-zh2en/resolve/main/tokenizer.json",
28
- )
29
-
30
-
31
- def _ensure_dir(path: str):
32
- d = os.path.dirname(path)
33
- if d and not os.path.exists(d):
34
- os.makedirs(d, exist_ok=True)
35
-
36
-
37
- def _download_with_wget(url: str, dest: str):
38
- _ensure_dir(dest)
39
- try:
40
- subprocess.run(["wget", "-q", "-O", dest, url], check=True)
41
- return True
42
- except (subprocess.CalledProcessError, FileNotFoundError):
43
- return False
44
-
45
-
46
- def _download_with_requests(url: str, dest: str):
47
- _ensure_dir(dest)
48
- with requests.get(url, stream=True, timeout=60) as r:
49
- r.raise_for_status()
50
- tmp = dest + ".part"
51
- with open(tmp, "wb") as f:
52
- for chunk in r.iter_content(chunk_size=1024 * 1024):
53
- if chunk:
54
- f.write(chunk)
55
- os.replace(tmp, dest)
56
-
57
 
58
  class Inference:
59
  def __init__(self, config: Config, ckpt_path: str):
60
  self.config = config
61
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
62
 
63
- # If local files are missing, try direct download to checkpoints/
64
- if not os.path.exists(LOCAL_TOKENIZER_PATH):
65
- ok = _download_with_wget(DIRECT_TOKENIZER_URL, LOCAL_TOKENIZER_PATH)
66
- if not ok:
67
- _download_with_requests(DIRECT_TOKENIZER_URL, LOCAL_TOKENIZER_PATH)
68
-
69
- if not os.path.exists(LOCAL_CKPT_PATH):
70
- ok = _download_with_wget(DIRECT_CKPT_URL, LOCAL_CKPT_PATH)
71
- if not ok:
72
- _download_with_requests(DIRECT_CKPT_URL, LOCAL_CKPT_PATH)
73
-
74
  # tokenizer (local-first, else hub)
75
  tokenizer_path = (
76
- LOCAL_TOKENIZER_PATH if os.path.exists(LOCAL_TOKENIZER_PATH)
 
77
  else hf_hub_download(repo_id=HF_REPO_ID, filename=HF_TOKENIZER_FILE)
78
  )
79
  self.tokenizer: Tokenizer = Tokenizer.from_file(tokenizer_path)
@@ -83,19 +35,19 @@ class Inference:
83
 
84
  # model
85
  self.model: TranslateModel = TranslateModel(config)
 
86
  # ckpt (local-first, else hub)
87
  ckpt_resolved = (
88
- LOCAL_CKPT_PATH if os.path.exists(LOCAL_CKPT_PATH)
 
89
  else hf_hub_download(repo_id=HF_REPO_ID, filename=HF_CKPT_FILE)
90
  )
91
  ckpt = torch.load(ckpt_resolved, map_location="cpu")["state_dict"]
92
-
93
  prefix = "net._orig_mod."
94
  state_dict = {}
95
  for k, v in ckpt.items():
96
  new_k = k[len(prefix):] if k.startswith(prefix) else k
97
  state_dict[new_k] = v
98
-
99
  self.model.load_state_dict(state_dict, strict=True)
100
  self.model.to(self.device).eval()
101
 
@@ -111,7 +63,6 @@ class Inference:
111
  tgt = torch.cat([tgt, index.unsqueeze(-1)], dim=-1)
112
  if self.id_EOS is not None and index.item() == self.id_EOS:
113
  break
114
-
115
  return tgt.squeeze(0).tolist()
116
 
117
  @torch.no_grad()
@@ -125,21 +76,17 @@ class Inference:
125
  if temperature != 1.0:
126
  logits = logits / temperature
127
  probs = torch.softmax(logits, dim=-1)
128
-
129
  sorted_probs, sorted_idx = torch.sort(probs, descending=True)
130
  cumsum = torch.cumsum(sorted_probs, dim=-1)
131
  mask = cumsum > top_p
132
  mask[..., 0] = False
133
  filtered = sorted_probs.masked_fill(mask, 0.0)
134
  filtered = filtered / filtered.sum(dim=-1, keepdim=True)
135
-
136
  next_sorted = torch.multinomial(filtered, 1) # [1,1]
137
  next_id = sorted_idx.gather(-1, next_sorted)
138
  tgt = torch.cat([tgt, next_id], dim=-1)
139
-
140
  if self.id_EOS is not None and next_id.item() == self.id_EOS:
141
  break
142
-
143
  return tgt.squeeze(0).tolist()
144
 
145
  @torch.no_grad()
@@ -148,7 +95,6 @@ class Inference:
148
  src_pad_mask = (src != self.id_PAD) if (self.id_PAD is not None) else None
149
 
150
  beams = [(torch.tensor([[self.id_SOS]], device=self.device), 0.0)]
151
-
152
  for _ in range(1, max_len):
153
  new_beams = []
154
  for seq, logp in beams:
@@ -169,10 +115,8 @@ class Inference:
169
 
170
  new_beams.sort(key=lambda x: score_fn(x[0], x[1]), reverse=True)
171
  beams = new_beams[:beam]
172
-
173
  if all(seq[0, -1].item() == self.id_EOS for seq, _ in beams if self.id_EOS is not None):
174
  break
175
-
176
  return beams[0][0].squeeze(0).tolist()
177
 
178
  def postprocess(self, ids):
@@ -183,8 +127,16 @@ class Inference:
183
  text = self.tokenizer.decode(ids).strip()
184
  return text
185
 
186
- def translate(self, text, method="greedy", max_tokens=128,
187
- top_p_val=0.9, temperature=1.0, beam=4, len_penalty=0.6):
 
 
 
 
 
 
 
 
188
  src_ids = self.tokenizer.encode(text).ids
189
  max_len = min(max_tokens, self.config.max_len)
190
 
@@ -196,7 +148,6 @@ class Inference:
196
  ids = self.beam_search(src_ids, max_len, beam, len_penalty)
197
  else:
198
  return f"鏈煡瑙g爜鏂规硶: {method}"
199
-
200
  return self.postprocess(ids)
201
 
202
 
 
1
  # app.py
2
  import os
 
 
3
  import gradio as gr
4
  import torch
5
  from tokenizers import Tokenizer
6
  from huggingface_hub import hf_hub_download
7
+
8
  from src.config import Config
9
  from src.model import TranslateModel
10
 
 
16
  LOCAL_CKPT_PATH = os.getenv("LOCAL_CKPT_PATH", "checkpoints/translate-step=290000.ckpt")
17
  LOCAL_TOKENIZER_PATH = os.getenv("LOCAL_TOKENIZER_PATH", "checkpoints/tokenizer.json")
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class Inference:
21
  def __init__(self, config: Config, ckpt_path: str):
22
  self.config = config
23
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  # tokenizer (local-first, else hub)
26
  tokenizer_path = (
27
+ LOCAL_TOKENIZER_PATH
28
+ if os.path.exists(LOCAL_TOKENIZER_PATH)
29
  else hf_hub_download(repo_id=HF_REPO_ID, filename=HF_TOKENIZER_FILE)
30
  )
31
  self.tokenizer: Tokenizer = Tokenizer.from_file(tokenizer_path)
 
35
 
36
  # model
37
  self.model: TranslateModel = TranslateModel(config)
38
+
39
  # ckpt (local-first, else hub)
40
  ckpt_resolved = (
41
+ LOCAL_CKPT_PATH
42
+ if os.path.exists(LOCAL_CKPT_PATH)
43
  else hf_hub_download(repo_id=HF_REPO_ID, filename=HF_CKPT_FILE)
44
  )
45
  ckpt = torch.load(ckpt_resolved, map_location="cpu")["state_dict"]
 
46
  prefix = "net._orig_mod."
47
  state_dict = {}
48
  for k, v in ckpt.items():
49
  new_k = k[len(prefix):] if k.startswith(prefix) else k
50
  state_dict[new_k] = v
 
51
  self.model.load_state_dict(state_dict, strict=True)
52
  self.model.to(self.device).eval()
53
 
 
63
  tgt = torch.cat([tgt, index.unsqueeze(-1)], dim=-1)
64
  if self.id_EOS is not None and index.item() == self.id_EOS:
65
  break
 
66
  return tgt.squeeze(0).tolist()
67
 
68
  @torch.no_grad()
 
76
  if temperature != 1.0:
77
  logits = logits / temperature
78
  probs = torch.softmax(logits, dim=-1)
 
79
  sorted_probs, sorted_idx = torch.sort(probs, descending=True)
80
  cumsum = torch.cumsum(sorted_probs, dim=-1)
81
  mask = cumsum > top_p
82
  mask[..., 0] = False
83
  filtered = sorted_probs.masked_fill(mask, 0.0)
84
  filtered = filtered / filtered.sum(dim=-1, keepdim=True)
 
85
  next_sorted = torch.multinomial(filtered, 1) # [1,1]
86
  next_id = sorted_idx.gather(-1, next_sorted)
87
  tgt = torch.cat([tgt, next_id], dim=-1)
 
88
  if self.id_EOS is not None and next_id.item() == self.id_EOS:
89
  break
 
90
  return tgt.squeeze(0).tolist()
91
 
92
  @torch.no_grad()
 
95
  src_pad_mask = (src != self.id_PAD) if (self.id_PAD is not None) else None
96
 
97
  beams = [(torch.tensor([[self.id_SOS]], device=self.device), 0.0)]
 
98
  for _ in range(1, max_len):
99
  new_beams = []
100
  for seq, logp in beams:
 
115
 
116
  new_beams.sort(key=lambda x: score_fn(x[0], x[1]), reverse=True)
117
  beams = new_beams[:beam]
 
118
  if all(seq[0, -1].item() == self.id_EOS for seq, _ in beams if self.id_EOS is not None):
119
  break
 
120
  return beams[0][0].squeeze(0).tolist()
121
 
122
  def postprocess(self, ids):
 
127
  text = self.tokenizer.decode(ids).strip()
128
  return text
129
 
130
+ def translate(
131
+ self,
132
+ text,
133
+ method="greedy",
134
+ max_tokens=128,
135
+ top_p_val=0.9,
136
+ temperature=1.0,
137
+ beam=4,
138
+ len_penalty=0.6,
139
+ ):
140
  src_ids = self.tokenizer.encode(text).ids
141
  max_len = min(max_tokens, self.config.max_len)
142
 
 
148
  ids = self.beam_search(src_ids, max_len, beam, len_penalty)
149
  else:
150
  return f"鏈煡瑙g爜鏂规硶: {method}"
 
151
  return self.postprocess(ids)
152
 
153