Edit model card

THaLLE: Text Hyperlocally Augmented Large Language Extension

❗NOTICE❗: KBTG-Labs/THaLLE-0.1-7B-fa is a WIP model checkpoint distributed for reproducing results in our Technical Report.

Training details

This model is a Qwen2-7B-Instruct fine-tuned on our Internal CFA Mock Exam 2009-2019 containing 9,426 Questions using LoRA.

Vocab Config Patching

Prior to training, we patched Qwen/Qwen2-7B-Instruct's tokenizer_config.json bos_token field from null to the start token "<|im_start|>".

{
    ...
    "bos_token": "<|im_start|>"
    ...
}

Results

For more details see our Technical Report.

Model Internal 2020 Internal 2024 Flare CFA*
APIs
gpt-3.5-turbo-0125 0.5458 0.5027 0.6366
gemini-1.5-flash-001 0.6271 0.6278 0.7355
gemini-1.5-pro-001 0.6780 0.6444 0.7829
gpt-4o-2024-05-13 0.8000 0.8055 0.8789
HF models
"meta-llama/Llama-2-7b-chat-hf" 0.3774 0.3639 0.4264
"google/gemma-7b-it" 0.5107 0.5333 0.6027
"meta-llama/Meta-Llama-3-8B-Instruct" 0.5424 0.5222 0.6386
"Qwen/Qwen2-7B-Instruct" 0.5740 0.5583 0.6831
"KBTG-Labs/THaLLE-0.1-7B-fa" 0.6678 0.6500 0.7171

[*] Flare CFA is "ChanceFocus/flare-cfa"

Usage

Requirements

Since KBTG-Labs/THaLLE-0.1-7B-fa is a fine-tuned of Qwen2-7B-Instruct you will need to install transformers>=4.37.0.

Reproducing results

Running the script below should give you this output:

Progress: 1032/1032 | Correct: 740 (71.71%)
import re
from typing import Literal, Optional

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID: str = "KBTG-Labs/THaLLE-0.1-7B-fa"
SYSTEM_PROMPT: str = """You are a CFA (chartered financial analyst) taking a test to evaluate your knowledge of finance. You will be given a question along with three possible answers (A, B, and C).
Indicate the correct answer (A, B, or C)."""
QUESTION_TEMPLATE: str = """Question:
{question}
A. {choice_a}
B. {choice_b}
C. {choice_c}"""


def format_flare_cfa(text: str) -> dict[str, str]:
    text = re.sub(r"\s+", " ", text)

    pattern = r"Q:\s*(.*?),\s*CHOICES:\s*A:\s*(.*?),\s*B:\s*(.*?),\s*C:\s*(.*)"
    match = re.search(pattern, text)
    if match:
        question, choice_a, choice_b, choice_c = match.groups()
        return {
            "question": question.strip(),
            "choice_a": choice_a.strip(),
            "choice_b": choice_b.strip(),
            "choice_c": choice_c.strip(),
        }
    else:
        raise ValueError("Input text does not match the expected format.")


def load_benchmark_dataset() -> list[dict[str, str]]:
    dataset = load_dataset("ChanceFocus/flare-cfa")["test"]
    prepared_dataset = []
    for d in dataset:
        entry = format_flare_cfa(d["text"])
        entry["answer"] = str(d["answer"]).upper()
        prepared_dataset.append(entry)
    return prepared_dataset


def extract_choice(
    response_text: str, choice_a: str, choice_b: str, choice_c: str
) -> Optional[Literal["A", "B", "C"]]:
    def clean(text: str) -> str:
        return text.replace("–", "-").strip().replace("\n", "")

    find_choice = re.findall(
        r"([T|t]he correct answer is[.|:]? [ABC]|[A|a]nswer[.|:]?[is]?\W+?\n?[ABC]\s)",
        response_text,
    )

    if find_choice:
        return clean(find_choice[0])[-1]

    if len(response_text) == 1 and response_text in "ABC":
        return response_text

    find_choice = re.findall(r"[ABC][.]\s?", response_text)
    if find_choice:
        return find_choice[0][0]

    choice = {"A": choice_a, "B": choice_b, "C": choice_c}

    for ch, content in choice.items():
        if clean(content) in clean(response_text):
            return ch

    return None


def inference(messages: list[dict[str, str]], model, tokenizer) -> str:
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=768,
        do_sample=False,
        temperature=None,
        top_p=None,
        top_k=None,
    )
    generated_ids = [
        output_ids[len(input_ids) :]
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response


def run_benchmark(dataset: list[dict[str, str]], model, tokenizer):
    total_correct = 0

    for i, problem in enumerate(dataset, start=1):
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": QUESTION_TEMPLATE.format(**problem)},
        ]
        output_text = inference(messages, model, tokenizer)
        prediction = extract_choice(
            output_text,
            problem["choice_a"],
            problem["choice_b"],
            problem["choice_c"],
        )

        correct = problem["answer"] == prediction
        total_correct += correct
        percent = total_correct / i * 100

        print(
            f"Progress: {i}/{len(dataset)} | Correct: {total_correct} ({percent:.2f}%)",
            end="\r",
        )


if __name__ == "__main__":
    dataset = load_benchmark_dataset()
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    run_benchmark(dataset, model, tokenizer)

Citation

If you find our work useful, please cite:

@misc{labs2024thalle,
      title={THaLLE: Text Hyperlocally Augmented Large Language Extension -- Technical Report}, 
      author={KBTG Labs and Danupat Khamnuansin and Atthakorn Petchsod and Anuruth Lertpiya and Pornchanan Balee and Thanawat Lodkaew and Tawunrat Chalothorn and Thadpong Pongthawornkamol and Monchai Lertsutthiwong},
      year={2024},
      eprint={2406.07505},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
Downloads last month
126
Safetensors
Model size
7.62B params
Tensor type
BF16
·