witness_reliability_run1_merged
This model is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.1 on the latest of labeled dataset(https://git.enigmalabs.io/data-science-playground/model-data/-/tree/master/models/witness_reliability?ref_type=heads).
Model description
More information needed
Intended uses & limitations
Usage
Preparation
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from functools import partial
import torch
from datasets import Dataset
import re
def read_data_file(data_path):
if '.json' in data_path:
try:
df = pd.read_json(data_path)
except ValueError:
df = pd.read_json(data_path, lines=True)
elif '.csv' in data_path:
df = pd.read_csv(data_path)
else:
raise ValueError(f'data_path file type unknown. End file name in .json or .csv')
print(df.head())
return df
# transform to test instruct/prompt format
def create_test_prompt_format(sample, args):
system_message = args.system_message
user_message = system_message + f" Sighting report: {sample['text']}"
sample['chat'] = [
{"role": "user", "content": user_message},
]
return sample
def preprocess_batch(batch, tokenizer, max_length, device='cuda'):
return tokenizer(
batch["chat"],
max_length=max_length,
truncation=True,
return_tensors="pt",
padding=True
)
def preprocess_test(tokenizer: AutoTokenizer, max_length: int, df: pd.DataFrame, args):
# Format each prompt.
print("Preprocessing dataset...")
ds = Dataset.from_pandas(df)
ds = ds.map(lambda row: create_test_prompt_format(row, args))
ds = ds.map(lambda row: {"chat": tokenizer.apply_chat_template(row['chat'], tokenize=False, add_generation_prompt=True)})
device = "cuda" if torch.cuda.is_available() else "cpu"
_preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer, device=device)
ds = ds.map(
_preprocessing_function,
batched=True,
)
return ds
def get_max_length(model, args):
conf = model.config
max_length = None
for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
max_length = getattr(model.config, length_setting, None)
if max_length:
break
if not max_length or max_length > args.max_seq_length: # cap at args.max_seq_length to avoid long text causing CUDA OOM:
max_length = args.max_seq_length
return max_length
def predict(model, tokenizer, df, labels, args, task_type='text-generation'):
model.config.use_cache = True # for inference to speed up predictions for similar inputs
y_pred = []
ds = preprocess_test(tokenizer, get_max_length(model, args), df, args)
pipe = pipeline(task_type, model=model, tokenizer=tokenizer, temperature=0.0, max_new_tokens=15)
for i in range(len(ds)):
prompt = ds[i]['chat']
result = pipe(prompt, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
answer = result[0]['generated_text'][len(prompt):].strip()
if answer == 'a':
y_pred.append('average')
elif answer == 'question':
y_pred.append('questionable')
elif answer == 're':
y_pred.append('reliable')
elif answer == 'second':
y_pred.append('second')
else:
y_pred.append('none')
y_pred.append(answer)
df['prediction'] = y_pred
return df
class ScriptArguments():
def __init__(self):
self.max_seq_length = 3072 # Max sequence length for model and packing of the dataset
self.system_message = (
"""You are a UFO sighting investigator. Users will give you a sighting report and you will classify the reliability of the witness involved in"""
""" the sighting using one the following labels: 'reliable': The witness(es) is of reliable backgrounds such as military, veteran, aerospace specialist, """
"""scientist, pilots, law enforcement. 'questionable': The witness is not an adult (under 16 years old) or is noticeably old with deteriorated capabilities to"""
""" judge (above 70 years old). 'second': The report is second-handed. The writer of the report is not the witness. 'average': The background of the witness"""
""" is not of any other categories or is not specified. Reply with only one of the labels in ['reliable', 'average', 'questionable', 'secondhand']."""
)
Prediction
# prepare data as a pd dataframe with the sighting description in the 'text' field
df_test = read_data_file($PATH_TO_DAT_FILE)
script_args = ScriptArguments()
merged_model_name = "e-labs/witness_reliability_ft_mistral_7b_v0.1_instruct"
task_type = 'text-generation'
tokenizer = AutoTokenizer.from_pretrained(merged_model_name)
model = AutoModelForCausalLM.from_pretrained(merged_model_name)
df = predict(model, tokenizer, df_test, labels=None, args=script_args)
Notice in the prediction
function we map the LLM outputs as
answer | inference |
---|---|
a | average |
question | questionable |
re | reliable |
second | second-hand |
all else | average |
Since the model is fundamentally a LLM, it might generate texts that are not in the defined set of values ['a', 'question', 're', 'second']
.
In those cases, default to average
, as indicated by the "all else" in the table above.
Training and evaluation data
https://wandb.ai/enigmalabs/witness_reliability_ft_mistral_instruct_v0.1/runs/0skl7iac
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 0.0002
- train_batch_size: 1
- eval_batch_size: 8
- seed: 42
- gradient_accumulation_steps: 2
- total_train_batch_size: 2
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: constant
- lr_scheduler_warmup_ratio: 0.03
- num_epochs: 3
Training results
https://wandb.ai/enigmalabs/witness_reliability_ft_mistral_instruct_v0.1/runs/2etycpye
Accuracy Metrics
- Accuracy: 0.958
- Accuracy for label questionable: 1.000
- Accuracy for label second: 0.941
- Accuracy for label reliable: 0.958
- Accuracy for label average: 0.933
Classification Report:
label | precision | recall | f1-score | support |
---|---|---|---|---|
average | 0.97 | 0.93 | 0.95 | 30 |
none | 0.00 | 0.00 | 0.00 | 0 |
questionable | 0.97 | 1.00 | 0.98 | 30 |
reliable | 0.92 | 0.96 | 0.94 | 24 |
second | 1.00 | 0.94 | 0.97 | 34 |
accuracy | 0.96 | 118 | ||
macro avg | 0.77 | 0.77 | 0.77 | 118 |
weighted avg | 0.97 | 0.96 | 0.96 | 118 |
Framework versions
- PEFT 0.7.2.dev0
- Transformers 4.36.2
- Pytorch 2.1.2+cu121
- Datasets 2.16.1
- Tokenizers 0.15.1
- Downloads last month
- 0
Model tree for e-labs/witness_reliability_ft_mistral_7b_v0.1_instruct
Base model
mistralai/Mistral-7B-v0.1