mshahidul
Initial commit of readCtrl code without large models
030876e
import os
import logging
# Avoid TypeError in transformers deprecation warning (message contains '%', extra args break %-formatting)
for _logger_name in ("transformers", "transformers.modeling_attn_mask_utils", "transformers.utils.logging"):
logging.getLogger(_logger_name).setLevel(logging.ERROR)
# If a handler still hits the buggy warning, don't crash the script
logging.raiseExceptions = False
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import json
from datetime import datetime
import torch
from datasets import Dataset
from unsloth import FastLanguageModel
from trl import SFTConfig, SFTTrainer
model_name = "unsloth/Llama-3.2-3B-Instruct"
data_path = "/home/mshahidul/readctrl/code/text_classifier/bn/testing_bn_full.json"
test_size = 0.2 # 1 - train_ratio (0.8), same as Gemma script
seed = 42
prompt_language = "bn" # "bn" (Bangla) or "en" (English)
run_mode = "finetune_and_eval" # "finetune_and_eval" or "eval_base_only"
save_fp16_merged = False # whether to save merged fp16 model after finetuning
max_seq_length = 4096
load_in_4bit = False
def get_model_size_from_name(name):
base = name.split("/")[-1]
for part in base.split("-"):
token = part.lower()
if token.endswith("b") or token.endswith("m"):
return part
return "unknown"
model_size = get_model_size_from_name(model_name)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False,
).removeprefix("<|begin_of_text|>")
for convo in convos
]
return {"text": texts}
def build_classification_user_prompt(fulltext, gen_text):
# Input: fulltext + gen_text, Output: label
if prompt_language == "en":
return (
"You will be given a medical case description (full text) and a generated summary. "
"Classify the patient's health literacy level.\n\n"
f"Full text:\n{fulltext}\n\n"
f"Generated text:\n{gen_text}\n\n"
"Reply with exactly one label from this set:\n"
"low_health_literacy, intermediate_health_literacy, high_health_literacy"
)
# Bangla (default)
return (
"আপনাকে একটি মেডিকেল কেসের পূর্ণ বর্ণনা (full text) এবং তৈরি করা সারাংশ (generated text) দেওয়া হবে। "
"রোগীর স্বাস্থ্যজ্ঞান (health literacy) কোন স্তরের তা নির্ধারণ করুন।\n\n"
f"Full text:\n{fulltext}\n\n"
f"Generated text:\n{gen_text}\n\n"
"শুধু নিচের সেট থেকে একটি লেবেল দিয়ে উত্তর দিন:\n"
"low_health_literacy, intermediate_health_literacy, high_health_literacy"
)
def build_classification_examples(raw_records):
examples = []
for record in raw_records:
fulltext = record.get("fulltext", "")
gen_text = record.get("gen_text", "")
label = (record.get("label") or "").strip()
if not label:
continue
user_prompt = build_classification_user_prompt(fulltext, gen_text)
examples.append(
{
"conversations": [
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": label},
],
}
)
return examples
def generate_prediction(user_prompt):
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": user_prompt}],
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
temperature=0.0,
use_cache=True,
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# 1. Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=None,
load_in_4bit=load_in_4bit,
)
# 2. Add LoRA adapters (kept same as original Llama script)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=seed,
)
# 3. Data preparation (same dataset split and prompt style as Gemma script)
with open(data_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
raw_dataset = Dataset.from_list(raw_data)
split_dataset = raw_dataset.train_test_split(test_size=test_size, seed=seed, shuffle=True)
train_raw = split_dataset["train"]
test_raw = split_dataset["test"]
train_examples = build_classification_examples(train_raw)
train_dataset = Dataset.from_list(train_examples)
train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
# 4. Optional finetuning
if run_mode == "finetune_and_eval":
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
args=SFTConfig(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=60,
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=seed,
output_dir="outputs",
report_to="none",
),
)
trainer.train()
save_dir = f"/home/mshahidul/readctrl_model/text_classifier_bn/{model_name.split('/')[-1]}"
os.makedirs(save_dir, exist_ok=True)
if save_fp16_merged:
model.save_pretrained_merged(save_dir, tokenizer, save_method="merged_16bit")
tokenizer.save_pretrained(save_dir)
elif run_mode == "eval_base_only":
# No finetuning; evaluate base model
save_dir = f"BASE_MODEL:{model_name}"
else:
raise ValueError(f"Unsupported run_mode: {run_mode}")
# 5. Test-set inference + accuracy (same pattern and folders as Gemma script)
FastLanguageModel.for_inference(model)
model.eval()
model_info_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/model_info"
ablation_dir = "/home/mshahidul/readctrl/code/text_classifier/bn/ablation_studies"
os.makedirs(model_info_dir, exist_ok=True)
os.makedirs(ablation_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = model_name.split("/")[-1].replace(".", "_")
def evaluate_classification_mode(test_split):
results = []
total = 0
correct = 0
for idx, sample in enumerate(test_split):
fulltext = sample.get("fulltext", "")
gen_text = sample.get("gen_text", "")
gold_label = (sample.get("label") or "").strip()
if not gold_label:
continue
user_prompt = build_classification_user_prompt(fulltext, gen_text)
pred_text = generate_prediction(user_prompt)
pred_label = (pred_text or "").strip()
total += 1
is_correct = pred_label == gold_label
if is_correct:
correct += 1
results.append(
{
"sample_index": idx,
"fulltext": fulltext,
"gen_text": gen_text,
"gold_label": gold_label,
"predicted_label": pred_label,
"correct": is_correct,
}
)
accuracy = correct / total if total else 0.0
metrics = {
"mode": "fulltext_gen_text_classification",
"model_name": model_name,
"model_save_dir": save_dir,
"dataset_path": data_path,
"prompt_language": prompt_language,
"seed": seed,
"test_size": test_size,
"examples_evaluated": total,
"accuracy": accuracy,
"timestamp": timestamp,
}
return results, metrics
results, accuracy_summary = evaluate_classification_mode(test_raw)
accuracy_summary["finetune_mode"] = "classification"
accuracy_summary["model_size"] = model_size
accuracy_summary["run_mode"] = run_mode
accuracy_summary["prompt_language"] = prompt_language
predictions_path = os.path.join(
model_info_dir,
f"{model_tag}_test_inference_{timestamp}.json",
)
accuracy_path = os.path.join(
ablation_dir,
f"{model_tag}_classification_{model_size}_{run_mode}_{timestamp}.json",
)
with open(predictions_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
with open(accuracy_path, "w", encoding="utf-8") as f:
json.dump(accuracy_summary, f, ensure_ascii=False, indent=2)
print(f"Saved test inference to: {predictions_path}")
print(f"Saved test accuracy to: {accuracy_path}")
print(f"Accuracy: {accuracy_summary.get('accuracy', 0.0):.4f}")