Pennlaine's picture
Update README.md
deca549 verified
metadata
tags:
  - autotrain
  - text-generation-inference
  - text-generation
  - peft
library_name: transformers
widget:
  - messages:
      - role: user
        content: What is your favorite condiment?
license: other
datasets:
  - Pennlaine/entity-extraction
pipeline_tag: text-generation

Model Introduction

This model is based on mistral-community/Mistral-7B-v0.2 and has been LoRA finetuned on an English language instruction dataset.

It aims to help extract keywords/entities from unstructured data into Json formatted data.

Traning Details

GPU: NVIDIA_L4 x 4, 100 GB Time: 5 mins Platform: Google Cloud Vertex AI

Prompt Format

<s>[INST]{question}Answer the question, extract the {entities}, and return in Json format.[/INST]```json

Since this corresponds to the training data, the model performs best when the prompt is in this format. You could also simply input your question and most of the time it works too!🥰

E.g:

Input

prompt = '''<s>[INST]Type 2 Diabetes Mellitus is a chronic metabolic disorder characterized by insulin resistance and relative insulin deficiency. This condition leads to chronic hyperglycemia, which can cause significant damage to various body systems over time. What is the abbreviation? Answer the question, extract the disorder, type of disorder, causes, effect and return in Json format.[/INST]```json'''

Output

[INST]Type 2 Diabetes Mellitus is a chronic metabolic disorder characterized by insulin resistance and relative insulin deficiency. This condition leads to chronic hyperglycemia, which can cause significant damage to various body systems over time. Write a short summary of how to treat patients with diabetes. Answer the question, extract the disorder, type of disorder, causes, effect, ICD Code, and return in Json format.[/INST]```json
{
    "question": "Type 2 Diabetes Mellitus is a chronic metabolic disorder characterized by insulin resistance and relative insulin deficiency. This condition leads to chronic hyperglycemia, which can cause significant damage to various body systems over time. Write a short summary of how to treat patients with diabetes.",
    "answer": "Treatment for diabetes typically involves a combination of lifestyle modifications, medication, and insulin therapy.",
    "entities": [
        {
            "Disorder": "Type 2 Diabetes Mellitus"
        },
        {
            "Type of Disorder": "Chronic Metabolic Disorder"
        },
        {
            "Causes": "Insulin Resistance and Relative Insulin Deficiency"
        },
        {
            "Effect": "Chronic Hyperglycemia"
        },
        {
            "ICD Code": "E11"
        }
    ]
}
```

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = "Pennlaine/Mistral-7B-v02-Entity-Extraction"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

prompt = '''<s>[INST]John Doe, a 45-year-old male, presents with a two-week history of progressive chest pain and shortness of breath. The chest pain is described as a tight, squeezing sensation located centrally and radiating to the left arm and jaw. It is aggravated by physical exertion and alleviated by rest. The patient reports associated symptoms of palpitations, diaphoresis, and nausea. He has a history of hypertension, hyperlipidemia, and type 2 diabetes mellitus. His family history is significant for myocardial infarction in his father at age 60. The patient has a 20-pack-year smoking history but quit 5 years ago. He occasionally consumes alcohol and leads a sedentary lifestyle with a diet high in processed foods and red meat. Current medications include metformin, lisinopril, atorvastatin, and aspirin.
Extract the name, age, symptoms, medical history, family history, father death age, medication history  and return in Json format.[/INST]```json'''

inputs = tokenizer.encode(prompt, return_tensors='pt')
outputs = model.generate(inputs, max_new_tokens=max_new_tokens, num_return_sequences=1)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(response)