shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
from unsloth import FastLanguageModel
import torch
from health_classifier import classifier
max_seq_length = 8192
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "/home/mshahidul/readctrl_model/RL_model/readability_sft_lora_model",
max_seq_length = max_seq_length,
load_in_4bit = False, # Set to False if you have enough VRAM
fast_inference = False,
)
# Simply enable gradient checkpointing and prepare for training
model = FastLanguageModel.for_training(model)
# /home/mshahidul/readctrl/data/extracting_subclaim/extracted_subclaims_multiclinsum_test_en_full.json
with open("/home/mshahidul/readctrl/data/extracting_subclaim/extracted_subclaims_multiclinsum_test_en_full.json", "r") as f:
import json
data = json.load(f)
from datasets import Dataset
dataset = Dataset.from_list(data)
with open('/home/mshahidul/readctrl/code/RL_model/prompt', 'r') as f:
prompt_template = f.read()
dataset = dataset.map(lambda x: {
"prompt" : [
{"role": "system", "content": prompt_template},
{"role": "user", "content": f'''
- Input Language: English
- Gold Summary (the anchor reference summary): {x['summary']}
- Source Text (detailed content): {x['fulltext']}
'''},
],
"answer": {
"fulltext_subclaims": x['fulltext_subclaims'],
"summary_subclaims": x['summary_subclaims'],
},
})
import requests
import json
import re
from claim_verifier import MedicalClaimVerifier
verifier = MedicalClaimVerifier()
def claim_reward_func(prompts, completions, answer, **kwargs):
# import ipdb; ipdb.set_trace()
"""
GRPO reward function.
Expects 'summary_subclaims' and 'fulltext_subclaims' to be in the dataset.
"""
rewards = []
# We loop through the group of completions
for i in range(len(completions)):
reward = verifier.get_reward_score(
completions[i],
answer[i]["summary_subclaims"],
answer[i]["fulltext_subclaims"]
)
rewards.append(reward)
return rewards
# def format_reward_func(completions, **kwargs):
# required_keys = ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]
# scores = []
# for completion in completions:
# try:
# match = re.search(r"<SOLUTION>(.*?)</SOLUTION>", completion, re.DOTALL)
# content = match.group(1) if match else completion
# data = json.loads(content)
# if all(k in data for k in required_keys):
# scores.append(2.0)
# else:
# scores.append(-1.0)
# except:
# scores.append(-2.0)
# return scores
import json
def literacy_classifier_reward_func(completions, **kwargs):
scores = []
for completion in completions:
try:
# 1. Clean up potential Markdown formatting
cleaned_content = completion[0]['content'].strip()
if cleaned_content.startswith("```"):
# Removes leading ```json or ``` and trailing ```
cleaned_content = cleaned_content.split("```")[1]
if cleaned_content.startswith("json"):
cleaned_content = cleaned_content[4:]
# 2. Parse the JSON
data = json.loads(cleaned_content.strip())
alignment_score = 0.0
target_labels = ["low", "intermediate", "proficient"]
for label in target_labels:
key = f"{label}_health_literacy"
text_to_test = data.get(key, "")
if text_to_test:
# Run the DSPy classifier
result = classifier(summary_text=text_to_test)
predicted = result.label # Expected format: "low_health_literacy"
# import ipdb; ipdb.set_trace()
if predicted == key:
alignment_score += 1.0
else:
# Soft penalty for misclassification
alignment_score -= 0.5
else:
# Penalty if a specific literacy level is missing from the JSON
alignment_score -= 0.3
scores.append(alignment_score)
except (json.JSONDecodeError, Exception):
# Significant penalty for malformed JSON or failed processing
scores.append(-1.0)
return scores
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
learning_rate = 5e-6,
lr_scheduler_type = "cosine",
weight_decay = 0.1,
max_prompt_length = 8192,
max_completion_length = 4096,
# num_of_epochs = 10,
num_generations = 4, # GRPO group size
per_device_train_batch_size = 4,
gradient_accumulation_steps = 4,
max_steps = 500,
bf16 = True,
output_dir = "medical_grpo_outputs",
)
trainer = GRPOTrainer(
model = model,
reward_funcs = [
claim_reward_func,
# format_reward_func,
literacy_classifier_reward_func
],
args = training_args,
train_dataset = dataset, # Use the same dataset from your SFT prep
tokenizer = tokenizer,
)
trainer.train()
model.save_pretrained("/home/mshahidul/readctrl_model/readability_GRPO_model_v1")
tokenizer.save_pretrained("/home/mshahidul/readctrl_model/readability_GRPO_model_v1")