Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
# app.py
|
| 2 |
import os
|
| 3 |
-
import subprocess
|
| 4 |
-
import requests
|
| 5 |
import gradio as gr
|
| 6 |
import torch
|
| 7 |
from tokenizers import Tokenizer
|
| 8 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 9 |
from src.config import Config
|
| 10 |
from src.model import TranslateModel
|
| 11 |
|
|
@@ -17,63 +16,16 @@ HF_TOKENIZER_FILE = os.getenv("TOKENIZER_FILE", "tokenizer.json")
|
|
| 17 |
LOCAL_CKPT_PATH = os.getenv("LOCAL_CKPT_PATH", "checkpoints/translate-step=290000.ckpt")
|
| 18 |
LOCAL_TOKENIZER_PATH = os.getenv("LOCAL_TOKENIZER_PATH", "checkpoints/tokenizer.json")
|
| 19 |
|
| 20 |
-
# Optional: direct download URLs (useful when huggingface_hub is restricted or to force plain HTTP download)
|
| 21 |
-
DIRECT_CKPT_URL = os.getenv(
|
| 22 |
-
"DIRECT_CKPT_URL",
|
| 23 |
-
"https://huggingface.co/caixiaoshun/tiny-translator-zh2en/resolve/main/translate-step%3D290000.ckpt",
|
| 24 |
-
)
|
| 25 |
-
DIRECT_TOKENIZER_URL = os.getenv(
|
| 26 |
-
"DIRECT_TOKENIZER_URL",
|
| 27 |
-
"https://huggingface.co/caixiaoshun/tiny-translator-zh2en/resolve/main/tokenizer.json",
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def _ensure_dir(path: str):
|
| 32 |
-
d = os.path.dirname(path)
|
| 33 |
-
if d and not os.path.exists(d):
|
| 34 |
-
os.makedirs(d, exist_ok=True)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def _download_with_wget(url: str, dest: str):
|
| 38 |
-
_ensure_dir(dest)
|
| 39 |
-
try:
|
| 40 |
-
subprocess.run(["wget", "-q", "-O", dest, url], check=True)
|
| 41 |
-
return True
|
| 42 |
-
except (subprocess.CalledProcessError, FileNotFoundError):
|
| 43 |
-
return False
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def _download_with_requests(url: str, dest: str):
|
| 47 |
-
_ensure_dir(dest)
|
| 48 |
-
with requests.get(url, stream=True, timeout=60) as r:
|
| 49 |
-
r.raise_for_status()
|
| 50 |
-
tmp = dest + ".part"
|
| 51 |
-
with open(tmp, "wb") as f:
|
| 52 |
-
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
| 53 |
-
if chunk:
|
| 54 |
-
f.write(chunk)
|
| 55 |
-
os.replace(tmp, dest)
|
| 56 |
-
|
| 57 |
|
| 58 |
class Inference:
|
| 59 |
def __init__(self, config: Config, ckpt_path: str):
|
| 60 |
self.config = config
|
| 61 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 62 |
|
| 63 |
-
# If local files are missing, try direct download to checkpoints/
|
| 64 |
-
if not os.path.exists(LOCAL_TOKENIZER_PATH):
|
| 65 |
-
ok = _download_with_wget(DIRECT_TOKENIZER_URL, LOCAL_TOKENIZER_PATH)
|
| 66 |
-
if not ok:
|
| 67 |
-
_download_with_requests(DIRECT_TOKENIZER_URL, LOCAL_TOKENIZER_PATH)
|
| 68 |
-
|
| 69 |
-
if not os.path.exists(LOCAL_CKPT_PATH):
|
| 70 |
-
ok = _download_with_wget(DIRECT_CKPT_URL, LOCAL_CKPT_PATH)
|
| 71 |
-
if not ok:
|
| 72 |
-
_download_with_requests(DIRECT_CKPT_URL, LOCAL_CKPT_PATH)
|
| 73 |
-
|
| 74 |
# tokenizer (local-first, else hub)
|
| 75 |
tokenizer_path = (
|
| 76 |
-
LOCAL_TOKENIZER_PATH
|
|
|
|
| 77 |
else hf_hub_download(repo_id=HF_REPO_ID, filename=HF_TOKENIZER_FILE)
|
| 78 |
)
|
| 79 |
self.tokenizer: Tokenizer = Tokenizer.from_file(tokenizer_path)
|
|
@@ -83,19 +35,19 @@ class Inference:
|
|
| 83 |
|
| 84 |
# model
|
| 85 |
self.model: TranslateModel = TranslateModel(config)
|
|
|
|
| 86 |
# ckpt (local-first, else hub)
|
| 87 |
ckpt_resolved = (
|
| 88 |
-
LOCAL_CKPT_PATH
|
|
|
|
| 89 |
else hf_hub_download(repo_id=HF_REPO_ID, filename=HF_CKPT_FILE)
|
| 90 |
)
|
| 91 |
ckpt = torch.load(ckpt_resolved, map_location="cpu")["state_dict"]
|
| 92 |
-
|
| 93 |
prefix = "net._orig_mod."
|
| 94 |
state_dict = {}
|
| 95 |
for k, v in ckpt.items():
|
| 96 |
new_k = k[len(prefix):] if k.startswith(prefix) else k
|
| 97 |
state_dict[new_k] = v
|
| 98 |
-
|
| 99 |
self.model.load_state_dict(state_dict, strict=True)
|
| 100 |
self.model.to(self.device).eval()
|
| 101 |
|
|
@@ -111,7 +63,6 @@ class Inference:
|
|
| 111 |
tgt = torch.cat([tgt, index.unsqueeze(-1)], dim=-1)
|
| 112 |
if self.id_EOS is not None and index.item() == self.id_EOS:
|
| 113 |
break
|
| 114 |
-
|
| 115 |
return tgt.squeeze(0).tolist()
|
| 116 |
|
| 117 |
@torch.no_grad()
|
|
@@ -125,21 +76,17 @@ class Inference:
|
|
| 125 |
if temperature != 1.0:
|
| 126 |
logits = logits / temperature
|
| 127 |
probs = torch.softmax(logits, dim=-1)
|
| 128 |
-
|
| 129 |
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
| 130 |
cumsum = torch.cumsum(sorted_probs, dim=-1)
|
| 131 |
mask = cumsum > top_p
|
| 132 |
mask[..., 0] = False
|
| 133 |
filtered = sorted_probs.masked_fill(mask, 0.0)
|
| 134 |
filtered = filtered / filtered.sum(dim=-1, keepdim=True)
|
| 135 |
-
|
| 136 |
next_sorted = torch.multinomial(filtered, 1) # [1,1]
|
| 137 |
next_id = sorted_idx.gather(-1, next_sorted)
|
| 138 |
tgt = torch.cat([tgt, next_id], dim=-1)
|
| 139 |
-
|
| 140 |
if self.id_EOS is not None and next_id.item() == self.id_EOS:
|
| 141 |
break
|
| 142 |
-
|
| 143 |
return tgt.squeeze(0).tolist()
|
| 144 |
|
| 145 |
@torch.no_grad()
|
|
@@ -148,7 +95,6 @@ class Inference:
|
|
| 148 |
src_pad_mask = (src != self.id_PAD) if (self.id_PAD is not None) else None
|
| 149 |
|
| 150 |
beams = [(torch.tensor([[self.id_SOS]], device=self.device), 0.0)]
|
| 151 |
-
|
| 152 |
for _ in range(1, max_len):
|
| 153 |
new_beams = []
|
| 154 |
for seq, logp in beams:
|
|
@@ -169,10 +115,8 @@ class Inference:
|
|
| 169 |
|
| 170 |
new_beams.sort(key=lambda x: score_fn(x[0], x[1]), reverse=True)
|
| 171 |
beams = new_beams[:beam]
|
| 172 |
-
|
| 173 |
if all(seq[0, -1].item() == self.id_EOS for seq, _ in beams if self.id_EOS is not None):
|
| 174 |
break
|
| 175 |
-
|
| 176 |
return beams[0][0].squeeze(0).tolist()
|
| 177 |
|
| 178 |
def postprocess(self, ids):
|
|
@@ -183,8 +127,16 @@ class Inference:
|
|
| 183 |
text = self.tokenizer.decode(ids).strip()
|
| 184 |
return text
|
| 185 |
|
| 186 |
-
def translate(
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
src_ids = self.tokenizer.encode(text).ids
|
| 189 |
max_len = min(max_tokens, self.config.max_len)
|
| 190 |
|
|
@@ -196,7 +148,6 @@ class Inference:
|
|
| 196 |
ids = self.beam_search(src_ids, max_len, beam, len_penalty)
|
| 197 |
else:
|
| 198 |
return f"鏈煡瑙g爜鏂规硶: {method}"
|
| 199 |
-
|
| 200 |
return self.postprocess(ids)
|
| 201 |
|
| 202 |
|
|
|
|
| 1 |
# app.py
|
| 2 |
import os
|
|
|
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
from tokenizers import Tokenizer
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
+
|
| 8 |
from src.config import Config
|
| 9 |
from src.model import TranslateModel
|
| 10 |
|
|
|
|
| 16 |
LOCAL_CKPT_PATH = os.getenv("LOCAL_CKPT_PATH", "checkpoints/translate-step=290000.ckpt")
|
| 17 |
LOCAL_TOKENIZER_PATH = os.getenv("LOCAL_TOKENIZER_PATH", "checkpoints/tokenizer.json")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
class Inference:
|
| 21 |
def __init__(self, config: Config, ckpt_path: str):
|
| 22 |
self.config = config
|
| 23 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# tokenizer (local-first, else hub)
|
| 26 |
tokenizer_path = (
|
| 27 |
+
LOCAL_TOKENIZER_PATH
|
| 28 |
+
if os.path.exists(LOCAL_TOKENIZER_PATH)
|
| 29 |
else hf_hub_download(repo_id=HF_REPO_ID, filename=HF_TOKENIZER_FILE)
|
| 30 |
)
|
| 31 |
self.tokenizer: Tokenizer = Tokenizer.from_file(tokenizer_path)
|
|
|
|
| 35 |
|
| 36 |
# model
|
| 37 |
self.model: TranslateModel = TranslateModel(config)
|
| 38 |
+
|
| 39 |
# ckpt (local-first, else hub)
|
| 40 |
ckpt_resolved = (
|
| 41 |
+
LOCAL_CKPT_PATH
|
| 42 |
+
if os.path.exists(LOCAL_CKPT_PATH)
|
| 43 |
else hf_hub_download(repo_id=HF_REPO_ID, filename=HF_CKPT_FILE)
|
| 44 |
)
|
| 45 |
ckpt = torch.load(ckpt_resolved, map_location="cpu")["state_dict"]
|
|
|
|
| 46 |
prefix = "net._orig_mod."
|
| 47 |
state_dict = {}
|
| 48 |
for k, v in ckpt.items():
|
| 49 |
new_k = k[len(prefix):] if k.startswith(prefix) else k
|
| 50 |
state_dict[new_k] = v
|
|
|
|
| 51 |
self.model.load_state_dict(state_dict, strict=True)
|
| 52 |
self.model.to(self.device).eval()
|
| 53 |
|
|
|
|
| 63 |
tgt = torch.cat([tgt, index.unsqueeze(-1)], dim=-1)
|
| 64 |
if self.id_EOS is not None and index.item() == self.id_EOS:
|
| 65 |
break
|
|
|
|
| 66 |
return tgt.squeeze(0).tolist()
|
| 67 |
|
| 68 |
@torch.no_grad()
|
|
|
|
| 76 |
if temperature != 1.0:
|
| 77 |
logits = logits / temperature
|
| 78 |
probs = torch.softmax(logits, dim=-1)
|
|
|
|
| 79 |
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
| 80 |
cumsum = torch.cumsum(sorted_probs, dim=-1)
|
| 81 |
mask = cumsum > top_p
|
| 82 |
mask[..., 0] = False
|
| 83 |
filtered = sorted_probs.masked_fill(mask, 0.0)
|
| 84 |
filtered = filtered / filtered.sum(dim=-1, keepdim=True)
|
|
|
|
| 85 |
next_sorted = torch.multinomial(filtered, 1) # [1,1]
|
| 86 |
next_id = sorted_idx.gather(-1, next_sorted)
|
| 87 |
tgt = torch.cat([tgt, next_id], dim=-1)
|
|
|
|
| 88 |
if self.id_EOS is not None and next_id.item() == self.id_EOS:
|
| 89 |
break
|
|
|
|
| 90 |
return tgt.squeeze(0).tolist()
|
| 91 |
|
| 92 |
@torch.no_grad()
|
|
|
|
| 95 |
src_pad_mask = (src != self.id_PAD) if (self.id_PAD is not None) else None
|
| 96 |
|
| 97 |
beams = [(torch.tensor([[self.id_SOS]], device=self.device), 0.0)]
|
|
|
|
| 98 |
for _ in range(1, max_len):
|
| 99 |
new_beams = []
|
| 100 |
for seq, logp in beams:
|
|
|
|
| 115 |
|
| 116 |
new_beams.sort(key=lambda x: score_fn(x[0], x[1]), reverse=True)
|
| 117 |
beams = new_beams[:beam]
|
|
|
|
| 118 |
if all(seq[0, -1].item() == self.id_EOS for seq, _ in beams if self.id_EOS is not None):
|
| 119 |
break
|
|
|
|
| 120 |
return beams[0][0].squeeze(0).tolist()
|
| 121 |
|
| 122 |
def postprocess(self, ids):
|
|
|
|
| 127 |
text = self.tokenizer.decode(ids).strip()
|
| 128 |
return text
|
| 129 |
|
| 130 |
+
def translate(
|
| 131 |
+
self,
|
| 132 |
+
text,
|
| 133 |
+
method="greedy",
|
| 134 |
+
max_tokens=128,
|
| 135 |
+
top_p_val=0.9,
|
| 136 |
+
temperature=1.0,
|
| 137 |
+
beam=4,
|
| 138 |
+
len_penalty=0.6,
|
| 139 |
+
):
|
| 140 |
src_ids = self.tokenizer.encode(text).ids
|
| 141 |
max_len = min(max_tokens, self.config.max_len)
|
| 142 |
|
|
|
|
| 148 |
ids = self.beam_search(src_ids, max_len, beam, len_penalty)
|
| 149 |
else:
|
| 150 |
return f"鏈煡瑙g爜鏂规硶: {method}"
|
|
|
|
| 151 |
return self.postprocess(ids)
|
| 152 |
|
| 153 |
|