machine-translation / llm_toolkit /translation_utils.py
dh-mc's picture
initial code for Chinese/English translation
3860729
raw
history blame
12.5 kB
import os
import re
import pandas as pd
import evaluate
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from tqdm import tqdm
print(f"loading {__file__}")
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
accuracy = evaluate.load("accuracy")
def extract_answer(text, debug=False):
if text:
# Remove the begin and end tokens
text = re.sub(
r".*?(assistant|\[/INST\]).+?\b", "", text, flags=re.DOTALL | re.MULTILINE
)
if debug:
print("--------\nstep 1:", text)
text = re.sub(r"<.+?>.*", "", text, flags=re.DOTALL | re.MULTILINE)
if debug:
print("--------\nstep 2:", text)
text = re.sub(
r".*?end_header_id\|>\n\n", "", text, flags=re.DOTALL | re.MULTILINE
)
if debug:
print("--------\nstep 3:", text)
return text
def calc_metrics(references, predictions, debug=False):
assert len(references) == len(
predictions
), f"lengths are difference: {len(references)} != {len(predictions)}"
predictions = [extract_answer(text) for text in predictions]
correct = [1 if ref == pred else 0 for ref, pred in zip(references, predictions)]
accuracy = sum(correct) / len(references)
results = {"accuracy": accuracy}
if debug:
correct_ids = [i for i, c in enumerate(correct) if c == 1]
results["correct_ids"] = correct_ids
results["meteor"] = meteor.compute(predictions=predictions, references=references)[
"meteor"
]
results["bleu_scores"] = bleu.compute(
predictions=predictions, references=references, max_order=4
)
results["rouge_scores"] = rouge.compute(
predictions=predictions, references=references
)
return results
def save_results(model_name, results_path, dataset, predictions, debug=False):
if not os.path.exists(results_path):
# Get the directory part of the file path
dir_path = os.path.dirname(results_path)
# Create all directories in the path (if they don't exist)
os.makedirs(dir_path, exist_ok=True)
df = dataset.to_pandas()
df.drop(columns=["text", "prompt"], inplace=True)
else:
df = pd.read_csv(results_path, on_bad_lines="warn")
df[model_name] = predictions
if debug:
print(df.head(1))
df.to_csv(results_path, index=False)
def load_translation_dataset(data_path, tokenizer=None):
train_data_file = data_path.replace(".tsv", "-train.tsv")
test_data_file = data_path.replace(".tsv", "-test.tsv")
if not os.path.exists(train_data_file):
print("generating train/test data files")
dataset = load_dataset(
"csv", data_files=data_path, delimiter="\t", split="train"
)
print(len(dataset))
dataset = dataset.filter(lambda x: x["chinese"] and x["english"])
datasets = dataset.train_test_split(test_size=0.2)
print(len(dataset))
# Convert to pandas DataFrame
train_df = pd.DataFrame(datasets["train"])
test_df = pd.DataFrame(datasets["test"])
# Save to TSV
train_df.to_csv(train_data_file, sep="\t", index=False)
test_df.to_csv(test_data_file, sep="\t", index=False)
print("loading train/test data files")
datasets = load_dataset(
"csv",
data_files={"train": train_data_file, "test": test_data_file},
delimiter="\t",
)
if tokenizer:
translation_prompt = "Please translate the following Chinese text into English and provide only the translated content, nothing else.\n{}"
def formatting_prompts_func(examples):
inputs = examples["chinese"]
outputs = examples["english"]
messages = [
{
"role": "system",
"content": "You are an expert in translating Chinese to English.",
},
None,
]
model_name = os.getenv("MODEL_NAME")
if "mistral" in model_name.lower():
messages = messages[1:]
texts = []
prompts = []
for input, output in zip(inputs, outputs):
prompt = translation_prompt.format(input)
messages[-1] = {"role": "user", "content": prompt}
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompts.append(prompt)
texts.append(prompt + output + tokenizer.eos_token)
return {"text": texts, "prompt": prompts}
datasets = datasets.map(
formatting_prompts_func,
batched=True,
)
print(datasets)
return datasets
def eval_model(model, tokenizer, eval_dataset):
total = len(eval_dataset)
predictions = []
for i in tqdm(range(total)):
inputs = tokenizer(
eval_dataset["prompt"][i : i + 1],
return_tensors="pt",
).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=4096, use_cache=False)
decoded_output = tokenizer.batch_decode(outputs)
debug = i == 0
decoded_output = [
extract_answer(output, debug=debug) for output in decoded_output
]
predictions.extend(decoded_output)
return predictions
def save_model(
model,
tokenizer,
include_gguf=True,
include_merged=True,
publish=True,
):
try:
token = os.getenv("HF_TOKEN") or None
model_name = os.getenv("MODEL_NAME")
save_method = "lora"
quantization_method = "q5_k_m"
model_names = get_model_names(
model_name, save_method=save_method, quantization_method=quantization_method
)
model.save_pretrained(model_names["local"])
tokenizer.save_pretrained(model_names["local"])
if publish:
model.push_to_hub(
model_names["hub"],
token=token,
)
tokenizer.push_to_hub(
model_names["hub"],
token=token,
)
if include_merged:
model.save_pretrained_merged(
model_names["local"] + "-merged", tokenizer, save_method=save_method
)
if publish:
model.push_to_hub_merged(
model_names["hub"] + "-merged",
tokenizer,
save_method="lora",
token="",
)
if include_gguf:
model.save_pretrained_gguf(
model_names["local-gguf"],
tokenizer,
quantization_method=quantization_method,
)
if publish:
model.push_to_hub_gguf(
model_names["hub-gguf"],
tokenizer,
quantization_method=quantization_method,
token=token,
)
except Exception as e:
print(e)
def get_metrics(df):
metrics_df = pd.DataFrame(df.columns.T)[2:]
metrics_df.rename(columns={0: "model"}, inplace=True)
metrics_df["model"] = metrics_df["model"].apply(lambda x: x.split("/")[-1])
metrics_df.reset_index(inplace=True)
metrics_df = metrics_df.drop(columns=["index"])
accuracy = []
meteor = []
bleu_1 = []
rouge_l = []
all_metrics = []
for col in df.columns[2:]:
metrics = calc_metrics(df["english"], df[col], debug=True)
print(f"{col}: {metrics}")
accuracy.append(metrics["accuracy"])
meteor.append(metrics["meteor"])
bleu_1.append(metrics["bleu_scores"]["bleu"])
rouge_l.append(metrics["rouge_scores"]["rougeL"])
all_metrics.append(metrics)
metrics_df["accuracy"] = accuracy
metrics_df["meteor"] = meteor
metrics_df["bleu_1"] = bleu_1
metrics_df["rouge_l"] = rouge_l
metrics_df["all_metrics"] = all_metrics
return metrics_df
def plot_metrics(metrics_df, figsize=(14, 5), ylim=(0, 0.44)):
plt.figure(figsize=figsize)
df_melted = pd.melt(
metrics_df, id_vars="model", value_vars=["meteor", "bleu_1", "rouge_l"]
)
barplot = sns.barplot(x="variable", y="value", hue="model", data=df_melted)
# Set different hatches for each model
hatches = ["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*", "//", "\\\\"]
# Create a dictionary to map models to hatches
model_hatches = {
model: hatches[i % len(hatches)]
for i, model in enumerate(metrics_df["model"].unique())
}
# Apply hatches based on the model
num_vars = len(df_melted["variable"].unique())
for i, bar in enumerate(barplot.patches):
model = df_melted["model"].iloc[i // num_vars]
bar.set_hatch(model_hatches[model])
# Manually update legend to match the bar hatches
handles, labels = barplot.get_legend_handles_labels()
for handle, model in zip(handles, metrics_df["model"].unique()):
handle.set_hatch(model_hatches[model])
barplot.set_xticklabels(["METEOR", "BLEU-1", "ROUGE-L"])
for p in barplot.patches:
if p.get_height() == 0:
continue
barplot.annotate(
f"{p.get_height():.2f}",
(p.get_x() + p.get_width() / 2.0, p.get_height()),
ha="center",
va="center",
xytext=(0, 10),
textcoords="offset points",
)
barplot.set(ylim=ylim, ylabel="Scores", xlabel="Metrics")
plt.legend(bbox_to_anchor=(0.5, -0.1), loc="upper center")
plt.show()
def plot_times(perf_df, ylim=0.421):
# Adjusted code to put "train-time" bars in red at the bottom
fig, ax1 = plt.subplots(figsize=(12, 10))
color_train = "tab:red"
color_eval = "orange"
ax1.set_xlabel("Models")
ax1.set_ylabel("Time (mins)")
ax1.set_xticks(range(len(perf_df["model"]))) # Set x-ticks positions
ax1.set_xticklabels(perf_df["model"], rotation=90)
# Plot "train-time" first so it's at the bottom
ax1.bar(
perf_df["model"],
perf_df["train-time(mins)"],
color=color_train,
label="train-time",
)
# Then, plot "eval-time" on top of "train-time"
ax1.bar(
perf_df["model"],
perf_df["eval-time(mins)"],
bottom=perf_df["train-time(mins)"],
color=color_eval,
label="eval-time",
)
ax1.tick_params(axis="y")
ax1.legend(loc="upper left")
if "meteor" in perf_df.columns:
ax2 = ax1.twinx()
color_meteor = "tab:blue"
ax2.set_ylabel("METEOR", color=color_meteor)
ax2.plot(
perf_df["model"],
perf_df["meteor"],
color=color_meteor,
marker="o",
label="meteor",
)
ax2.tick_params(axis="y", labelcolor=color_meteor)
ax2.legend(loc="upper right")
ax2.set_ylim(ax2.get_ylim()[0], ylim)
# Show numbers in bars
for p in ax1.patches:
height = p.get_height()
if height == 0: # Skip bars with height 0
continue
ax1.annotate(
f"{height:.2f}",
(p.get_x() + p.get_width() / 2.0, p.get_y() + height),
ha="center",
va="center",
xytext=(0, -10),
textcoords="offset points",
)
fig.tight_layout()
plt.show()
def translate_via_llm(text):
base_url = os.getenv("OPENAI_BASE_URL") or "http://localhost:8000/v1"
llm = ChatOpenAI(
model="gpt-4o",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
base_url=base_url,
)
prompt = ChatPromptTemplate.from_messages(
[
(
"human",
"Please translate the following Chinese text into English and provide only the translated content, nothing else.\n{input}",
),
]
)
chain = prompt | llm
response = chain.invoke(
{
"input": text,
}
)
return response.content
def translate(text, cache_dict):
if text in cache_dict:
return cache_dict[text]
else:
translated_text = translate_via_llm(text)
cache_dict[text] = translated_text
return translated_text