|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import gc |
|
import time |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import os |
|
import argparse |
|
import sys |
|
import re |
|
from huggingface_hub import login |
|
import os |
|
from unidecode import unidecode |
|
|
|
def capture_logits_hook(module, input_args, output): |
|
""" |
|
Hook function to capture the output of the lm_head layer. |
|
The output might be a tensor or a tuple containing the tensor. |
|
We are interested in the tensor containing logits. |
|
""" |
|
if isinstance(output, torch.Tensor): |
|
logits = output |
|
elif isinstance(output, tuple) and len(output) > 0 and isinstance(output[0], torch.Tensor): |
|
|
|
|
|
logits = output[0] |
|
else: |
|
|
|
print(f"Warning: Hook captured unexpected output type: {type(output)}") |
|
return |
|
|
|
parser = argparse.ArgumentParser( |
|
description="LifGenerator for CPU with Hugging face models with greedy decoding", |
|
epilog="Help Documentation" |
|
) |
|
parser.add_argument( |
|
"-input_file", "-i", |
|
type=str, |
|
help="The path to the input file." |
|
) |
|
|
|
parser.add_argument( |
|
"-output_file", "-o", |
|
type=str, |
|
help="Name and path of output file" |
|
) |
|
|
|
parser.add_argument( |
|
"-prompt_id", "-pid", |
|
type=str, |
|
help="Overall name of item" |
|
) |
|
|
|
parser.add_argument( |
|
"-prompt_topic", "-pt", |
|
type=str, |
|
help="Topic given to LLM before stem words" |
|
) |
|
|
|
parser.add_argument( |
|
"-multi_pv", "-mpv", |
|
type=int, |
|
help="Number of options to consider at each turn" |
|
) |
|
|
|
parser.add_argument( |
|
"-num_words", "-nw", |
|
type=int, |
|
help="Cap on # of text words to iterate" |
|
) |
|
|
|
parser.add_argument( |
|
"-num_tokens", "-nt", |
|
type=int, |
|
help="# of tokens to search for text word match" |
|
) |
|
|
|
parser.add_argument( |
|
"-beam_width", "-bw", |
|
type=int, |
|
help="Width of beam search, 0 or 1 for greedy" |
|
) |
|
|
|
parser.add_argument( |
|
"-alpha_mode", "-a", |
|
type=int, |
|
help="0 = all tokens, up thru 4 = alpha chars plus ' only" |
|
) |
|
|
|
parser.add_argument( |
|
"-start_turn", "-st", |
|
type=int, |
|
help="1 by default, adds st-1 words to prompt" |
|
) |
|
|
|
parser.add_argument( |
|
"-model", "-model", |
|
type=str, |
|
help="DS for DeepSeek, QWEN for Qwen" |
|
) |
|
|
|
args = parser.parse_args() |
|
print("Welcome to the LifGenerator CPU script!") |
|
print("This script generates lif files using a Hugging Face model and greedy decoding.") |
|
|
|
|
|
INPUT_FILE = args.input_file if args.input_file else "Kanga20.txt" |
|
INPUT_FILE_STEM = INPUT_FILE.split('.')[0] |
|
OUTPUT_FILE = args.output_file if args.output_file else (INPUT_FILE_STEM + ".lif") |
|
print(f"Input file path: {INPUT_FILE}") |
|
print(f"Output file path: {OUTPUT_FILE}") |
|
print("But using text output window instead...") |
|
PROMPT_ID = args.prompt_id if args.prompt_id else INPUT_FILE |
|
PROMPT_TOPIC = args.prompt_topic if args.prompt_topic else INPUT_FILE |
|
MULTI_PV = args.multi_pv if args.multi_pv else 100 |
|
NUM_WORDS = args.num_words if args.num_words else 20 |
|
NUM_TOKENS = args.num_tokens if args.num_tokens else 10000 |
|
BEAM_WIDTH = args.beam_width if args.beam_width else 1 |
|
ALPHA_MODE = args.alpha_mode if args.alpha_mode else 0 |
|
START_TURN = args.start_turn if args.start_turn else 1 |
|
MODEL_TAG = args.model if args.model else "Qwen" |
|
MINUS_INF = -1000.0 |
|
|
|
|
|
""" |
|
Match if arg occurs in st surrounded by ends or non-alpha chars. |
|
|
|
Intent is e.g. for "Karp" to match "Karp, R" but not "Karpov". |
|
Whether "Karp" matches "Karp-Lipton" depends on whether hyphen is part of name. |
|
Works even if arg itself has non-alpha characters. |
|
Used for player and event names AND to identify tokens in command streams. |
|
Uses C++ "isalpha" for local definition of names. |
|
Prefer to override it to count underscore as a non-delimiting char. |
|
Hyphen is always part of tokens but can be used to delimit place and person names, |
|
so "Khanty" and "Khanty-Mansiysk" can both match "Khanty-Mansiysk" and |
|
"Vachier" can match "Vachier-Lagrave". |
|
|
|
With LLM tokens, this allows arg="abc" to match st=" abc" but not vice-versa. |
|
However, if called with arg.strip() then vice-versa is fine. |
|
If the token is @-@ then it will match "--" but NOT match a hyphenated word. |
|
""" |
|
|
|
|
|
def borderedMatch(arg, st, hyphenDelimits=False, underscoreDelimits=False): |
|
fromPos = st.find(arg) |
|
while fromPos != -1: |
|
leftOK = (fromPos == 0) |
|
if (fromPos > 0): |
|
c = st[fromPos - 1] |
|
if c == '-': |
|
leftOK = hyphenDelimits |
|
elif c == '_': |
|
leftOK = underscoreDelimits |
|
else: |
|
leftOK = (not c.isalnum()) |
|
|
|
rightEdge = fromPos + len(arg) |
|
rightOK = (rightEdge == len(st)) |
|
if (not rightOK): |
|
d = st[rightEdge] |
|
if d == '-': |
|
rightOK = hyphenDelimits |
|
elif d == '_': |
|
rightOK = underscoreDelimits |
|
else: |
|
rightOK = (not d.isalnum()) |
|
|
|
if rightOK and leftOK: |
|
return True |
|
else: |
|
fromPos = st.find(arg, fromPos + 1) |
|
|
|
return False |
|
|
|
|
|
def reprat(tok): |
|
rep = unidecode(repr(tok)) |
|
return f"@{rep.replace('@','(at)')[1:-1]}@" |
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_token = os.getenv('May8WriteRepos') |
|
|
|
if hf_token: |
|
print("Logging in to Hugging Face Hub...") |
|
login(token=hf_token) |
|
else: |
|
print("HF Token not found. Gated model download might fail.") |
|
|
|
|
|
def main(INPUT_FILE, OUTPUT_FILE, PROMPT_ID, PROMPT_TOPIC, MULTI_PV, NUM_WORDS, NUM_TOKENS, BEAM_WIDTH, ALPHA_MODE, START_TURN, |
|
MODEL_TAG): |
|
|
|
|
|
|
|
MODEL_NAME = "Qwen/Qwen3-1.7B" |
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cpu" |
|
TORCH_DTYPE = torch.float32 |
|
DEPTH_RANGE = 1 |
|
|
|
|
|
|
|
if not os.path.exists(INPUT_FILE): |
|
print(f"Warning: Input file '{INPUT_FILE}' not found. Creating a dummy file.") |
|
with open(INPUT_FILE, 'w', encoding='utf-8') as f: |
|
f.write("The quick brown fox jumps over the lazy dog") |
|
|
|
|
|
textoutput = "" |
|
MODEL_CONTEXT_WINDOW = 128_000 |
|
SAFETY_THRESHOLD = 2_000 |
|
MAX_INPUT_TOKENS = MODEL_CONTEXT_WINDOW - SAFETY_THRESHOLD |
|
|
|
|
|
print("Step 1: Loading model...") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=TORCH_DTYPE, |
|
trust_remote_code=True, |
|
token=hf_token |
|
).to(DEVICE) |
|
print(f" Model loaded to {DEVICE}.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hook_handle = model.lm_head.register_forward_hook(capture_logits_hook) |
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
print(" Quantization complete. Model is ready for inference.\n") |
|
|
|
print("Step 3: Loading tokenizer...") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True, token=hf_token) |
|
if tokenizer.pad_token is None: |
|
print(" Tokenizer missing pad token; setting pad_token = eos_token") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
if hasattr(model, 'config'): |
|
model.config.pad_token_id = tokenizer.eos_token_id |
|
print(" Tokenizer loaded and configured.\n") |
|
|
|
|
|
print("Step 4: Prompting user for inputs...") |
|
|
|
promptID = PROMPT_ID |
|
|
|
MultiPV = MULTI_PV |
|
|
|
LegalNumberOfMove = NUM_WORDS |
|
|
|
EngineID = f"Qwen/Qwen3-1.7B" |
|
|
|
|
|
|
|
Depth = 1 |
|
print(" User inputs captured.\n") |
|
|
|
|
|
print("Step 5: Pre-tokenizing input sequence...") |
|
initial_prompt = "Complete successive parts of a sentence given one word at a time:" |
|
initial_prompt_ids = tokenizer.encode(initial_prompt, add_special_tokens=False) |
|
|
|
print(f" Reading words from {INPUT_FILE}...") |
|
lines = [] |
|
try: |
|
with open(INPUT_FILE, 'r', encoding='utf-8') as f: |
|
|
|
lines = f.readlines() |
|
words_from_file = "".join(line.replace('\n', '') for line in lines) |
|
wordList = re.split(r'([a-zA-Z]+|\d+)', words_from_file) |
|
wordList = [x for x in wordList if x != ' ' and x != ''] |
|
|
|
|
|
numChars = 0 |
|
numTextTokens = len(wordList) |
|
for word in wordList: |
|
numChars += len(word) |
|
avgTokenLength = round(numChars/numTextTokens, 4) |
|
print(f"\nFound {numTextTokens} text word/tokens with average length {avgTokenLength}.\n") |
|
|
|
except FileNotFoundError: |
|
print(f"Error: Input file '{INPUT_FILE}' not found. Exiting.") |
|
exit() |
|
|
|
all_tokens = list(initial_prompt_ids) |
|
word_end_indices = [len(initial_prompt_ids)] |
|
processed_words = [] |
|
|
|
print(" Tokenizing words and building full sequence...") |
|
for word in wordList: |
|
word_tokens = tokenizer.encode(" " + word, add_special_tokens=False) |
|
all_tokens.extend(word_tokens) |
|
word_end_indices.append(len(all_tokens)) |
|
processed_words.append(word) |
|
|
|
full_token_tensor = torch.tensor(all_tokens, dtype=torch.long).unsqueeze(0) |
|
print(f" Pre-tokenized {len(processed_words)} words into a sequence of {len(all_tokens)} tokens.\n") |
|
|
|
num_words_to_process = min(len(processed_words), LegalNumberOfMove) - (START_TURN - 1) |
|
if num_words_to_process < len(processed_words) - (START_TURN - 1): |
|
print(f" Will process the first {num_words_to_process} words due to NUM_WORDS limit.\n") |
|
elif num_words_to_process == 0: |
|
print(" Warning: No words to process based on input file or limits.\n") |
|
|
|
|
|
print("Step 8: Preparing output file header...") |
|
header_lines = [ |
|
f'[PromptID "{promptID}"]\n', |
|
f'[EngineID "{EngineID}"]\n', |
|
f'[MultiPV "{MultiPV}"]\n', |
|
f'[DepthRange "1:1"]\n\n', |
|
] + lines + [f'\n\n'] |
|
print(" Header prepared.\n") |
|
|
|
|
|
print("Step 9: Entering main generation loop (using pre-tokenized slicing and greedy decoding)...\n") |
|
PrevEval = "n.a." |
|
start_time = time.time() |
|
current_time = start_time |
|
numMatchedWords = 0 |
|
numMatchedChars = 0 |
|
|
|
if num_words_to_process > 0: |
|
if (START_TURN > 1): |
|
OUTPUT_FILE = OUTPUT_FILE.split('.')[0]+"from"+str(START_TURN)+".lif" |
|
with open(OUTPUT_FILE, 'w', encoding='utf-8') as writer: |
|
print(" Writing header to output file...") |
|
writer.write(''.join(header_lines)) |
|
textoutput += ''.join(header_lines) |
|
print(" Header written. Starting word-by-word prediction.\n") |
|
|
|
for turnCount in range(START_TURN, START_TURN + num_words_to_process): |
|
current_word = processed_words[turnCount - 1].strip() |
|
|
|
|
|
slice_end_index = word_end_indices[turnCount - 1] |
|
slice_start_index = max(0, slice_end_index - MAX_INPUT_TOKENS) |
|
|
|
|
|
input_tensor = full_token_tensor[:, slice_start_index:slice_end_index] |
|
current_input_len = input_tensor.shape[1] |
|
|
|
|
|
input_tensor_dev = input_tensor.to(DEVICE) |
|
|
|
start_time_gen = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_tensor_dev, |
|
max_new_tokens=2, |
|
min_new_tokens=2, |
|
output_scores=True, |
|
return_dict_in_generate=True, |
|
do_sample=False, |
|
pad_token_id=tokenizer.pad_token_id, |
|
num_beams=BEAM_WIDTH, |
|
num_return_sequences=BEAM_WIDTH, |
|
|
|
temperature=None, |
|
top_k=None, |
|
top_p=None, |
|
|
|
) |
|
end_time_gen = time.time() |
|
gen_duration = end_time_gen - start_time_gen |
|
|
|
|
|
if (turnCount < START_TURN): |
|
print("Skipping turn", turnCount) |
|
turnCount += 1 |
|
continue |
|
|
|
|
|
|
|
|
|
logits_for_step = outputs.scores[ |
|
0] |
|
|
|
|
|
logits_for_greedy_path = logits_for_step[0] |
|
|
|
|
|
|
|
top_k_logits_values, top_k_logits_indices = torch.topk( |
|
logits_for_greedy_path, k=MultiPV, dim=-1 |
|
) |
|
|
|
|
|
top_k_logits_values = top_k_logits_values.tolist() |
|
top_k_logits_indices = top_k_logits_indices.tolist() |
|
|
|
|
|
top_k_tokens = [tokenizer.decode(tid) for tid in top_k_logits_indices] |
|
""" |
|
print(f"Top {MultiPV} Logits from greedy path (Token | Logit Value):") |
|
for i in range(MultiPV): |
|
token_str_cleaned = top_k_tokens[i].strip() |
|
print(f" - '{token_str_cleaned}': {top_k_logits_values[i]:.4f} (ID: {top_k_logits_indices[i]})") |
|
""" |
|
|
|
greedy_selected_token_id = outputs.sequences[0, -1].item() |
|
greedy_selected_token_str = tokenizer.decode(greedy_selected_token_id).strip() |
|
|
|
|
|
|
|
|
|
|
|
modelToken = reprat(top_k_tokens[0]) |
|
|
|
|
|
|
|
modelEval = round(top_k_logits_values[0], 4) |
|
|
|
|
|
|
|
|
|
print("Turn ", turnCount, " now matching text word ", current_word, " ...", end='', sep='') |
|
|
|
topNUMTvals, topNUMTindices = torch.topk(logits_for_greedy_path, k=NUM_TOKENS, dim=-1) |
|
topNUMTvalList = topNUMTvals.tolist() |
|
topNUMTindList = topNUMTindices.tolist() |
|
topNUMTtokens = [reprat(tokenizer.decode(tind)) for tind in topNUMTindList] |
|
matchingTextToken = "@@" |
|
|
|
textTokenIndex = 0 |
|
textTokenValue = 0 |
|
for tok in topNUMTtokens: |
|
|
|
if current_word.find("Joki") >= 0 and tok.find("J") >= 0: |
|
print("Why doesn't", current_word, "match", tok, "at index", textTokenIndex, "?") |
|
if borderedMatch(current_word, tok, True, True): |
|
matchingTextToken = tok |
|
textTokenValue = topNUMTvalList[textTokenIndex] |
|
if math.isinf(textTokenValue) and textTokenValue < 0.0: |
|
textTokenValue = MINUS_INF |
|
else: |
|
textTokenValue = round(textTokenValue,4) |
|
if textTokenIndex == 0: |
|
print("***matches top model token", modelToken, "with score ", textTokenValue) |
|
numMatchedWords += 1 |
|
numMatchedChars += len(current_word) |
|
else: |
|
print("found at index", textTokenIndex, "in token", matchingTextToken, "with score ", textTokenValue, "; top is ", modelToken, modelEval) |
|
break |
|
textTokenIndex += 1 |
|
|
|
if textTokenIndex >= NUM_TOKENS: |
|
textTokenValue = round(topNUMTvalList[-1], 4) |
|
print("not found, using bottom score", textTokenValue) |
|
|
|
|
|
NextEval = textTokenValue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_stem = initial_prompt + " " + " ".join(processed_words[:turnCount]) |
|
lines = [ |
|
f'[PID "{promptID}"]\n', |
|
f'[EID "{MODEL_NAME}"]\n', |
|
f'[Turn "{turnCount}-w"]\n', |
|
f'[TextToken "@{current_word}@"]\n', |
|
f'[ModelToken "{modelToken}"]\n', |
|
f'[TextTokenIndex "{textTokenIndex}"]\n' |
|
f'[TextTokenValue "{textTokenValue}"]\n' |
|
f'[Eval "{modelEval}"]\n', |
|
f'[PrevEval "{PrevEval}"]\n', |
|
f'[NextEval "{NextEval}"]\n', |
|
f'[Depth "{Depth}"]\n', |
|
f'[STEM "{current_stem}"]\n', |
|
f'[NumLegalMoves "{MultiPV}"]\n', |
|
"---------------\n", |
|
f"{DEPTH_RANGE}\n", |
|
"---------------\n" |
|
] |
|
for token_str, logit_val in zip(top_k_tokens, top_k_logits_values): |
|
rep = reprat(token_str) |
|
|
|
lines.append(f"{rep} {logit_val:.4f}\n") |
|
|
|
lines.append( |
|
"===========================================================================================================\n") |
|
lines.append(f"[Comments]\n") |
|
lines.append(f"[EndMove]\n\n") |
|
|
|
|
|
|
|
|
|
|
|
writer.write(''.join(lines)) |
|
textoutput += ''.join(lines) |
|
|
|
|
|
|
|
PrevEval = modelEval |
|
|
|
|
|
status_interval = min(100, num_words_to_process // 2 if num_words_to_process >= 10 else 10) |
|
if turnCount % status_interval == 0 or turnCount == num_words_to_process: |
|
last_time = current_time |
|
current_time = time.time() |
|
elapsed = current_time - start_time |
|
elapsedLast = current_time - last_time |
|
rate = (turnCount - 1) / elapsed if elapsed > 0 else 0 |
|
rateLast = 100.0 / elapsedLast if elapsedLast > 0 else 0 |
|
print() |
|
print(f"Processed Turn {turnCount}. Rate: {rate:.2f} words/sec., last 100 rate: {rateLast:.2f}") |
|
|
|
|
|
averageCharsMatched = 0 if numMatchedWords == 0 else round(numMatchedChars/numMatchedWords, 4) |
|
matchPercent = 0.0 if numTextTokens == 0 else round(100.0*numMatchedWords/numTextTokens, 2) |
|
matchPercentStr = f"({matchPercent}%)" |
|
print("Done: matched", numMatchedWords, matchPercentStr, "tokens of average length", averageCharsMatched) |
|
print("from", numTextTokens, "tokens of average length", avgTokenLength) |
|
|
|
else: |
|
print("Skipping main generation loop as there are no words to process.") |
|
|
|
hook_handle.remove() |
|
print("Removed forward hook.") |
|
|
|
print("Step 10: Reporting final statistics...") |
|
total_time = time.time() - start_time |
|
avg_rate = (num_words_to_process / total_time) if total_time > 0 and num_words_to_process > 0 else 0 |
|
print(f" Total turns processed: {num_words_to_process}") |
|
print(f" Total time: {total_time:.2f} seconds") |
|
print(f" Average speed: {avg_rate:.2f} words/second") |
|
print(f" Output written to {OUTPUT_FILE}") |
|
|
|
|
|
print("\nCleaning up resources...") |
|
del model |
|
del tokenizer |
|
del full_token_tensor |
|
if 'outputs' in locals(): |
|
del outputs |
|
if 'input_tensor' in locals(): |
|
del input_tensor |
|
if 'input_tensor_dev' in locals(): |
|
del input_tensor_dev |
|
gc.collect() |
|
if DEVICE == 'cuda': |
|
print("Emptying CUDA cache...") |
|
torch.cuda.empty_cache() |
|
print("\nScript finished.") |
|
|
|
return textoutput |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def premain(ip): |
|
|
|
txtoutput = main("Romance.txt", "RomanceByQ06.lif", "KWR Romance ssay", "Essay on Romance", 1000, 1000, 10000, 1, 0, 1, "Qwen3-06B") |
|
return txtoutput |
|
|
|
demo = gr.Interface(fn=premain, inputs="text", outputs="text") |
|
|
|
demo.launch() |
|
|
|
|