Snapshot
Browse files
app.py
CHANGED
|
@@ -15,6 +15,9 @@ class Word:
|
|
| 15 |
logprob: float
|
| 16 |
context: list[int]
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
|
| 19 |
words: list[Word] = []
|
| 20 |
current_word: list[int] = []
|
|
@@ -31,7 +34,7 @@ def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer)
|
|
| 31 |
|
| 32 |
for i, (token_id, logprob) in enumerate(token_probs):
|
| 33 |
token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
|
| 34 |
-
if not token
|
| 35 |
current_word.append(token_id)
|
| 36 |
current_log_probs.append(logprob)
|
| 37 |
else:
|
|
@@ -80,12 +83,13 @@ def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples:
|
|
| 80 |
outputs = model.generate(
|
| 81 |
input_ids=input_ids,
|
| 82 |
attention_mask=attention_mask,
|
| 83 |
-
|
| 84 |
num_return_sequences=num_samples,
|
| 85 |
temperature=1.0,
|
| 86 |
top_k=50,
|
| 87 |
top_p=0.95,
|
| 88 |
do_sample=True
|
|
|
|
| 89 |
)
|
| 90 |
return outputs
|
| 91 |
|
|
@@ -96,8 +100,8 @@ def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer:
|
|
| 96 |
for j in range(num_samples):
|
| 97 |
generated_ids = outputs[i * num_samples + j][input_len:]
|
| 98 |
new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
|
| 99 |
-
if new_word
|
| 100 |
-
replacements.append(new_word)
|
| 101 |
all_new_words.append(replacements)
|
| 102 |
return all_new_words
|
| 103 |
|
|
@@ -105,20 +109,18 @@ def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer:
|
|
| 105 |
|
| 106 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 107 |
|
| 108 |
-
model_name = "mistralai/Mistral-7B-v0.1"
|
|
|
|
| 109 |
model, tokenizer = load_model_and_tokenizer(model_name, device)
|
| 110 |
|
| 111 |
#%%
|
| 112 |
-
|
| 113 |
input_text = "He asked me to prostrate myself before the king, but I rifused."
|
| 114 |
inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
|
| 115 |
|
| 116 |
#%%
|
| 117 |
-
|
| 118 |
token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
|
| 119 |
|
| 120 |
#%%
|
| 121 |
-
|
| 122 |
words = split_into_words(token_probs, tokenizer)
|
| 123 |
log_prob_threshold = -5.0
|
| 124 |
low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
|
|
@@ -129,7 +131,6 @@ inputs = prepare_inputs(contexts, tokenizer, device)
|
|
| 129 |
input_ids = inputs["input_ids"]
|
| 130 |
|
| 131 |
#%%
|
| 132 |
-
|
| 133 |
num_samples = 5
|
| 134 |
start_time = time.time()
|
| 135 |
outputs = generate_outputs(model, inputs, num_samples)
|
|
@@ -140,13 +141,11 @@ print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
|
|
| 140 |
replacements_batch = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
|
| 141 |
|
| 142 |
#%%
|
| 143 |
-
|
| 144 |
for word, replacements in zip(low_prob_words, replacements_batch):
|
| 145 |
print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
|
| 146 |
print(f"Proposed replacements: {replacements}")
|
| 147 |
|
| 148 |
# %%
|
| 149 |
-
|
| 150 |
generated_ids = outputs[:, input_ids.shape[-1]:]
|
| 151 |
for g in generated_ids:
|
| 152 |
print(tokenizer.convert_ids_to_tokens(g.tolist()))
|
|
|
|
| 15 |
logprob: float
|
| 16 |
context: list[int]
|
| 17 |
|
| 18 |
+
def starts_with_space(token: str) -> bool:
|
| 19 |
+
return token.startswith(chr(9601)) or token.startswith(chr(288))
|
| 20 |
+
|
| 21 |
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
|
| 22 |
words: list[Word] = []
|
| 23 |
current_word: list[int] = []
|
|
|
|
| 34 |
|
| 35 |
for i, (token_id, logprob) in enumerate(token_probs):
|
| 36 |
token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
|
| 37 |
+
if not starts_with_space(token) and token.isalpha():
|
| 38 |
current_word.append(token_id)
|
| 39 |
current_log_probs.append(logprob)
|
| 40 |
else:
|
|
|
|
| 83 |
outputs = model.generate(
|
| 84 |
input_ids=input_ids,
|
| 85 |
attention_mask=attention_mask,
|
| 86 |
+
max_new_tokens=4,
|
| 87 |
num_return_sequences=num_samples,
|
| 88 |
temperature=1.0,
|
| 89 |
top_k=50,
|
| 90 |
top_p=0.95,
|
| 91 |
do_sample=True
|
| 92 |
+
# num_beams=num_samples
|
| 93 |
)
|
| 94 |
return outputs
|
| 95 |
|
|
|
|
| 100 |
for j in range(num_samples):
|
| 101 |
generated_ids = outputs[i * num_samples + j][input_len:]
|
| 102 |
new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
|
| 103 |
+
if starts_with_space(new_word):
|
| 104 |
+
replacements.append(new_word[1:])
|
| 105 |
all_new_words.append(replacements)
|
| 106 |
return all_new_words
|
| 107 |
|
|
|
|
| 109 |
|
| 110 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 111 |
|
| 112 |
+
# model_name = "mistralai/Mistral-7B-v0.1"
|
| 113 |
+
model_name = "unsloth/Llama-3.2-1B"
|
| 114 |
model, tokenizer = load_model_and_tokenizer(model_name, device)
|
| 115 |
|
| 116 |
#%%
|
|
|
|
| 117 |
input_text = "He asked me to prostrate myself before the king, but I rifused."
|
| 118 |
inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
|
| 119 |
|
| 120 |
#%%
|
|
|
|
| 121 |
token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
|
| 122 |
|
| 123 |
#%%
|
|
|
|
| 124 |
words = split_into_words(token_probs, tokenizer)
|
| 125 |
log_prob_threshold = -5.0
|
| 126 |
low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
|
|
|
|
| 131 |
input_ids = inputs["input_ids"]
|
| 132 |
|
| 133 |
#%%
|
|
|
|
| 134 |
num_samples = 5
|
| 135 |
start_time = time.time()
|
| 136 |
outputs = generate_outputs(model, inputs, num_samples)
|
|
|
|
| 141 |
replacements_batch = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
|
| 142 |
|
| 143 |
#%%
|
|
|
|
| 144 |
for word, replacements in zip(low_prob_words, replacements_batch):
|
| 145 |
print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
|
| 146 |
print(f"Proposed replacements: {replacements}")
|
| 147 |
|
| 148 |
# %%
|
|
|
|
| 149 |
generated_ids = outputs[:, input_ids.shape[-1]:]
|
| 150 |
for g in generated_ids:
|
| 151 |
print(tokenizer.convert_ids_to_tokens(g.tolist()))
|