Edit model card

witness_count_mistral_train_run5

This model is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.1 on the None dataset.

Model description

More information needed

Intended uses & limitations

Usage

Preparation

import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from functools import partial
import torch
from datasets import Dataset
import re

def read_data_file(data_path):
  if '.json' in data_path:
    try:
      df = pd.read_json(data_path)
    except ValueError:
      df = pd.read_json(data_path, lines=True)
  elif '.csv' in data_path:
    df = pd.read_csv(data_path)
  else:
    raise ValueError(f'data_path file type unknown. End file name in .json or .csv')

  print(df.head())

  return df

# transform to test instruct/prompt format
def create_test_prompt_format(sample, args):

  system_message = args.system_message

  user_message = system_message + f" Sighting report: {sample['text']}"

  sample['chat'] = [
      {"role": "user", "content": user_message},
    ]

  return sample

def preprocess_batch(batch, tokenizer, max_length, device='cuda'):

  return tokenizer(
         batch["chat"],
         max_length=max_length,
         truncation=True,
         return_tensors="pt",
         padding=True
        )

def preprocess_test(tokenizer: AutoTokenizer, max_length: int, df: pd.DataFrame, args):

    # Format each prompt.
    print("Preprocessing dataset...")
    ds = Dataset.from_pandas(df)
    ds = ds.map(lambda row: create_test_prompt_format(row, args))
    ds = ds.map(lambda row: {"chat": tokenizer.apply_chat_template(row['chat'], tokenize=False, add_generation_prompt=True)})

    device = "cuda" if torch.cuda.is_available() else "cpu"
    _preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer, device=device)

    ds = ds.map(
        _preprocessing_function,
        batched=True,
    )

    return ds

def find_number_in_string(s):
    # This regex matches one or more digits that follow 'Total count is ' and precede ' witnesses.'
    match = re.search(r"Total count is (\d+) witnesses", s)
    if match:
        return match.group(1)  # Returns the captured number as a string
    return None  # Returns None if no number is found

def get_max_length(model, args):
    conf = model.config
    max_length = None
    for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
        max_length = getattr(model.config, length_setting, None)
        if max_length:            
            break
    if not max_length or max_length > args.max_seq_length: # cap at args.max_seq_length to avoid long text causing CUDA OOM:
        max_length = args.max_seq_length
        
    return max_length


def predict(model, tokenizer, df, labels, args, task_type='text-generation'):

  model.config.use_cache = True  # for inference to speed up predictions for similar inputs

  y_pred = []

  ds = preprocess_test(tokenizer, get_max_length(model, args), df, args)

  pipe = pipeline(task_type, model=model, tokenizer=tokenizer, temperature=0.0, max_new_tokens=15)

  for i in range(len(ds)):
    prompt = ds[i]['chat']
    result = pipe(prompt, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
    answer = result[0]['generated_text'][len(prompt):]
    answer = find_number_in_string(answer)
    if not answer:
      answer = result[0]['generated_text'][len(prompt):]
    y_pred.append(answer)

  df['prediction'] = y_pred

  return df

class ScriptArguments():

  def __init__(self):
    self.max_seq_length = 3072 # Max sequence length for model and packing of the dataset
    self.system_message = (
    """You are a UFO sighting investigator. Users will give you a sighting report and you will infer the total count of human witnesses that saw the UFO. """
    """This should be an integer number larger than zero. If no such information is explicitly present, infer a witness count of 1. Do not count animals that reacted, or other"""
    """" people present at the time and location of the sighting but did not actually see the UFO. Terms like "we", or "several", or "other people" count as 3 people."""
    )
        

Prediction


# prepare data as a pd dataframe with the sighting description in the 'text' field

df_test = read_data_file($PATH_TO_DAT_FILE)

script_args = ScriptArguments()

merged_model_name = "e-labs/witness_count_ft_mistral_7b_v0.1_instruct"
task_type = 'text-generation'

tokenizer = AutoTokenizer.from_pretrained(merged_model_name)
model = AutoModelForCausalLM.from_pretrained(merged_model_name)

df = predict(model, tokenizer, df_test, labels=None, args=script_args)

The prediction is the 'prediction' column in the returned df. It is a string of the integer witness count, such as "1" or "2". Since the model is fundamentally a LLM, it has a non-zero possibility to generate texts that are not an integer number. In those cases, default to "1"

Training and evaluation data

https://wandb.ai/enigmalabs/witness_count_ft_mistral_instruct_v0.1/runs/vi6bu5bk

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 0.0002
  • train_batch_size: 1
  • eval_batch_size: 8
  • seed: 42
  • gradient_accumulation_steps: 2
  • total_train_batch_size: 2
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: constant
  • lr_scheduler_warmup_ratio: 0.03
  • num_epochs: 3

Training results

https://wandb.ai/enigmalabs/witness_count_ft_mistral_instruct_v0.1/runs/mqvknq8t\

Accuracy Metrics

Overall micro accuracy Accuracy: 0.875

Accuracy by witness count range

range support weight accuracy
1 132 0.498 0.902
2-5 123 0.464 0.862
6-10 7 0.026 0.714
>10 3 0.011 0.667

Macro accuracy: 0.7860640375884278 Weighted accuracy: 0.8754716981132076

Framework versions

  • PEFT 0.7.2.dev0
  • Transformers 4.36.2
  • Pytorch 2.1.2+cu121
  • Datasets 2.16.1
  • Tokenizers 0.15.1
Downloads last month
0
Safetensors
Model size
7.24B params
Tensor type
F32
·
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for e-labs/witness_count_ft_mistral_7b_v0.1_instruct

Adapter
(340)
this model