| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import sys |
| | import os |
| | import argparse |
| | import random |
| | import codecs |
| | from typing import List, Dict |
| | from collections import Counter |
| |
|
| | import torch |
| | from transformers import AutoModelForCausalLM |
| |
|
| |
|
| | def decode_base2_digits_strict(digits: List[int], *, encoding: str = "utf-8", errors: str = "replace") -> str: |
| | |
| | bits: List[int] = [] |
| | for d in digits: |
| | di = int(d) |
| | if di == 0 or di == 1: |
| | bits.append(di) |
| |
|
| | n_full_bytes = len(bits) // 8 |
| | if n_full_bytes <= 0: |
| | return "" |
| |
|
| | out = bytearray(n_full_bytes) |
| |
|
| | j = 0 |
| | for i in range(n_full_bytes): |
| | |
| | b = 0 |
| | b = (b << 1) | bits[j + 0] |
| | b = (b << 1) | bits[j + 1] |
| | b = (b << 1) | bits[j + 2] |
| | b = (b << 1) | bits[j + 3] |
| | b = (b << 1) | bits[j + 4] |
| | b = (b << 1) | bits[j + 5] |
| | b = (b << 1) | bits[j + 6] |
| | b = (b << 1) | bits[j + 7] |
| | out[i] = b |
| | j += 8 |
| |
|
| | bb = bytes(out) |
| |
|
| | |
| | if encoding.lower() == "utf-8": |
| | inc = codecs.getincrementaldecoder("utf-8")(errors=errors) |
| | s = inc.decode(bb, final=False) |
| | s += inc.decode(b"", final=True) |
| | return s |
| |
|
| | return bb.decode(encoding, errors=errors) |
| |
|
| |
|
| | def bytes_to_base2_digits_bytesafe(data: bytes) -> List[int]: |
| | digits: List[int] = [] |
| | for b in data: |
| | for i in range(7, -1, -1): |
| | digits.append((b >> i) & 1) |
| | return digits |
| |
|
| |
|
| | def text_to_base2_digits(text: str) -> List[int]: |
| | |
| | return bytes_to_base2_digits_bytesafe(text.encode("utf-8")) |
| |
|
| |
|
| | def wrap_base2_sequence_2(ids: List[int], bos_id: int, eos_id: int) -> List[int]: |
| | return [int(bos_id), *ids, int(eos_id)] |
| |
|
| |
|
| | def apply_repetition_penalty_(logits: torch.Tensor, token_ids: List[int], penalty: float) -> None: |
| | if penalty is None or penalty == 1.0 or penalty <= 0: |
| | return |
| | for t in set(token_ids): |
| | val = logits[0, t] |
| | logits[0, t] = val * penalty if val < 0 else val / penalty |
| |
|
| |
|
| | def apply_presence_frequency_penalties_(logits: torch.Tensor, token_ids: List[int], presence_penalty: float, frequency_penalty: float) -> None: |
| | counts = Counter(token_ids) |
| | if presence_penalty: |
| | for t in counts: |
| | logits[0, t] -= presence_penalty |
| | if frequency_penalty: |
| | for t, c in counts.items(): |
| | logits[0, t] -= frequency_penalty * c |
| |
|
| |
|
| | def get_banned_tokens_no_repeat_ngram(seq: List[int], n: int) -> set: |
| | if n <= 0 or len(seq) < n - 1: |
| | return set() |
| |
|
| | prefix_len = n - 1 |
| | ngrams: Dict[tuple, set] = {} |
| | for i in range(len(seq) - n + 1): |
| | prefix = tuple(seq[i:i + prefix_len]) |
| | nxt = seq[i + prefix_len] |
| | ngrams.setdefault(prefix, set()).add(nxt) |
| |
|
| | return ngrams.get(tuple(seq[-prefix_len:]), set()) |
| |
|
| |
|
| | def mask_banned_tokens_(logits: torch.Tensor, banned: set) -> None: |
| | if banned: |
| | logits[0, list(banned)] = float("-inf") |
| |
|
| |
|
| | def _maybe_hf_token() -> str: |
| | tok = os.environ.get("HF_TOKEN") |
| | if tok: |
| | return tok |
| | tok = os.environ.get("HUGGINGFACE_HUB_TOKEN") |
| | if tok: |
| | return tok |
| | return "" |
| |
|
| |
|
| | def main() -> None: |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument("--repo", type=str, required=True, help="chemin dossier HF local (./hf_binaryllm_repo) ou repo_id") |
| | parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"]) |
| | parser.add_argument("--seed", type=int, default=-1) |
| |
|
| | |
| | parser.add_argument("--bos", type=int, default=2, help="BOS id (base2: BOS=2)") |
| | parser.add_argument("--eos", type=int, default=3, help="EOS id (base2: EOS=3)") |
| | parser.add_argument("--prompt", type=str, required=True, help="texte à encoder en base2 (UTF-8 -> bits MSB->LSB)") |
| |
|
| | parser.add_argument("--max_new_tokens", type=int, default=800) |
| | parser.add_argument("--temperature", type=float, default=0.7) |
| | parser.add_argument("--top_k", type=int, default=50) |
| |
|
| | parser.add_argument("--repetition_penalty", type=float, default=1.0) |
| | parser.add_argument("--presence_penalty", type=float, default=0.0) |
| | parser.add_argument("--frequency_penalty", type=float, default=0.0) |
| | parser.add_argument("--no_repeat_ngram_size", type=int, default=0) |
| |
|
| | parser.add_argument("--decode_encoding", type=str, default="utf-8") |
| | parser.add_argument("--decode_errors", type=str, default="replace") |
| | parser.add_argument("--print_ids", action="store_true") |
| | parser.add_argument("--stream", action="store_true", help="stream strict (réaffiche decode strict à chaque step)") |
| |
|
| | args = parser.parse_args() |
| |
|
| | seed = args.seed if args.seed >= 0 else random.randint(0, 2**31 - 1) |
| | print(f"[Seed] {seed}") |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | device = torch.device("cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu") |
| | print(f"[Device] {device}") |
| |
|
| | |
| | hf_token = _maybe_hf_token() |
| | if hf_token: |
| | m = AutoModelForCausalLM.from_pretrained(args.repo, trust_remote_code=True, token=hf_token) |
| | else: |
| | m = AutoModelForCausalLM.from_pretrained(args.repo, trust_remote_code=True) |
| |
|
| | m.to(device) |
| | m.eval() |
| |
|
| | |
| | if hasattr(m, "config") and m.config is not None: |
| | m.config.use_cache = True |
| |
|
| | |
| | def encode_prompt(text: str) -> List[int]: |
| | ids = text_to_base2_digits(text) |
| | ids = wrap_base2_sequence_2(ids, args.bos, args.eos) |
| | ids = ids + [int(args.bos)] |
| | print("[+] PROMPT IDS = ", ids) |
| | return ids |
| |
|
| | prompt_ids = encode_prompt(args.prompt) |
| |
|
| | tokens = torch.tensor([prompt_ids], dtype=torch.long, device=device) |
| | generated: List[int] = [] |
| | last_text_len = 0 |
| |
|
| | print("\n[Prompt]\n", args.prompt) |
| | print(f"\n[Prompt IDs] len={len(prompt_ids)} | BOS={args.bos} EOS={args.eos}") |
| | print("\n[Stream]" if args.stream else "\n[Output]") |
| |
|
| | with torch.no_grad(): |
| | for _ in range(int(args.max_new_tokens)): |
| | |
| | out = m(input_ids=tokens, use_cache=True) |
| | logits = out.logits[:, -1, :] |
| |
|
| | full_seq = tokens[0].tolist() |
| |
|
| | apply_repetition_penalty_(logits, full_seq, float(args.repetition_penalty)) |
| | apply_presence_frequency_penalties_(logits, full_seq, float(args.presence_penalty), float(args.frequency_penalty)) |
| |
|
| | if int(args.no_repeat_ngram_size) > 0: |
| | banned = get_banned_tokens_no_repeat_ngram(full_seq, int(args.no_repeat_ngram_size)) |
| | mask_banned_tokens_(logits, banned) |
| |
|
| | logits = logits / max(float(args.temperature), 1e-6) |
| |
|
| | if 0 < int(args.top_k) < logits.size(-1): |
| | v, _ = torch.topk(logits, int(args.top_k)) |
| | logits[logits < v[:, [-1]]] = float("-inf") |
| |
|
| | probs = torch.softmax(logits, dim=-1) |
| | next_token = torch.multinomial(probs, 1) |
| | tok_id = int(next_token.item()) |
| |
|
| | if tok_id == int(args.eos): |
| | break |
| |
|
| | tokens = torch.cat([tokens, next_token], dim=1) |
| | generated.append(tok_id) |
| |
|
| | if args.stream: |
| | text = decode_base2_digits_strict(generated, encoding=args.decode_encoding, errors=args.decode_errors) |
| | if len(text) > last_text_len: |
| | sys.stdout.write(text[last_text_len:]) |
| | sys.stdout.flush() |
| | last_text_len = len(text) |
| |
|
| | if args.stream: |
| | print() |
| |
|
| | print("\n[Final Output]\n") |
| | print(decode_base2_digits_strict(generated, encoding=args.decode_encoding, errors=args.decode_errors)) |
| |
|
| | if args.print_ids: |
| | print("\n[Generated IDs]\n") |
| | print(generated) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|