XumengWen commited on
Commit
90fdb52
·
1 Parent(s): 91af52a

update model card

Browse files
Files changed (1) hide show
  1. README.md +156 -0
README.md CHANGED
@@ -1,3 +1,159 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ license_link: https://github.com/microsoft/Industrial-Foundation-Models/blob/main/LICENSE
4
+
5
+ tags:
6
+ - llm
7
+ - transfer learning
8
+ - in-context learning
9
+ - tabular data
10
  ---
11
+
12
+ ## Model Summary
13
+
14
+ The model is finetuned on over 380 tabular datasets based on LLaMA-2, designed to process a variety of industrial data, including commerce, healthcare, energy, and sustainability. The model belongs to the IFMs family, including two versions [7B](https://huggingface.co/microsoft/LLaMA-2-7b-GTL-Delta) and [13B](https://huggingface.co/microsoft/LLaMA-2-13b-GTL-Delta).
15
+
16
+ The Industrial Foundation Model is designed to accept language format data samples from various domains as input prompts. The input prompt should contain relevant information for the task at hand, such as context data, specific task instructions, or direct questions. In response to the input prompts, the model generates predictive answers. Depending on the nature of the task instruction in the input, the model can support both classification and regression tasks.
17
+
18
+ Resources and Technical Documentation:
19
+
20
+ + [IFMs Microsoft Repo](https://github.com/microsoft/Industrial-Foundation-Models)
21
+ + [Paper](https://arxiv.org/abs/2310.07338)
22
+
23
+ ## Intended Uses
24
+
25
+ **Primary use cases**
26
+
27
+ This model is designed to process and analyze diverse tabular data from various industry sectors for accurate prediction of classification and regression tasks.
28
+
29
+ ### Tokenizer
30
+
31
+ LLaMA-2-GTL supports a vocabulary size of up to `32000` tokens, which is same as the base model LLaMA2.
32
+
33
+ ### Prompt Examples
34
+
35
+ Given the nature of the training data, the LLaMA-2-GTL series model is best suited for prompts using the prompt format as follows:
36
+ ```markdown
37
+ You are an expert in healthcare data analysis.
38
+ Based on the patient medical records, please predict the length of stay in the hospital.
39
+ I will supply multiple instances with features and the corresponding label for your reference.
40
+ Please refer to the table below for detailed descriptions of the features and label:
41
+ --- feature description ---
42
+ hemo: Indicator of hemoglobin count
43
+ hematocrit: Hematocrit level
44
+ neutrophils: Neutrophils count
45
+ sodium: Sodium level
46
+ glucose: Glucose level
47
+ bloodureanitro: Blood urea nitrogen level
48
+ creatinine: Creatinine level
49
+ bmi: Body Mass Index of patient
50
+ pulse: Pulse rate of patient
51
+ respiration: Respiration rate of patient
52
+ rcount: Count of patient visits
53
+ gender: Patient gender
54
+ dialysisrenalendstage: Indicator of end stage renal disease requiring dialysis
55
+ asthma: Indicator of asthma
56
+ irondef: Indicator of iron deficiency
57
+ pneum: Indicator of pneumonia
58
+ substancedependence: Indicator of substance dependence
59
+ psychologicaldisordermajor: Indicator of major psychological disorder
60
+ depress: Indicator of depression
61
+ psychother: Indicator of psychotherapy
62
+ fibrosisandother: Indicator of fibrosis and other similar conditions
63
+ malnutrition: Indicator of malnutrition
64
+ secondarydiagnosisnonicd9: Indicator of secondary diagnosis other than ICD9
65
+ facid: Identifier of facility where treatment was provided
66
+ vdate: Date of patient visit to hospital
67
+ discharged: Date of patient discharge
68
+ --- label description ---
69
+ lengthofstay: Length of patient stay at hospital in days
70
+ --- data ---
71
+ |hemo|hematocrit|neutrophils|sodium|glucose|bloodureanitro|creatinine|bmi|pulse|respiration|rcount|gender|dialysisrenalendstage|asthma|irondef|pneum|substancedependence|psychologicaldisordermajor|depress|psychother|fibrosisandother|malnutrition|secondarydiagnosisnonicd9|facid|vdate|discharged|lengthofstay|
72
+ |0.0|15.2|12.3|141.74|188.88|21.0|0.93|33.48|76|5|0|M|1|0|0|1|0|1|0|0|0|0|0|E|9/3/2012|9/11/2012|8|
73
+ |0.0|11.0|9.9|140.98|167.7|8.0|1.24|30.98|78|8|0|F|0|0|0|0|0|1|0|0|0|0|2|E|6/13/2012|6/16/2012|3|
74
+ |0.0|11.9|9.4|138.75|148.82|12.0|1.09|29.51|53|6|3|F|0|0|0|0|0|0|0|0|0|0|1|B|10/19/2012|10/24/2012|5|
75
+ |0.0|11.9|9.4|137.19|164.71|12.0|1.09|31.98|84|6|1|F|0|0|0|0|0|0|0|0|0|0|2|B|1/16/2012|1/18/2012|2|
76
+ |0.0|15.1|11.2|134.7|132.43|12.0|1.05|29.12|73|6|0|F|0|0|0|0|0|0|0|0|0|0|1|A|2/21/2012|2/22/2012|1|
77
+ |0.0|15.8|13.9|137.13|129.93|9.0|1.38|29.93|66|6|0|M|0|0|0|0|0|0|0|0|0|0|1|B|7/16/2012|7/18/2012|2|
78
+ |0.0|11.9|9.4|140.12|161.36|12.0|1.0|28.55|63|6|0|F|0|0|0|0|0|0|0|0|0|0|1|A|8/16/2012|8/17/2012|1|
79
+ |0.0|11.9|9.4|134.43|154.18|12.0|1.16|28.14|78|6|4|M|0|0|0|0|0|0|0|0|0|0|3|B|12/8/2012|12/14/2012|6|
80
+ |0.0|11.3|5.2|137.89|119.99|19.0|1.22|27.82|91|6|0|F|0|0|0|0|0|1|0|0|0|0|0|E|2/23/2012|2/26/2012|3|
81
+ |0.0|8.9|7.3|139.25|105.44|9.0|0.85|28.89|73|6|5+|M|0|0|0|0|0|0|0|0|0|0|4|B|7/18/2012|7/26/2012|8|
82
+ |1.0|8.1|5.6|138.4|103.73|21.0|1.26|29.05|74|6|5+|M|0|0|0|0|0|0|0|1|0|0|1|E|2/19/2012|2/28/2012|9|
83
+ |0.0|11.1|9.9|138.44|115.05|12.0|1.05|29.17|72|6|0|F|0|0|0|0|0|0|0|0|0|0|3|A|12/4/2012|12/5/2012|1|
84
+ |0.0|13.7|8.1|142.21|160.48|13.0|1.24|29.06|74|6|0|F|0|0|1|0|0|0|0|0|1|0|1|D|9/2/2012|9/5/2012|3|
85
+ |0.0|13.5|6.3|136.41|96.48|15.0|0.95|30.21|89|8|2|F|0|0|0|0|0|0|0|0|0|0|1|B|6/21/2012|6/27/2012|6|
86
+ |0.0|11.9|9.4|138.96|121.66|12.0|1.21|30.75|69|6|0|M|0|0|0|0|0|0|0|0|0|0|1|B|9/28/2012|9/29/2012|1|
87
+ |1.0|8.2|7.9|144.1|145.29|11.0|0.98|30.92|73|5|0|F|0|0|1|1|0|0|0|0|0|1|10|D|8/6/2012|8/12/2012|6|
88
+ |0.0|10.9|7.8|134.47|141.91|10.0|1.24|26.75|63|6|1|M|0|0|0|0|0|0|0|0|0|0|2|A|11/7/2012|11/9/2012|<MASK>|
89
+ Please use the supplied data to predict the <MASK> lengthofstay.
90
+ Answer: 2
91
+ ```
92
+
93
+ ### Recover full model checkpoint
94
+
95
+ Please follow the document to [prepare the model checkpoint](https://github.com/xumwen/Industrial-Foundation-Models/tree/merge_refactor?tab=readme-ov-file#prepare-the-model-checkpoint).
96
+
97
+ ### Sample inference code
98
+
99
+ This code shows how to quick start with running the model on a GPU:
100
+
101
+ ```python
102
+ import torch
103
+ from transformers import AutoModelForCausalLM, AutoTokenizer
104
+
105
+ # Load the checkpoint
106
+ model = AutoModelForCausalLM.from_pretrained(
107
+ CKPT_SAVE_PATH, # CKPT_SAVE_DIR/LLaMA-2-GTL/13B
108
+ torch_dtype=torch.bfloat16
109
+ )
110
+ tokenizer = AutoTokenizer.from_pretrained(CKPT_SAVE_PATH)
111
+
112
+ # Load example prompt
113
+ example_path = "data/prompt_examples/cls_in_context_table"
114
+ with open(example_path, "r") as f:
115
+ full_prompt = f.read()
116
+ answer = full_prompt.split('Answer:')[-1].strip()
117
+ prompt_without_answer = full_prompt[:-len(answer)]
118
+ print("Prompt:\n", prompt_without_answer)
119
+ print("Label:", answer)
120
+
121
+ # Inference
122
+ inputs = tokenizer(prompt_without_answer, return_tensors="pt")
123
+ input_ids = inputs['input_ids']
124
+ max_new_tokens = 10
125
+ outputs = model.generate(
126
+ input_ids=input_ids,
127
+ attention_mask=inputs['attention_mask'],
128
+ max_new_tokens=max_new_tokens
129
+ )
130
+
131
+ # Print the answer
132
+ print("Generate answer:", tokenizer.decode(outputs[0][input_ids.shape[-1]:]))
133
+ ```
134
+
135
+ ## Responsible AI Considerations
136
+
137
+ Like other language models, the LLaMA-GTL series models can potentially behave in ways that are unfair, unreliable, or offensive. Some of the risks and limitations to be aware of include:
138
+
139
+ + Data Bias: The model is trained on data that is not representative of the full range of industrial scenarios, and it may produce biased predictions. This could include over-representation of certain types of data or under-representation of others . Biased price forecasting could result in inaccurate budgeting, misplaced investments, and other business strategy misalignments. In the healthcare sector, it can perform tasks such as health risk assessments. Unrepresentative data could lead to skewed assessments and potentially compromise patient care. We recommend the users to have a clear understanding of the context and the underlying assumptions before drawing conclusions from the predictions.
140
+ + Algorithmic Bias: Despite the advanced learning algorithm used, there might be inherent biases in the algorithm itself which could influence the prediction outcomes. We strongly recommend that users verify the predictions with other sources or domain experts before making crucial decisions based on the model's output.
141
+ + Misinterpretation: There's a risk that users may misinterpret the predictions made by the model, leading to incorrect decisions.
142
+ + Our model may inherit vulnerabilities from the base model.
143
+
144
+ Developers should apply responsible AI best practices and are responsible for ensuring that a specific use case complies with relevant laws and regulations (e.g. privacy, trade, etc.). Important areas for consideration include:
145
+
146
+ + Allocation: Models may not be suitable for scenarios that could have consequential impact on legal status or the allocation of resources or life opportunities (ex: housing, employment, credit, etc.) without further assessments and additional debiasing techniques.
147
+ + High-Risk Scenarios: Developers should assess suitability of using models in high-risk scenarios where unfair, unreliable or offensive outputs might be extremely costly or lead to harm. This includes providing advice in sensitive or expert domains where accuracy and reliability are critical (ex: legal or health advice). Additional safeguards should be implemented at the application level according to the deployment context.
148
+ + Misinformation: Models may produce inaccurate information. Developers should follow transparency best practices and inform end-users they are interacting with an AI system. At the application level, developers can build feedback mechanisms and pipelines to ground responses in use-case specific, contextual information, a technique known as Retrieval Augmented Generation (RAG).
149
+ + Generation of Harmful Content: Developers should assess outputs for their context and use available safety classifiers or custom solutions appropriate for their use case.
150
+ + Misuse: Other forms of misuse such as fraud, spam, or malware production may be possible, and developers should ensure that their applications do not violate applicable laws and regulations.
151
+
152
+
153
+ ## Training and Evaluation
154
+
155
+ Please follow the [instruction](https://github.com/microsoft/Industrial-Foundation-Models) here to reproduce our [paper](https://arxiv.org/abs/2310.07338) results.
156
+
157
+ ## License
158
+
159
+ The model is licensed under the [MIT license](https://github.com/microsoft/Industrial-Foundation-Models/blob/main/LICENSE).