|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
import argparse |
|
|
import random |
|
|
from collections import Counter |
|
|
from typing import List, Dict, Tuple, Any, Optional |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TOK_BOS = 2 |
|
|
TOK_EOS = 3 |
|
|
TOK_BOI = 4 |
|
|
TOK_EOI = 5 |
|
|
TOK_BOR = 6 |
|
|
TOK_EOR = 7 |
|
|
|
|
|
TOK_NAMES = { |
|
|
0: "0", |
|
|
1: "1", |
|
|
TOK_BOS: "BOS", |
|
|
TOK_EOS: "EOS", |
|
|
TOK_BOI: "BOI", |
|
|
TOK_EOI: "EOI", |
|
|
TOK_BOR: "BOR", |
|
|
TOK_EOR: "EOR", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROMPT_INT_T0 = 0 |
|
|
PROMPT_INT_T1 = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_repetition_penalty_(logits: torch.Tensor, token_ids: List[int], penalty: float) -> None: |
|
|
if penalty is None or penalty == 1.0 or penalty <= 0: |
|
|
return |
|
|
for t in set(token_ids): |
|
|
val = logits[0, t] |
|
|
logits[0, t] = val * penalty if val < 0 else val / penalty |
|
|
|
|
|
def apply_encoder_repetition_penalty_(logits: torch.Tensor, prompt_token_ids: List[int], penalty: float) -> None: |
|
|
if penalty is None or penalty == 1.0 or penalty <= 0: |
|
|
return |
|
|
for t in set(prompt_token_ids): |
|
|
val = logits[0, t] |
|
|
logits[0, t] = val / penalty if val < 0 else val * penalty |
|
|
|
|
|
def apply_presence_frequency_penalties_( |
|
|
logits: torch.Tensor, |
|
|
token_ids: List[int], |
|
|
presence_penalty: float, |
|
|
frequency_penalty: float, |
|
|
) -> None: |
|
|
counts = Counter(token_ids) |
|
|
|
|
|
if presence_penalty: |
|
|
for t in counts: |
|
|
logits[0, t] -= presence_penalty |
|
|
|
|
|
if frequency_penalty: |
|
|
for t, c in counts.items(): |
|
|
logits[0, t] -= frequency_penalty * c |
|
|
|
|
|
def get_banned_tokens_no_repeat_ngram(seq: List[int], n: int) -> set: |
|
|
if n <= 0 or len(seq) < n - 1: |
|
|
return set() |
|
|
|
|
|
prefix_len = n - 1 |
|
|
ngrams: Dict[Tuple[int, ...], set] = {} |
|
|
for i in range(len(seq) - n + 1): |
|
|
prefix = tuple(seq[i:i + prefix_len]) |
|
|
nxt = seq[i + prefix_len] |
|
|
ngrams.setdefault(prefix, set()).add(nxt) |
|
|
|
|
|
return ngrams.get(tuple(seq[-prefix_len:]), set()) |
|
|
|
|
|
def mask_banned_tokens_(logits: torch.Tensor, banned: set) -> None: |
|
|
if banned: |
|
|
logits[0, list(banned)] = float("-inf") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_prompt_ids_str(s: str, vocab_size: int = 8) -> List[int]: |
|
|
s = "" if s is None else str(s) |
|
|
s = s.strip() |
|
|
if s == "": |
|
|
return [] |
|
|
|
|
|
if not s.isdigit(): |
|
|
raise ValueError("prompt_ids doit contenir uniquement des chiffres (0..7), sans espaces.") |
|
|
|
|
|
ids: List[int] = [] |
|
|
for ch in s: |
|
|
t = ord(ch) - ord("0") |
|
|
if t < 0 or t >= vocab_size: |
|
|
raise ValueError(f"token id hors vocab: {t} (vocab_size={vocab_size})") |
|
|
ids.append(t) |
|
|
return ids |
|
|
|
|
|
def format_ids_readable(ids: List[int]) -> str: |
|
|
out: List[str] = [] |
|
|
for t in ids: |
|
|
out.append(TOK_NAMES.get(int(t), str(int(t)))) |
|
|
return " ".join(out) |
|
|
|
|
|
def format_ids_compact(ids: List[int]) -> str: |
|
|
s: List[str] = [] |
|
|
for t in ids: |
|
|
ti = int(t) |
|
|
if ti in (0, 1): |
|
|
if s and (s[-1] and s[-1][-1] in ("0", "1")): |
|
|
s[-1] = s[-1] + str(ti) |
|
|
else: |
|
|
s.append(str(ti)) |
|
|
else: |
|
|
s.append(TOK_NAMES.get(ti, str(ti))) |
|
|
return " ".join(s) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def int_to_10bits_tokens(x: int) -> List[int]: |
|
|
if x < 0 or x > 1023: |
|
|
raise ValueError(f"int hors range pour 10 bits: {x} (attendu 0..1023)") |
|
|
b = format(int(x), "010b") |
|
|
return [0 if ch == "0" else 1 for ch in b] |
|
|
|
|
|
def parse_prompt_int_str(s: str) -> Tuple[int, int]: |
|
|
s = "" if s is None else str(s) |
|
|
s = s.strip() |
|
|
if s == "": |
|
|
raise ValueError("--prompt_int vide. Attendu: \"int,int\"") |
|
|
|
|
|
parts = s.split(",") |
|
|
if len(parts) != 2: |
|
|
raise ValueError(f"--prompt_int invalide: {s!r}. Attendu: \"int,int\"") |
|
|
|
|
|
try: |
|
|
a = int(parts[0].strip()) |
|
|
b = int(parts[1].strip()) |
|
|
except Exception: |
|
|
raise ValueError(f"--prompt_int invalide: {s!r}. Les deux valeurs doivent être des int.") |
|
|
|
|
|
return a, b |
|
|
|
|
|
def build_prompt_from_ints(int1: int, int2: int) -> List[int]: |
|
|
seq: List[int] = [] |
|
|
seq.append(TOK_BOS) |
|
|
seq.append(int(PROMPT_INT_T0)) |
|
|
seq.append(int(PROMPT_INT_T1)) |
|
|
|
|
|
seq.append(TOK_BOI) |
|
|
seq.extend(int_to_10bits_tokens(int1)) |
|
|
seq.append(TOK_EOI) |
|
|
|
|
|
seq.append(TOK_BOI) |
|
|
seq.extend(int_to_10bits_tokens(int2)) |
|
|
seq.append(TOK_EOI) |
|
|
|
|
|
return seq |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_first_bor_eor_bits(ids: List[int], min_bits: int = 1) -> Optional[Tuple[List[int], int, int]]: |
|
|
try: |
|
|
i = ids.index(TOK_BOR) |
|
|
except ValueError: |
|
|
return None |
|
|
|
|
|
bits: List[int] = [] |
|
|
j = i + 1 |
|
|
while j < len(ids): |
|
|
t = int(ids[j]) |
|
|
if t == TOK_EOR: |
|
|
break |
|
|
if t in (0, 1): |
|
|
bits.append(t) |
|
|
j += 1 |
|
|
|
|
|
if len(bits) < int(min_bits): |
|
|
return None |
|
|
|
|
|
val = 0 |
|
|
for b in bits: |
|
|
val = (val << 1) | int(b) |
|
|
|
|
|
return bits, val, i |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--repo", type=str, required=True, help='HF repo id ou path local (ex: "PhysiQuanty/xxx")') |
|
|
parser.add_argument("--revision", type=str, default=None, help="HF revision/branch/tag/commit (optionnel)") |
|
|
|
|
|
g = parser.add_mutually_exclusive_group(required=False) |
|
|
g.add_argument("--prompt_ids", type=str, default=None, help='Ex: "240000001540000015" (digits only 0..7) or ""') |
|
|
g.add_argument("--prompt_int", type=str, default=None, help='Ex: "12,900" -> BOS t0 t1 BOI 10b EOI BOI 10b EOI') |
|
|
|
|
|
parser.add_argument("--print_int", action="store_true", help="Affiche le 1er bloc BOR..EOR (bits) en int") |
|
|
|
|
|
parser.add_argument("--max_new_tokens", type=int, default=40) |
|
|
parser.add_argument("--temperature", type=float, default=0.7) |
|
|
parser.add_argument("--top_k", type=int, default=50) |
|
|
|
|
|
parser.add_argument("--repetition_penalty", type=float, default=1.0) |
|
|
parser.add_argument("--presence_penalty", type=float, default=0.0) |
|
|
parser.add_argument("--frequency_penalty", type=float, default=0.0) |
|
|
parser.add_argument("--encoder_repetition_penalty", type=float, default=1.0) |
|
|
parser.add_argument("--no_repeat_ngram_size", type=int, default=0) |
|
|
|
|
|
parser.add_argument("--seed", type=int, default=-1) |
|
|
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"]) |
|
|
|
|
|
parser.add_argument("--stream_ids", action="store_true", help="Stream les IDS générés au fil de l'eau") |
|
|
parser.add_argument("--print_prompt_readable", action="store_true", help="Affiche prompt en tokens lisibles") |
|
|
parser.add_argument("--print_final_readable", action="store_true", help="Affiche sortie finale en tokens lisibles") |
|
|
parser.add_argument("--stop_on_eos", action="store_true", help="Stop dès que EOS(3) est généré") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
seed = args.seed if args.seed >= 0 else random.randint(0, 2**31 - 1) |
|
|
print(f"[Seed] {seed}", flush=True) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
device = torch.device("cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu") |
|
|
print(f"[Device] {device}", flush=True) |
|
|
|
|
|
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32 |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
args.repo, |
|
|
revision=args.revision, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
vocab_size_cfg = int(getattr(model.config, "vocab_size", -1)) |
|
|
print(f"[Model] loaded from {args.repo} | vocab_size={vocab_size_cfg}", flush=True) |
|
|
if vocab_size_cfg != 8: |
|
|
print(f"[Warn] vocab_size={vocab_size_cfg} (attendu 8).", flush=True) |
|
|
|
|
|
|
|
|
if args.prompt_int is not None: |
|
|
int1, int2 = parse_prompt_int_str(args.prompt_int) |
|
|
prompt_ids = build_prompt_from_ints(int1, int2) |
|
|
prompt_origin = f'prompt_int="{args.prompt_int}" (t0,t1={PROMPT_INT_T0},{PROMPT_INT_T1})' |
|
|
else: |
|
|
s = "" if args.prompt_ids is None else args.prompt_ids |
|
|
prompt_ids = parse_prompt_ids_str(s, vocab_size=8) |
|
|
prompt_origin = 'prompt_ids' if args.prompt_ids is not None else 'prompt_ids="" (default)' |
|
|
|
|
|
print(f"[Prompt Origin] {prompt_origin}", flush=True) |
|
|
|
|
|
if args.print_prompt_readable: |
|
|
print(f"[Prompt IDs] {prompt_ids}", flush=True) |
|
|
print(f"[Prompt readable] {format_ids_readable(prompt_ids)}", flush=True) |
|
|
print(f"[Prompt compact] {format_ids_compact(prompt_ids)}", flush=True) |
|
|
else: |
|
|
if len(prompt_ids) == 0: |
|
|
print("[Prompt IDs] len=0 (prompt nul)", flush=True) |
|
|
else: |
|
|
print(f"[Prompt IDs] len={len(prompt_ids)} first32={prompt_ids[:32]}", flush=True) |
|
|
|
|
|
seeded_with_bos = False |
|
|
if len(prompt_ids) == 0: |
|
|
tokens = torch.tensor([TOK_BOS], device=device, dtype=torch.long).unsqueeze(0) |
|
|
seeded_with_bos = True |
|
|
else: |
|
|
tokens = torch.tensor(prompt_ids, device=device, dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
generated_raw: List[int] = [] |
|
|
|
|
|
if args.stream_ids: |
|
|
sys.stdout.write("[Stream IDS] ") |
|
|
sys.stdout.flush() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(int(args.max_new_tokens)): |
|
|
out = model(input_ids=tokens) |
|
|
logits = out.logits[:, -1, :] |
|
|
|
|
|
logits_work = logits.clone() |
|
|
full_seq = tokens[0].tolist() |
|
|
|
|
|
apply_encoder_repetition_penalty_(logits_work, prompt_ids, float(args.encoder_repetition_penalty)) |
|
|
apply_repetition_penalty_(logits_work, full_seq, float(args.repetition_penalty)) |
|
|
apply_presence_frequency_penalties_( |
|
|
logits_work, |
|
|
full_seq, |
|
|
float(args.presence_penalty), |
|
|
float(args.frequency_penalty), |
|
|
) |
|
|
|
|
|
if int(args.no_repeat_ngram_size) > 0: |
|
|
banned = get_banned_tokens_no_repeat_ngram(full_seq, int(args.no_repeat_ngram_size)) |
|
|
mask_banned_tokens_(logits_work, banned) |
|
|
|
|
|
logits_work /= max(float(args.temperature), 1e-6) |
|
|
|
|
|
if 0 < int(args.top_k) < logits_work.size(-1): |
|
|
v, _ = torch.topk(logits_work, int(args.top_k)) |
|
|
logits_work[logits_work < v[:, [-1]]] = float("-inf") |
|
|
|
|
|
probs = torch.softmax(logits_work, dim=-1) |
|
|
next_token = torch.multinomial(probs, 1) |
|
|
tok_id = int(next_token.item()) |
|
|
generated_raw.append(tok_id) |
|
|
|
|
|
if args.stream_ids: |
|
|
sys.stdout.write(str(tok_id)) |
|
|
sys.stdout.flush() |
|
|
|
|
|
tokens = torch.cat([tokens, next_token], dim=1) |
|
|
|
|
|
if args.stop_on_eos and tok_id == TOK_EOS: |
|
|
break |
|
|
|
|
|
if args.stream_ids: |
|
|
sys.stdout.write("\n") |
|
|
sys.stdout.flush() |
|
|
|
|
|
if seeded_with_bos: |
|
|
print("\n[Prompt] prompt nul -> seed interne BOS(2) utilisé uniquement pour init logits", flush=True) |
|
|
|
|
|
print("\n[Generated RAW IDS]", flush=True) |
|
|
print(generated_raw, flush=True) |
|
|
|
|
|
print("\n[Generated RAW IDS (as digits)]", flush=True) |
|
|
print("".join(str(x) for x in generated_raw), flush=True) |
|
|
|
|
|
if args.print_final_readable or args.print_int: |
|
|
full = prompt_ids + generated_raw |
|
|
|
|
|
if args.print_final_readable: |
|
|
print("\n[Full sequence readable]", flush=True) |
|
|
print(format_ids_readable(full), flush=True) |
|
|
print("\n[Full sequence compact]", flush=True) |
|
|
print(format_ids_compact(full), flush=True) |
|
|
|
|
|
if args.print_int: |
|
|
got = extract_first_bor_eor_bits(full, min_bits=10) |
|
|
if got is None: |
|
|
print("\n[PrintInt] Aucun bloc BOR..EOR valide trouvé.", flush=True) |
|
|
else: |
|
|
bits, val, pos = got |
|
|
bits_str = "".join(str(b) for b in bits) |
|
|
print("\n[PrintInt] First BOR..EOR", flush=True) |
|
|
print(f"[PrintInt] pos={pos} nbits={len(bits)} bits={bits_str} int={val}", flush=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|