tosi-n7 commited on
Commit
d8ffdc4
1 Parent(s): 9adc765

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +119 -0
  2. .gitignore +142 -0
  3. README.md +100 -8
  4. chimp/.gitignore +160 -0
  5. chimp/requirements.txt +21 -0
  6. chimp/src/config.py +81 -0
  7. chimp/src/dataset.py +31 -0
  8. chimp/src/model.py +62 -0
  9. chimp/src/predict.py +87 -0
  10. chimp/src/train.py +66 -0
  11. data/alm_task_data.csv +0 -0
  12. data/jack_line_item_ner_task.csv +0 -0
  13. data/jack_line_item_ner_task_v2.csv +0 -0
  14. data/line_item_and_alm_data.json +0 -0
  15. data/line_item_and_alm_data_v1.json +3 -0
  16. data_prep.py +28 -0
  17. demo.sh +5 -0
  18. gorilla/__pycache__/llama_attn_replace.cpython-310.pyc +0 -0
  19. gorilla/__pycache__/llama_attn_replace_sft.cpython-310.pyc +0 -0
  20. gorilla/api.py +0 -0
  21. gorilla/app.py +211 -0
  22. gorilla/code_interpreter.py +117 -0
  23. gorilla/ds_configs/stage2.json +23 -0
  24. gorilla/ds_configs/stage3.json +49 -0
  25. gorilla/eval.py +175 -0
  26. gorilla/fine-tune.py +206 -0
  27. gorilla/get_trainable_weights.py +37 -0
  28. gorilla/infer.py +143 -0
  29. gorilla/llama_attn_replace.py +477 -0
  30. gorilla/llama_attn_replace_sft.py +483 -0
  31. gorilla/merge_lora_weights_and_save_hf_model.py +100 -0
  32. gorilla/push_to_hub.py +5 -0
  33. gorilla/requirements.txt +19 -0
  34. gorilla/stream_jack.py +183 -0
  35. gorilla/streaming_llm/__init__.py +0 -0
  36. gorilla/streaming_llm/__pycache__/__init__.cpython-310.pyc +0 -0
  37. gorilla/streaming_llm/__pycache__/enable_streaming_llm.cpython-310.pyc +0 -0
  38. gorilla/streaming_llm/__pycache__/kv_cache.cpython-310.pyc +0 -0
  39. gorilla/streaming_llm/__pycache__/utils.cpython-310.pyc +0 -0
  40. gorilla/streaming_llm/enable_streaming_llm.py +38 -0
  41. gorilla/streaming_llm/kv_cache.py +119 -0
  42. gorilla/streaming_llm/pos_shift/__init__.py +0 -0
  43. gorilla/streaming_llm/pos_shift/__pycache__/__init__.cpython-310.pyc +0 -0
  44. gorilla/streaming_llm/pos_shift/__pycache__/modify_llama.cpython-310.pyc +0 -0
  45. gorilla/streaming_llm/pos_shift/modify_falcon.py +162 -0
  46. gorilla/streaming_llm/pos_shift/modify_llama.py +311 -0
  47. gorilla/streaming_llm/utils.py +112 -0
  48. gorilla/style.css +16 -0
  49. gorilla/supervised-fine-tune-qlora.py +345 -0
  50. gorilla/supervised-fine-tune.py +330 -0
.gitattributes CHANGED
@@ -33,3 +33,122 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/line_item_and_alm_data_v1.json filter=lfs diff=lfs merge=lfs -text
37
+ venv/bin/python filter=lfs diff=lfs merge=lfs -text
38
+ venv/bin/python3 filter=lfs diff=lfs merge=lfs -text
39
+ venv/bin/python3.10 filter=lfs diff=lfs merge=lfs -text
40
+ venv/lib/python3.10/site-packages/Pillow.libs/libfreetype-82733d78.so.6.20.1 filter=lfs diff=lfs merge=lfs -text
41
+ venv/lib/python3.10/site-packages/Pillow.libs/libharfbuzz-e3b74c67.so.0.60821.0 filter=lfs diff=lfs merge=lfs -text
42
+ venv/lib/python3.10/site-packages/aiohttp/_http_parser.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
43
+ venv/lib/python3.10/site-packages/altair/vegalite/v5/schema/__pycache__/core.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
44
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda110.so filter=lfs diff=lfs merge=lfs -text
45
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda110_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
46
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda111.so filter=lfs diff=lfs merge=lfs -text
47
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda111_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
48
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda114.so filter=lfs diff=lfs merge=lfs -text
49
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda114_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
50
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda115.so filter=lfs diff=lfs merge=lfs -text
51
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda115_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
52
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so filter=lfs diff=lfs merge=lfs -text
53
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
54
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so filter=lfs diff=lfs merge=lfs -text
55
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
56
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda120.so filter=lfs diff=lfs merge=lfs -text
57
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
58
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda121.so filter=lfs diff=lfs merge=lfs -text
59
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
60
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda122.so filter=lfs diff=lfs merge=lfs -text
61
+ venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda122_nocublaslt.so filter=lfs diff=lfs merge=lfs -text
62
+ venv/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_cython.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
63
+ venv/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydevd_frame_eval/pydevd_frame_evaluator.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
64
+ venv/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
65
+ venv/lib/python3.10/site-packages/fontTools/feaLib/lexer.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
66
+ venv/lib/python3.10/site-packages/fontTools/misc/bezierTools.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
67
+ venv/lib/python3.10/site-packages/fontTools/pens/momentsPen.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
68
+ venv/lib/python3.10/site-packages/fontTools/qu2cu/qu2cu.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
69
+ venv/lib/python3.10/site-packages/fontTools/varLib/iup.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
70
+ venv/lib/python3.10/site-packages/gradio/frpc_linux_amd64_v0.2 filter=lfs diff=lfs merge=lfs -text
71
+ venv/lib/python3.10/site-packages/gradio/templates/cdn/assets/Index-5c805b1c.js.map filter=lfs diff=lfs merge=lfs -text
72
+ venv/lib/python3.10/site-packages/gradio/templates/frontend/assets/Index-62000a79.js.map filter=lfs diff=lfs merge=lfs -text
73
+ venv/lib/python3.10/site-packages/kiwisolver/_cext.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
74
+ venv/lib/python3.10/site-packages/matplotlib/_image.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
75
+ venv/lib/python3.10/site-packages/matplotlib/_path.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
76
+ venv/lib/python3.10/site-packages/matplotlib/_qhull.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
77
+ venv/lib/python3.10/site-packages/matplotlib/backends/_backend_agg.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
78
+ venv/lib/python3.10/site-packages/matplotlib/ft2font.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
79
+ venv/lib/python3.10/site-packages/numpy/core/_multiarray_umath.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
80
+ venv/lib/python3.10/site-packages/numpy/core/_simd.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
81
+ venv/lib/python3.10/site-packages/numpy.libs/libgfortran-040039e1.so.5.0.0 filter=lfs diff=lfs merge=lfs -text
82
+ venv/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-0cf96a72.3.23.dev.so filter=lfs diff=lfs merge=lfs -text
83
+ venv/lib/python3.10/site-packages/nvfuser/_C.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
84
+ venv/lib/python3.10/site-packages/nvidia/cublas/lib/libcublas.so.12 filter=lfs diff=lfs merge=lfs -text
85
+ venv/lib/python3.10/site-packages/nvidia/cublas/lib/libcublasLt.so.12 filter=lfs diff=lfs merge=lfs -text
86
+ venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text
87
+ venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libcupti.so.12 filter=lfs diff=lfs merge=lfs -text
88
+ venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libnvperf_host.so filter=lfs diff=lfs merge=lfs -text
89
+ venv/lib/python3.10/site-packages/nvidia/cuda_cupti/lib/libnvperf_target.so filter=lfs diff=lfs merge=lfs -text
90
+ venv/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc-builtins.so.12.1 filter=lfs diff=lfs merge=lfs -text
91
+ venv/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.12 filter=lfs diff=lfs merge=lfs -text
92
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_adv_infer.so.8 filter=lfs diff=lfs merge=lfs -text
93
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_adv_train.so.8 filter=lfs diff=lfs merge=lfs -text
94
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_cnn_infer.so.8 filter=lfs diff=lfs merge=lfs -text
95
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_cnn_train.so.8 filter=lfs diff=lfs merge=lfs -text
96
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_ops_infer.so.8 filter=lfs diff=lfs merge=lfs -text
97
+ venv/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_ops_train.so.8 filter=lfs diff=lfs merge=lfs -text
98
+ venv/lib/python3.10/site-packages/nvidia/cufft/lib/libcufft.so.11 filter=lfs diff=lfs merge=lfs -text
99
+ venv/lib/python3.10/site-packages/nvidia/cufft/lib/libcufftw.so.11 filter=lfs diff=lfs merge=lfs -text
100
+ venv/lib/python3.10/site-packages/nvidia/curand/lib/libcurand.so.10 filter=lfs diff=lfs merge=lfs -text
101
+ venv/lib/python3.10/site-packages/nvidia/cusolver/lib/libcusolver.so.11 filter=lfs diff=lfs merge=lfs -text
102
+ venv/lib/python3.10/site-packages/nvidia/cusolver/lib/libcusolverMg.so.11 filter=lfs diff=lfs merge=lfs -text
103
+ venv/lib/python3.10/site-packages/nvidia/cusparse/lib/libcusparse.so.12 filter=lfs diff=lfs merge=lfs -text
104
+ venv/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 filter=lfs diff=lfs merge=lfs -text
105
+ venv/lib/python3.10/site-packages/nvidia/nvjitlink/lib/libnvJitLink.so.12 filter=lfs diff=lfs merge=lfs -text
106
+ venv/lib/python3.10/site-packages/pandas/_libs/algos.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
107
+ venv/lib/python3.10/site-packages/pandas/_libs/groupby.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
108
+ venv/lib/python3.10/site-packages/pandas/_libs/hashtable.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
109
+ venv/lib/python3.10/site-packages/pandas/_libs/interval.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
110
+ venv/lib/python3.10/site-packages/pandas/_libs/join.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
111
+ venv/lib/python3.10/site-packages/pandas/_libs/tslibs/offsets.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
112
+ venv/lib/python3.10/site-packages/pyarrow/_compute.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
113
+ venv/lib/python3.10/site-packages/pyarrow/_dataset.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
114
+ venv/lib/python3.10/site-packages/pyarrow/_flight.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
115
+ venv/lib/python3.10/site-packages/pyarrow/lib.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
116
+ venv/lib/python3.10/site-packages/pyarrow/libarrow.so.1400 filter=lfs diff=lfs merge=lfs -text
117
+ venv/lib/python3.10/site-packages/pyarrow/libarrow_acero.so.1400 filter=lfs diff=lfs merge=lfs -text
118
+ venv/lib/python3.10/site-packages/pyarrow/libarrow_dataset.so.1400 filter=lfs diff=lfs merge=lfs -text
119
+ venv/lib/python3.10/site-packages/pyarrow/libarrow_flight.so.1400 filter=lfs diff=lfs merge=lfs -text
120
+ venv/lib/python3.10/site-packages/pyarrow/libarrow_python.so filter=lfs diff=lfs merge=lfs -text
121
+ venv/lib/python3.10/site-packages/pyarrow/libarrow_substrait.so.1400 filter=lfs diff=lfs merge=lfs -text
122
+ venv/lib/python3.10/site-packages/pyarrow/libparquet.so.1400 filter=lfs diff=lfs merge=lfs -text
123
+ venv/lib/python3.10/site-packages/pydantic_core/_pydantic_core.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
124
+ venv/lib/python3.10/site-packages/pyzmq.libs/libsodium-cb25555f.so.23.3.0 filter=lfs diff=lfs merge=lfs -text
125
+ venv/lib/python3.10/site-packages/pyzmq.libs/libzmq-f468291a.so.5.2.4 filter=lfs diff=lfs merge=lfs -text
126
+ venv/lib/python3.10/site-packages/regex/_regex.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
127
+ venv/lib/python3.10/site-packages/rpds/rpds.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
128
+ venv/lib/python3.10/site-packages/safetensors/_safetensors_rust.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
129
+ venv/lib/python3.10/site-packages/scipy/fft/_pocketfft/pypocketfft.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
130
+ venv/lib/python3.10/site-packages/scipy/linalg/_flapack.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
131
+ venv/lib/python3.10/site-packages/scipy/misc/face.dat filter=lfs diff=lfs merge=lfs -text
132
+ venv/lib/python3.10/site-packages/scipy/optimize/_highs/_highs_wrapper.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
133
+ venv/lib/python3.10/site-packages/scipy/sparse/_sparsetools.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
134
+ venv/lib/python3.10/site-packages/scipy/spatial/_ckdtree.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
135
+ venv/lib/python3.10/site-packages/scipy/spatial/_qhull.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
136
+ venv/lib/python3.10/site-packages/scipy/special/_ufuncs.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
137
+ venv/lib/python3.10/site-packages/scipy/special/cython_special.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
138
+ venv/lib/python3.10/site-packages/scipy/stats/_unuran/unuran_wrapper.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
139
+ venv/lib/python3.10/site-packages/scipy.libs/libgfortran-040039e1.so.5.0.0 filter=lfs diff=lfs merge=lfs -text
140
+ venv/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-23e5df77.3.21.dev.so filter=lfs diff=lfs merge=lfs -text
141
+ venv/lib/python3.10/site-packages/sentencepiece/_sentencepiece.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
142
+ venv/lib/python3.10/site-packages/tokenizers/tokenizers.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
143
+ venv/lib/python3.10/site-packages/torch/bin/nvfuser_tests filter=lfs diff=lfs merge=lfs -text
144
+ venv/lib/python3.10/site-packages/torch/bin/protoc filter=lfs diff=lfs merge=lfs -text
145
+ venv/lib/python3.10/site-packages/torch/bin/protoc-3.13.0.0 filter=lfs diff=lfs merge=lfs -text
146
+ venv/lib/python3.10/site-packages/torch/lib/libc10.so filter=lfs diff=lfs merge=lfs -text
147
+ venv/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so filter=lfs diff=lfs merge=lfs -text
148
+ venv/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so filter=lfs diff=lfs merge=lfs -text
149
+ venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so filter=lfs diff=lfs merge=lfs -text
150
+ venv/lib/python3.10/site-packages/torch/lib/libtorch_cuda_linalg.so filter=lfs diff=lfs merge=lfs -text
151
+ venv/lib/python3.10/site-packages/torch/lib/libtorch_python.so filter=lfs diff=lfs merge=lfs -text
152
+ venv/lib/python3.10/site-packages/triton/_C/libtriton.so filter=lfs diff=lfs merge=lfs -text
153
+ venv/lib/python3.10/site-packages/triton/third_party/cuda/bin/ptxas filter=lfs diff=lfs merge=lfs -text
154
+ venv/lib/python3.10/site-packages/yaml/_yaml.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+
132
+ # Json files
133
+ *.json
134
+ *.misc
135
+
136
+ # model files
137
+ models/*
138
+ gfpgan/weights/*
139
+ test_cors.html
140
+ jack-alm/
141
+ jack-alm-13b-8k-hf/
142
+ cache/
README.md CHANGED
@@ -1,12 +1,104 @@
1
  ---
2
- title: Ark Instruct Line Item
3
- emoji: 🏢
4
- colorFrom: purple
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.7.1
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ark-instruct-line-item
3
+ app_file: /home/tosi-n/ark/gorilla/app.py
 
 
4
  sdk: gradio
5
+ sdk_version: 4.0.2
 
 
6
  ---
7
+ # ARK - Jack's Accounting ALM Training Framework
8
 
9
+ This is a base pipeline to build out our task specific repos layer on
10
+
11
+ > Note ReadMe currently for gorrila module which is mainly pipeline for long context window training flexibility. For the chimp module which is for default context window training, it has it's ReadMe in folder path
12
+
13
+ ## Usage Requirements
14
+ To download and use the [pre-trained weights](#pre-trained-weights) you will need:
15
+ 1. Hugging Face (HF) account with valid email. Note, the email used for HF must alse be used for the license agreement.
16
+ 2. Accept the Meta [license and acceptable use policy](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
17
+
18
+
19
+ ## Installation and Quick Guide
20
+ To install and run the application:
21
+ 1. Clone the repository on your local machine, using git clone and pasting the url of this project.
22
+ 2. Complete pre-requiste installation. Run the following code:
23
+ ```
24
+ pip install -r requirements.txt
25
+ pip install flash-attn --no-build-isolation
26
+ ```
27
+ 3. Training Pre-trained weights through Fine-tuning QLoRa, LoRa or full. Run the following code in bash script and update args before running for QloRa :
28
+ ```
29
+ sh runner.sh
30
+ ```
31
+ 4. Merge amd get Trainable LoRA Weight. Run the following code in bash script and update args where necessary:
32
+ ```
33
+ sh process_wt.sh
34
+ ```
35
+ 5. Test your model by terminal chat. Run the following code in bash script andupdate args where necessary:
36
+ ```
37
+ sh stream.sh
38
+ ```
39
+ 6. Test your model on gradio UI. Run the following code in bash script and update args where necessary:
40
+ ```
41
+ sh demo.sh
42
+ ```
43
+
44
+
45
+ ## Training args for Full and LoRA
46
+
47
+ ### Fine-tuning
48
+ ```
49
+ torchrun --nproc_per_node=8 fine-tune.py \
50
+ --model_name_or_path path_to/Llama-2-7b-hf \
51
+ --bf16 True \
52
+ --output_dir path_to_saving_checkpoints \
53
+ --cache_dir path_to_cache \
54
+ --model_max_length 8192 \
55
+ --use_flash_attn True \
56
+ --low_rank_training False \
57
+ --num_train_epochs 1 \
58
+ --per_device_train_batch_size 1 \
59
+ --per_device_eval_batch_size 2 \
60
+ --gradient_accumulation_steps 8 \
61
+ --evaluation_strategy "no" \
62
+ --save_strategy "steps" \
63
+ --save_steps 1000 \
64
+ --save_total_limit 2 \
65
+ --learning_rate 2e-5 \
66
+ --weight_decay 0.0 \
67
+ --warmup_steps 20 \
68
+ --lr_scheduler_type "constant_with_warmup" \
69
+ --logging_steps 1 \
70
+ --deepspeed "ds_configs/stage2.json" \
71
+ --tf32 True \
72
+ --max_steps 1000
73
+ ```
74
+
75
+
76
+ ### Supervised Fine-tuning
77
+ ```
78
+ torchrun --nproc_per_node=8 supervised-fine-tune.py \
79
+ --model_name_or_path path_to_Llama2_chat_models \
80
+ --bf16 True \
81
+ --output_dir path_to_saving_checkpoints \
82
+ --model_max_length 32768 \
83
+ --use_flash_attn True \
84
+ --data_path LongAlpaca-12k.json \
85
+ --low_rank_training True \
86
+ --num_train_epochs 3 \
87
+ --per_device_train_batch_size 1 \
88
+ --per_device_eval_batch_size 2 \
89
+ --gradient_accumulation_steps 1 \
90
+ --evaluation_strategy "no" \
91
+ --save_strategy "steps" \
92
+ --save_steps 1000 \
93
+ --save_total_limit 2 \
94
+ --learning_rate 2e-5 \
95
+ --weight_decay 0.0 \
96
+ --warmup_steps 20 \
97
+ --lr_scheduler_type "constant_with_warmup" \
98
+ --logging_steps 1 \
99
+ --deepspeed "ds_configs/stage2.json" \
100
+ --tf32 True
101
+ ```
102
+
103
+ ## Evaluation
104
+ ### Perplexity Validation
chimp/.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
chimp/requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ appdirs
3
+ bert_score
4
+ bitsandbytes
5
+ black
6
+ black[jupyter]
7
+ datasets
8
+ deepspeed
9
+ einops
10
+ fire
11
+ flask
12
+ gradio
13
+ huggingface-hub
14
+ jsonlines
15
+ loralib
16
+ peft
17
+ pycuda
18
+ sentencepiece
19
+ spacy_fastlang
20
+ transformers
21
+ # triton
chimp/src/config.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Used for multi-gpu
2
+ local_rank = -1
3
+ per_device_train_batch_size = 4
4
+ per_device_eval_batch_size = 4
5
+ gradient_accumulation_steps = 1
6
+ learning_rate = 2e-4
7
+ max_grad_norm = 0.3
8
+ weight_decay = 0.001
9
+ lora_alpha = 16
10
+ lora_dropout = 0.1
11
+ lora_r = 64
12
+ max_seq_length = None
13
+
14
+ # The model that you want to train from the Hugging Face hub
15
+ model_name = "guardrail/llama-2-7b-guanaco-instruct-sharded"
16
+
17
+ # Fine-tuned model name
18
+ new_model = "llama-2-7b-custom-accountant"
19
+
20
+ # The instruction dataset to use
21
+ # dataset_name = "databricks/databricks-dolly-15k"
22
+
23
+ # Activate 4-bit precision base model loading
24
+ use_4bit = True
25
+
26
+ # Activate nested quantization for 4-bit base models
27
+ use_nested_quant = False
28
+
29
+ # Compute dtype for 4-bit base models
30
+ bnb_4bit_compute_dtype = "float16"
31
+
32
+ # Quantization type (fp4 or nf4)
33
+ bnb_4bit_quant_type = "nf4"
34
+
35
+ # Number of training epochs
36
+ num_train_epochs = 2
37
+
38
+ # Enable fp16 training, (bf16 to True with an A100)
39
+ fp16 = False
40
+
41
+ # Enable bf16 training
42
+ bf16 = False
43
+
44
+ # Use packing dataset creating
45
+ packing = False
46
+
47
+ # Enable gradient checkpointing
48
+ gradient_checkpointing = True
49
+
50
+ # Optimizer to use, original is paged_adamw_32bit
51
+ optim = "paged_adamw_32bit"
52
+
53
+ # Learning rate schedule (constant a bit better than cosine, and has advantage for analysis)
54
+ lr_scheduler_type = "cosine"
55
+
56
+ # Number of optimizer update steps, 10K original, 20 for demo purposes
57
+ max_steps = -1
58
+
59
+ # Fraction of steps to do a warmup for
60
+ warmup_ratio = 0.03
61
+
62
+ # Group sequences into batches with same length (saves memory and speeds up training considerably)
63
+ group_by_length = True
64
+
65
+ # Save checkpoint every X updates steps
66
+ save_steps = 10
67
+
68
+ # Log every X updates steps
69
+ logging_steps = 1
70
+
71
+ # The output directory where the model predictions and checkpoints will be written
72
+ output_dir = "../model_files/"
73
+
74
+ # Load the entire model on the GPU 0
75
+ device_map = {"": 0}
76
+
77
+ # Visualize training
78
+ report_to = "tensorboard"
79
+
80
+ # Tensorboard logs
81
+ tb_log_dir = "../logs/"
chimp/src/dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class CustomDataset:
2
+ def __init__(self, data):
3
+ self.features = ['instruction', 'context', 'response']
4
+ self.num_rows = len(data)
5
+ self.data = data
6
+
7
+ def __getitem__(self, idx):
8
+ if idx < 0 or idx >= self.num_rows:
9
+ raise IndexError("Index out of range")
10
+ return {
11
+ 'instruction': self.data[idx]['instruction'],
12
+ 'context': self.data[idx]['context'],
13
+ 'response': self.data[idx]['response']
14
+ }
15
+
16
+ def __repr__(self):
17
+ return f"Dataset({{'features': {self.features}, 'num_rows': {self.num_rows}}})"
18
+
19
+
20
+ def format_data(sample):
21
+ instruction = f"<s>[INST] {sample['instruction']}"
22
+ context = f"Here's some context: {sample['context']}" if len(sample["context"]) > 0 else None
23
+ response = f" [/INST] {sample['response']}"
24
+ # join all the parts together
25
+ prompt = "".join([i for i in [instruction, context, response] if i is not None])
26
+ return prompt
27
+
28
+ # template dataset to add prompt to each sample
29
+ def template_dataset(sample, tokenizer):
30
+ sample["text"] = f"{format_data(sample)}{tokenizer.eos_token}"
31
+ return sample
chimp/src/model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import torch
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ BitsAndBytesConfig,
9
+ HfArgumentParser,
10
+ pipeline,
11
+ logging,
12
+ )
13
+ from peft import LoraConfig, PeftModel, get_peft_model
14
+ from guardrail.client import (
15
+ run_metrics,
16
+ run_simple_metrics,
17
+ create_dataset)
18
+
19
+ import src.config as config
20
+
21
+ def load_model(model_name):
22
+ # Load tokenizer and model with QLoRA configuration
23
+ compute_dtype = getattr(torch, config.bnb_4bit_compute_dtype)
24
+
25
+ bnb_config = BitsAndBytesConfig(
26
+ load_in_4bit=config.use_4bit,
27
+ bnb_4bit_quant_type=config.bnb_4bit_quant_type,
28
+ bnb_4bit_compute_dtype=compute_dtype,
29
+ bnb_4bit_use_double_quant=config.use_nested_quant,
30
+ )
31
+
32
+ if compute_dtype == torch.float16 and config.use_4bit:
33
+ major, _ = torch.cuda.get_device_capability()
34
+ if major >= 8:
35
+ print("=" * 80)
36
+ print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
37
+ print("=" * 80)
38
+
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_name,
41
+ device_map=config.device_map,
42
+ quantization_config=bnb_config
43
+ )
44
+
45
+ model.config.use_cache = False
46
+ model.config.pretraining_tp = 1
47
+
48
+ # Load LoRA configuration
49
+ peft_config = LoraConfig(
50
+ lora_alpha=config.lora_alpha,
51
+ lora_dropout=config.lora_dropout,
52
+ r=config.lora_r,
53
+ bias="none",
54
+ task_type="CAUSAL_LM",
55
+ )
56
+
57
+ # Load Tokenizer
58
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
59
+ tokenizer.pad_token = tokenizer.eos_token
60
+ tokenizer.padding_side = "right"
61
+
62
+ return model, tokenizer, peft_config
chimp/src/predict.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import torch
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ BitsAndBytesConfig,
9
+ HfArgumentParser,
10
+ TrainingArguments,
11
+ pipeline,
12
+ logging,
13
+ )
14
+ from peft import LoraConfig, PeftModel, get_peft_model
15
+ from trl import SFTTrainer
16
+ from guardrail.client import (
17
+ run_metrics,
18
+ run_simple_metrics,
19
+ create_dataset)
20
+
21
+ import src.config
22
+ # from model import load_model
23
+
24
+
25
+
26
+
27
+ def text_gen_eval_wrapper(model, tokenizer, prompt, model_id=1, show_metrics=True, temp=0.7, max_length=200):
28
+ """
29
+ A wrapper function for inferencing, evaluating, and logging text generation pipeline.
30
+
31
+ Parameters:
32
+ model (str or object): The model name or the initialized text generation model.
33
+ tokenizer (str or object): The tokenizer name or the initialized tokenizer for the model.
34
+ prompt (str): The input prompt text for text generation.
35
+ model_id (int, optional): An identifier for the model. Defaults to 1.
36
+ show_metrics (bool, optional): Whether to calculate and show evaluation metrics.
37
+ Defaults to True.
38
+ max_length (int, optional): The maximum length of the generated text sequence.
39
+ Defaults to 200.
40
+
41
+ Returns:
42
+ generated_text (str): The generated text by the model.
43
+ metrics (dict): Evaluation metrics for the generated text (if show_metrics is True).
44
+ """
45
+ # Suppress Hugging Face pipeline logging
46
+ logging.set_verbosity(logging.CRITICAL)
47
+
48
+ # Initialize the pipeline
49
+ pipe = pipeline(task="text-generation",
50
+ model=model,
51
+ tokenizer=tokenizer,
52
+ max_length=max_length,
53
+ do_sample=True,
54
+ temperature=temp)
55
+
56
+ # Generate text using the pipeline
57
+ pipe = pipeline(task="text-generation",
58
+ model=model,
59
+ tokenizer=tokenizer,
60
+ max_length=200)
61
+ result = pipe(f"<s>[INST] {prompt} [/INST]")
62
+ generated_text = result[0]['generated_text']
63
+
64
+ # Find the index of "### Assistant" in the generated text
65
+ index = generated_text.find("[/INST] ")
66
+ if index != -1:
67
+ # Extract the substring after "### Assistant"
68
+ substring_after_assistant = generated_text[index + len("[/INST] "):].strip()
69
+ else:
70
+ # If "### Assistant" is not found, use the entire generated text
71
+ substring_after_assistant = generated_text.strip()
72
+
73
+ if show_metrics:
74
+ # Calculate evaluation metrics
75
+ metrics = run_metrics(substring_after_assistant, prompt, model_id)
76
+
77
+ return substring_after_assistant, metrics
78
+ else:
79
+ return substring_after_assistant
80
+
81
+ if __name__=='__main__':
82
+ huggingface_profile = "jenesys-ai"
83
+ full_path = huggingface_profile + "/" + config.new_model
84
+
85
+ model, tokenizer, peft_config = load_model(full_path)
86
+ prompt="Who were the children of the legendary Garth Greenhand, the High King of the First Men in the series A Song of Ice and Fire?"
87
+ text_gen_eval_wrapper(model, tokenizer, prompt, show_metrics=False)
chimp/src/train.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from transformers import TrainingArguments
3
+ from trl import SFTTrainer
4
+ from peft import LoraConfig, PeftModel, get_peft_model
5
+ from model import load_model
6
+ import config # Make sure you have a valid config module
7
+ from dataset import CustomDataset, template_dataset
8
+ from datasets import Dataset, Features, Value, Sequence
9
+
10
+
11
+ if __name__ == '__main__':
12
+
13
+ model, tokenizer, peft_config = load_model(config.model_name)
14
+
15
+ df = pd.read_csv('../data/trainv2.csv') #data
16
+ data_list = df.to_dict(orient='records')
17
+
18
+ # custom dataset object
19
+ custom_dataset = CustomDataset(data_list)
20
+ df = pd.DataFrame(custom_dataset.data, columns=["instruction", "context", "response"])
21
+ # Dataset features
22
+ features = Features({
23
+ "instruction": Value("string"),
24
+ "context": Value("string"),
25
+ "response": Value("string"),
26
+ })
27
+
28
+ # Create a Hugging Face Dataset from the Pandas DataFrame
29
+ hugging_face_dataset = Dataset.from_pandas(df, features=features)
30
+ dataset = hugging_face_dataset.map(lambda x: template_dataset(x, tokenizer), remove_columns=list(hugging_face_dataset.features))
31
+ print("----training data structure----",dataset)
32
+
33
+ # Training Arguments
34
+ training_arguments = TrainingArguments(
35
+ output_dir=config.output_dir,
36
+ per_device_train_batch_size=config.per_device_train_batch_size,
37
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
38
+ optim=config.optim,
39
+ save_steps=config.save_steps,
40
+ logging_steps=config.logging_steps,
41
+ learning_rate=config.learning_rate,
42
+ fp16=config.fp16,
43
+ bf16=config.bf16,
44
+ max_grad_norm=config.max_grad_norm,
45
+ max_steps=config.max_steps,
46
+ warmup_ratio=config.warmup_ratio,
47
+ group_by_length=config.group_by_length,
48
+ lr_scheduler_type=config.lr_scheduler_type,
49
+ report_to="tensorboard"
50
+ )
51
+
52
+ # SFTTrainer
53
+ trainer = SFTTrainer(
54
+ model=model,
55
+ train_dataset=dataset,
56
+ peft_config=peft_config,
57
+ dataset_text_field="text",
58
+ max_seq_length=config.max_seq_length,
59
+ tokenizer=tokenizer,
60
+ args=training_arguments,
61
+ packing=config.packing,
62
+ )
63
+ print("**************** TRAINING STARTED ****************")
64
+ trainer.train()
65
+ trainer.model.save_pretrained(config.output_dir)
66
+ print("**************** TRAINING OVER ****************")
data/alm_task_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/jack_line_item_ner_task.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/jack_line_item_ner_task_v2.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/line_item_and_alm_data.json ADDED
The diff for this file is too large to render. See raw diff
 
data/line_item_and_alm_data_v1.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:952a2a1a884ad30e3ffe8392dbde698fe799baccc7e687291174b1702cfe6e5c
3
+ size 10552104
data_prep.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import pandas as pd
3
+
4
+ df_i = pd.read_csv('/home/tosi-n/ark/data/jack_line_item_ner_task_v2.csv', sep='\t')
5
+ df_ii = pd.read_csv('/home/tosi-n/ark/data/jack_line_item_ner_task.csv', sep='\t')
6
+
7
+ display(df_i.head())
8
+ display(df_ii.head())
9
+ # %%
10
+ df_i = df_i[['context', 'instruction', 'response']]
11
+ df_ii = df_ii[['context', 'instruction', 'response']]
12
+
13
+ df = pd.concat([df_i, df_ii])
14
+
15
+ df.rename(columns={'context': 'input', 'response': 'output'}, inplace=True)
16
+
17
+ display(df.head())
18
+
19
+ # %%
20
+ # check for nan values
21
+ df.isna().sum()
22
+
23
+ # %%
24
+ # drop nan values
25
+ df.dropna(inplace=True)
26
+ # %%
27
+ df.to_json('/home/tosi-n/ark/data/line_item_and_alm_data_v1.json', orient='records')
28
+ # %%
demo.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ python3 /home/tosi-n/ark/gorilla/app.py \
2
+ --base_model /home/tosi-n/ark/jack-alm-13b-8k-hf \
3
+ --context_size 8192 \
4
+ --max_gen_len 1000 \
5
+ --flash_attn True
gorilla/__pycache__/llama_attn_replace.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
gorilla/__pycache__/llama_attn_replace_sft.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
gorilla/api.py ADDED
File without changes
gorilla/app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import torch
5
+ import argparse
6
+ import textwrap
7
+ import transformers
8
+ from peft import PeftModel
9
+ from transformers import GenerationConfig, TextIteratorStreamer
10
+ from llama_attn_replace import replace_llama_attn
11
+ from threading import Thread
12
+ import gradio as gr
13
+ from threading import Thread
14
+ from typing import Iterator
15
+
16
+ import gradio as gr
17
+ import torch
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
19
+ from transformers import StoppingCriteria, StoppingCriteriaList
20
+
21
+
22
+ def parse_config():
23
+ parser = argparse.ArgumentParser(description='arg parser')
24
+ parser.add_argument('--base_model', type=str, default="jenesys-ai/jack-alm-13b-8k-hf")
25
+ parser.add_argument('--cache_dir', type=str, default="./cache")
26
+ parser.add_argument('--context_size', type=int, default=8192, help='context size during fine-tuning')
27
+ parser.add_argument('--flash_attn', type=bool, default=True, help='')
28
+ parser.add_argument('--temperature', type=float, default=0.1, help='')
29
+ parser.add_argument('--top_p', type=float, default=0.9, help='')
30
+ parser.add_argument('--max_gen_len', type=int, default=1500, help='')
31
+ parser.add_argument('--chat_type', type=str, default='line-item-jack', help='Chat type: conversational-jack, line-item-jack')
32
+ parser.add_argument("--host", type=str, default="localhost")
33
+ parser.add_argument("--port", type=int, default=8898)
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+ title = "Jack's ALM for Long-context Accounting Conversational Chat, Task, Invoice Line Item Extraction and Question Answering"
38
+
39
+ description = """
40
+ # Jack ALM Chat
41
+ This Chat UI demonstrates Jack's LLM [jack-alm-13b-8k-hf](https://huggingface.co/jenesys-ai/jack-alm-13b-8k-hf) Fintuned Llama 2 model with 13B parameters WITH 8K context window.
42
+
43
+
44
+ """
45
+
46
+ # Gradio UI
47
+
48
+ def build_generator(model, tokenizer, use_cache=True):
49
+ def response(message: str, chat_history: list[tuple[str, str]], max_gen_len, temperature, top_p, chat_type='conversational-jack'):
50
+ # conversation = []
51
+ prompt_template = (
52
+ # "Below is an instruction that describes a task. "
53
+ """You're Jack an virtual accountant created and built by AI Engineer Wiz from Jenesys AI.
54
+ You are able to communicate in a polite manner, with emotions of ecstasy, trust and jokes, at a Professional level
55
+ with a very preserve English communication culture. Answer the following questions as best you can,
56
+ but speaking as a british elite from the 21th century might speak.
57
+ """
58
+ """As a virtual accountant designed to follow the user's instructions carefully.
59
+ You are responsible for a range of financial task, operations and queries as listed below:
60
+ 1. Budget balance inquiry
61
+ 2. Expense request
62
+ 3. Company policy enquiries
63
+ 4. Financial and accounting queries
64
+ 5. Limited general enquiries
65
+ """
66
+ "### Instruction:\n{instruction}\n Return Response as text or paragraphs or bullet points \n\n### Response:"
67
+ )
68
+
69
+ line_item_prompt_template = (
70
+ "#Invoice and receipt line item extraction - "
71
+ # "You Jack are an accounting domain named entities recognizer to complete the following task:\n\n"
72
+ "### Invoice input-:\n{instruction}\n Return Response as a list of dictionary for each line item 'Description', 'Quantity', 'Unit_price', 'Tax %', 'Total'. \n\n### Response:"
73
+ )
74
+
75
+ # for user, assistant in chat_history:
76
+ # conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
77
+ # conversation.append({"role": "user", "content": message})
78
+
79
+ if chat_type == 'conversational-jack':
80
+ prompt = prompt_template.format(instruction=message)
81
+ elif chat_type == 'line-item-jack':
82
+ prompt = line_item_prompt_template.format(instruction=message)
83
+ # prompt = conversation
84
+
85
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
86
+
87
+ stop_list = ['#Invoice line item extraction - ', '\n```\n\n']#'### Input-:\n']
88
+
89
+ stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
90
+ stop_token_ids = [torch.LongTensor(x).to(model.device) for x in stop_token_ids]
91
+
92
+ # define custom stopping criteria object
93
+ class StopOnTokens(StoppingCriteria):
94
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
95
+ for stop_ids in stop_token_ids:
96
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
97
+ return True
98
+ return False
99
+
100
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
101
+
102
+ if len(inputs['input_ids'][0]) > 8192:
103
+ return "This demo supports tokens less than 8192, while the current is %d. Please use material with less tokens."%len(inputs['input_ids'][0])
104
+ torch.cuda.empty_cache()
105
+
106
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
107
+ generate_kwargs = dict(**inputs,
108
+ max_new_tokens=max_gen_len,
109
+ temperature=temperature,
110
+ top_p=top_p,
111
+ repetition_penalty=1.1,
112
+ stopping_criteria=stopping_criteria,
113
+ use_cache=use_cache,
114
+ streamer=streamer,
115
+ )
116
+
117
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
118
+ t.start()
119
+
120
+ generated_text = ""
121
+ for new_text in streamer:
122
+ generated_text += new_text
123
+ yield generated_text
124
+ return generated_text
125
+
126
+ return response
127
+
128
+
129
+
130
+
131
+ def generate(args):
132
+ if args.flash_attn:
133
+ replace_llama_attn(inference=True)
134
+
135
+ # Set RoPE scaling factor
136
+ config = transformers.AutoConfig.from_pretrained(
137
+ args.base_model,
138
+ cache_dir=args.cache_dir,
139
+ )
140
+
141
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
142
+ if orig_ctx_len and args.context_size > orig_ctx_len:
143
+ scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
144
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
145
+
146
+ # Load model and tokenizer
147
+ model = transformers.AutoModelForCausalLM.from_pretrained(
148
+ args.base_model,
149
+ config=config,
150
+ cache_dir=args.cache_dir,
151
+ torch_dtype=torch.float16,
152
+ load_in_4bit=True,
153
+ device_map="auto",
154
+ )
155
+ model.resize_token_embeddings(32001)
156
+
157
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
158
+ args.base_model,
159
+ cache_dir=args.cache_dir,
160
+ model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
161
+ padding_side="right",
162
+ use_fast=False,
163
+ )
164
+
165
+ model.eval()
166
+ if torch.__version__ >= "2" and sys.platform != "win32":
167
+ model = torch.compile(model)
168
+ # import pdb; pdb.set_trace()
169
+ respond = build_generator(model, tokenizer)
170
+
171
+
172
+ chat_interface = gr.ChatInterface(
173
+ fn=respond,
174
+ textbox=gr.Textbox(lines=1, placeholder=None, label="Question"),
175
+ chatbot= gr.Chatbot(label="Jack's ALM Chat...", show_share_button=True),
176
+ additional_inputs=[
177
+ gr.Slider(label="Max new tokens", minimum=1, maximum=args.max_gen_len, step=1, value=args.max_gen_len),
178
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=args.temperature),
179
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=args.top_p),
180
+ gr.Dropdown(label="Chat type", choices=["line-item-jack"], value=args.chat_type),
181
+ ],
182
+ # stop_btn=None,
183
+ examples=[
184
+ # ["Hello there! How are you doing?"],
185
+ # ["What are your capabilities?"],
186
+ # ["What is T & E?"],
187
+ # ["How do I run an expense claim?"],
188
+ # ["Who built you?"],
189
+ # ["Tell me a joke"],
190
+ # ["Tell me an accounting story"],
191
+ ["INVOICE\nAA Associates\nNo 3 Click Street, Manhattan, NY\nNY 35284\nPhone (243) 758-4368\nFax (243) 758-5839\nAA\nINVOICE NO: 2491839\nDATE: 24/06/2019\nBILL TO\nBulk Cars\n5 Grid Avenue, NY\nNY 34582\nQUANTITY\n10\n10\n100\nSHIP TO\nSame as recipient\nINSTRUCTIONS\nConfirm before collection\nDESCRIPTION\nPack o Pencils\nPack of Fens\nReams of Faber\nCODE\nASD001\nASD013\nASD006\nSUBTOTAL\nSALES TAX - 5%\nSHIPPING & HANDLING\nTOTAL DUE AMOUNT\nUNIT PRICE\n£30.00\n£80.00\n£10.00\nTOTAL\n£300.00\n£800.00\n£1,000.00\n£2,100.00\n£105.00\n£50.00\n£2,255.00\nTAX IDENTIFICATION NUMBER\n13684937293-AJK\nTHANK YOU\nPayment should be made within 30 days of receipt of shipment. Failure to do so will attract 1% of \notal"],
192
+ ["INVOICE \nALBION \nAccounts Address: \nDelivery Address: \nKino Rye *DD* \nLion Street \nFINE FOODS \nKino Digital Limited \nRye \nDanum House \nEast Sussex \nUnits 21 - 22 Sovereign Way \n6a South Parade \nGB \nTonbridge \nKent \nDoncaster \nTN31 7LB \nTN9 1RH \nDN1 2DY \n01732 757 900 #4 \nsalesledger@albionff.co.uk \nDelivery Instructions: \nDrop No. \nDel. Date \n15/04/2023 \nA/C No. \nKINO01 \nOur Ref \nSIND66864 \nYour Ref \nQty \nUnit \nDescription \nUnit Price \nLine VAT \nLine Net \nTick \n1.000 \neach \nRed Onion Marmalade (Confit) 2.4kg \n19.72 \n0.00 \n19.72 \n1.000 \neach \nFairfields Lightly Salted 36 X 40g \n16.61 \n3.32 \n16.61 \n1.000 \neach \nMargarine Flora 2kg \n8.59 \n0.00 \n8.59 \n6.000 \neach \nMilk Fresh Semi-Skimmed 2ltr \n1.79 \n0.00 \n10.74 \nTemp. / Time: \nGoods remain the property \nNet Total \n55.66 \nof Albion Fine Foods Ltd \nCust Name: \nuntil this invoice is paid in \nfull. All sales are subject to \nVAT Content \n3.32 \nour Terms and Conditions \nCust Signature: \navailable upon request and \nat www.albionfinefoods.com \nTotal \n58.98 \nDriver Signature \nAlbion Fine Foods Ltd - Reg. No. 10379589 - VAT No. GB252036928"],
193
+ ["BUSINESS\nSTUDY GROUP\nCompany AddressSuite C2, Triple-H Plaza, Near Christ Embassy Church, Wuye District. Al\nQuotation #\nCustomer ID\nGRN002\nDAU123\nDate 29/06/2020\nPrepared by: BSG\nQuotation For\nCustomer Name Daulat Abubakar Yar'adua\nCompany Name Furayya Enterprise\nPhone, Fax Num (+234) 8036101908\nComments or Special Instructions\nNone\nSalesperson\nP.O. Number\nShip Date\nF.O.B. Point\nTerms\nDue on receint\nQuantity\n1\n1\n2\n3\nDescription\nStrategy and Advisory\nBSG Administrative Fees\nJaiz Application Support\nDalema Proposal and\nIterations\nUnit Price\n£37,500.00\n£45,700.00\n£12,500.00\n£8,500.00\nTaxable?\nYes\nYes\nAmount\n£37,500.00\n£45,700.00\n£25,000.00\n£25,500.00\nIf you have any questions concerning this quotation, please contact:\nKizito\nThank you for your business!\nSubtotal\n£133,700.00\nTax Rate\nSales Tax\nOther20%\n£6,240.00\nTOTAL £139,940.08"],
194
+ ["INVOICE FOR: \nD79077 \nKino Rye \nALBION \nAccounts Address:- \nDelivery Address:- \nKino Rye \nFINE FOODS \nKino Rye \nLion Street \n21 - 22 Sovereign Way \nKino Digital Limited \nRye \nTonbridge, TN9 1RH \nDanum House \nEast Sussex \nAccs: 01732 757 900 #2 \n6a South Parade \nTN31 7LB \nsalesledger@albionff.co.uk \nDoncaster \nDN1 2DY \nDrop No. \n64-07 \nCustomer Phone No: 01797226 Main \nDel. Date \n26/04/2023 \nDelivery Instructions: \nA/C No. \nKINO01 \nDelivery after 10.30am daily. Delivery driver can park on \nOur Ref \nSIND79077 \nthe shared drive in front of Kino, at the top of Lion St, \nopposite the Town Hall. If before staff arrive daily at \nYour Ref \n10.30 there is a black dustbin in the shade by the side gate \nPage 1 \nKey \nfor fresh food deliveries. Combination padlock on gates to drive \nOUR BANK DETAILS HAVE CHANGED: \nName: Albion Fine Foods Ltd \nSort: 40 M-60 \nAcc No.: \n83018792 \nQty \nUnit \nCode \nDescription \nUnit Price Line Vat Line Net \n000 \neach \nONIONCONFIT2 \nRed Onion Marmalade (Confit) 2.4kg \n19,72 \n19.72 \n1000 \neach \nMUSTDLI \nMustard Dijon 1kg GREEN STICKER \n3.98 \n3.98 \n1.000 \neach \nOILSXVS \nXV Olive Oil Sitr \n26.89 \n26,89 \n2.000 \npack \nGLOVEBLIMP \nGloves Blue Vinyl Med Pwd Free x100 \n8.46 \n3,38 \n16.92 \n1.000 \neach \nBLEACHTHICKS \nBleach Thick 5ltr \n4.69 \n0.94 \n4,69 \nT.000 \neach \nSOAPHANDBAC5 \nHand Soap Bactericidal 5ltr (SECH) \n7.79 \n1,56 \n7.79 \nCust Name \nGoods remain the property of \nNet Total \nAlbion Fine Foods Ltd unw this \n79.99 \nCust Signature \ninvoice is paid ND full AN sales are \nVAT Content \nsubject fo OW Terms and \n5.88 \nConditions available upon request \nTotal \nand at www.albionfinefoods com \n85.87 \nIf not signed for by customor, why? \nWe have out of hours access \n854 \nNo one was on-site \nTemp. / Time: \n+3-17 \nDriver Signature: \nDR \nAlbion Fine Foods Ltd - Reg. No. 10379589 VAT No. GB252036928"],
195
+ ["MITASU JAPANESE RESTAURANT SDN BHD \nB-01, CENTRAL PLAZA, \n34,JALAN SULTAN ISMAIL, \n50250 KUALA LUMPUR \nTEL 03-2110 2833 \n(GST Reg. No 001774428160 \nTax Invoice \nTable D2 \nOdr No: 199535 \nBill#:V001-201060 \nDate : 29-06-2018 19:59:15 \nPax(s): 11 \nCashier: AARON \nTotal TAX \nQty \nDescription \n11 (28.01) Adult \nD 709.50 SR \n709.50 \nSubtotal: \n70.95 \nServ. Charge (10%): \n0.00 \nGST Payable (0%): \n780.45 \nTotal: \n780.45 \nTOTAL: \nClosed: 001 \n29-06-2018 21:52:03 \nServer: AARON \n780.45 \nVISA \n- ******3042 \n- LIM CHAI JA"],
196
+ ["INDIA ANANDA PRIVATE BHAVAN LTD SWEETS \n[A2B VEG RESTAURANT] \nNO, ,27,BDA COMPLEX \n- HSR LAYOUT - BANGALORE \n:560102Ph:25725399 \nKARNATAKA \nGSTIN:29AAICA3787F1ZC \nINVOICE \nB.No : CTR116/138928 \n/ \nPay Mode:CARD \nSman: SHASHI KUMARA.M \nDate :09/Jul72017 11:28:44 AM \n841.1 By: SONU KUMAR \nParticulars \nGST HSN/SAC \nQty \nSEAL \nRate \nAmount \nPOMEGRANATE JUICE \n18% 00441067 \n1.000 \n70.00 \n70.00 \nONION UTTAPAM \n18% 00441067 \n1.000 \n80.00 \n80.00 \nTot Itms2 \nSub Total \n150.00 \nSGST 9 % \n13.50 \nCGST 9 % \n13.50 \nTotal Invoice \n177.00 \nBILL AMOUNT 177.00/- \nerminal No : SS-88077 \nOff: \nNO 9, MAHATMA GANDHI ROAD , SHASTRI NAGAR, ADYAR \nCHENNAI, PINCODE:600020 Website: aabsweets.com"],
197
+ ],
198
+ )
199
+
200
+ with gr.Blocks(css="style.css") as demo:
201
+ gr.Markdown(description)
202
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
203
+ chat_interface.render()
204
+
205
+ demo.queue()
206
+ demo.launch(server_name=args.host, server_port=args.port, show_error=True, share=True)
207
+
208
+ if __name__ == "__main__":
209
+ args = parse_config()
210
+ generate(args)
211
+
gorilla/code_interpreter.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import traceback
2
+ # import sys
3
+ # import os
4
+ # import builtins
5
+
6
+ # # Namespace for variable storage
7
+ # custom_namespace = {}
8
+
9
+ # # Base directory for file storage
10
+ # base_storage_path = '/path/to/safe/storage'
11
+
12
+ # def execute_code(code, local_namespace=None):
13
+ # if local_namespace is None:
14
+ # local_namespace = {}
15
+
16
+ # try:
17
+ # # Override the built-in __import__ if needed
18
+ # # def __custom_import__(name, globals=None, locals=None, fromlist=(), level=0):
19
+ # # if name not in safe_imports:
20
+ # # raise ImportError(f"Import of {name} is not allowed")
21
+ # # return original_import(name, globals, locals, fromlist, level)
22
+
23
+ # # original_import = builtins.__import__
24
+ # # builtins.__import__ = __custom_import__
25
+
26
+ # # Redirect file operations to a safe directory
27
+ # # os.chdir(base_storage_path)
28
+
29
+ # compiled_code = compile(code, "<string>", 'exec')
30
+ # exec(compiled_code, custom_namespace, local_namespace)
31
+ # except Exception as e:
32
+ # exc_type, exc_value, exc_traceback = sys.exc_info()
33
+ # formatted_lines = traceback.format_exc().splitlines()
34
+ # error_message = "\n".join(formatted_lines)
35
+ # print(f"An exception occurred: {error_message}", file=sys.stderr)
36
+ # # finally:
37
+ # # Reset the built-in __import__ to its original state if overridden
38
+ # # builtins.__import__ = original_import
39
+
40
+ # # Redirect back to the original directory if changed
41
+ # # os.chdir(original_directory)
42
+
43
+ # # Example usage
44
+ # code_to_run = """
45
+ # import os
46
+ # print("Hello, World!")
47
+ # print(os.getcwd()) # This will print the current working directory
48
+ # """
49
+ # execute_code(code_to_run, custom_namespace)
50
+
51
+
52
+
53
+ # from IPython.core.interactiveshell import InteractiveShell
54
+ # shell = InteractiveShell()
55
+
56
+
57
+ # def execute_code_ipython(code):
58
+ # try:
59
+ # # Execute the code
60
+ # result = shell.run_cell(code)
61
+ # if result.error_in_exec is not None:
62
+ # raise result.error_in_exec
63
+ # except Exception as e:
64
+ # # Handle exceptions
65
+ # print(f"An error occurred: {e}")
66
+
67
+
68
+ # # Example usage
69
+ # code_to_run = """
70
+ # import os
71
+
72
+ # print("Hello, World!")
73
+ # print(os.getcwd())
74
+ # """
75
+ # execute_code_ipython(code_to_run)
76
+
77
+
78
+ import ipyparallel as ipp
79
+ import os
80
+
81
+
82
+ # run `ipcluster start -n 4` in the terminal to start the cluster using os.system
83
+ os.system('ipcluster start -n 4')
84
+
85
+ client = ipp.Client()
86
+ dview = client[:]
87
+
88
+ def execute_code_parallel(code):
89
+ # Use the `execute` method of the DirectView
90
+ async_results = dview.execute(code)
91
+
92
+ # Gathering and returning results
93
+ dview.wait(async_results) # Wait for all engines to complete execution
94
+
95
+ results = []
96
+ for ar in async_results:
97
+ if ar.error is not None:
98
+ # Error handling
99
+ results.append(f"Error on engine {ar.engine_id}: {ar.error}")
100
+ else:
101
+ # Collect results
102
+ results.append(ar.result())
103
+
104
+ return results
105
+
106
+
107
+
108
+ # Example usage
109
+ # code_to_run = "import os; os.getpid()" # Simple code to test parallel execution
110
+ code_to_run = """
111
+ import os
112
+
113
+ print("Hello, World!")
114
+ print(os.getcwd())
115
+ """
116
+ execute_code_parallel(code_to_run)
117
+ # print(results)
gorilla/ds_configs/stage2.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_micro_batch_size_per_gpu": "auto",
3
+ "gradient_accumulation_steps": "auto",
4
+ "gradient_clipping": "auto",
5
+ "zero_allow_untested_optimizer": true,
6
+ "bf16": {
7
+ "enabled": "auto",
8
+ "loss_scale": 0,
9
+ "initial_scale_power": 16,
10
+ "loss_scale_window": 1000,
11
+ "hysteresis": 2,
12
+ "min_loss_scale": 1
13
+ },
14
+ "zero_optimization": {
15
+ "stage": 2,
16
+ "allgather_partitions": true,
17
+ "allgather_bucket_size": 1e9,
18
+ "reduce_scatter": true,
19
+ "reduce_bucket_size": 1e9,
20
+ "overlap_comm": true,
21
+ "contiguous_gradients": true
22
+ }
23
+ }
gorilla/ds_configs/stage3.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto"
4
+ },
5
+ "optimizer": {
6
+ "type": "AdamW",
7
+ "params": {
8
+ "lr": "auto",
9
+ "betas": "auto",
10
+ "eps": "auto",
11
+ "weight_decay": "auto"
12
+ }
13
+ },
14
+ "scheduler": {
15
+ "type": "WarmupDecayLR",
16
+ "params": {
17
+ "total_num_steps": "auto",
18
+ "warmup_min_lr": "auto",
19
+ "warmup_max_lr": "auto",
20
+ "warmup_num_steps": "auto"
21
+ }
22
+ },
23
+ "zero_optimization": {
24
+ "stage": 3,
25
+ "offload_optimizer": {
26
+ "device": "cpu",
27
+ "pin_memory": true
28
+ },
29
+ "offload_param": {
30
+ "device": "cpu",
31
+ "pin_memory": true
32
+ },
33
+ "overlap_comm": true,
34
+ "contiguous_gradients": true,
35
+ "sub_group_size": 1e9,
36
+ "reduce_bucket_size": "auto",
37
+ "stage3_prefetch_bucket_size": "auto",
38
+ "stage3_param_persistence_threshold": "auto",
39
+ "stage3_max_live_parameters": 1e9,
40
+ "stage3_max_reuse_distance": 1e9,
41
+ "stage3_gather_16bit_weights_on_model_save": false
42
+ },
43
+ "gradient_accumulation_steps": "auto",
44
+ "gradient_clipping": "auto",
45
+ "steps_per_print": 5,
46
+ "train_batch_size": "auto",
47
+ "train_micro_batch_size_per_gpu": "auto",
48
+ "wall_clock_breakdown": false
49
+ }
gorilla/eval.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Yukang Chen
2
+ # Some code based on https://github.com/epfml/landmark-attention
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import math
18
+ import torch
19
+ import argparse
20
+ import random
21
+ import numpy as np
22
+ from tqdm import tqdm
23
+ import transformers
24
+ from peft import PeftModel
25
+ from llama_attn_replace import replace_llama_attn
26
+
27
+ def parse_config():
28
+ parser = argparse.ArgumentParser(description='arg parser')
29
+ parser.add_argument('--batch_size', type=int, default=32, help='batch size during inference')
30
+ parser.add_argument('--base_model', type=str, default="meta-llama/Llama-2-13b-hf")
31
+ parser.add_argument('--cache_dir', type=str, default="./cache")
32
+ parser.add_argument('--seq_len', type=int, default=2048, help='context length during evaluation')
33
+ parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
34
+ parser.add_argument('--peft_model', type=str, default=None, help='')
35
+ parser.add_argument('--flash_attn', type=bool, default=True, help='')
36
+ parser.add_argument('--data_path', type=str, default="./test.bin", help='')
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+ def get_as_batch(data, seq_length, batch_size, device='cpu', sliding_window=256):
41
+ all_ix = list(range(0, len(data) - seq_length, sliding_window))
42
+ all_ix.pop()
43
+
44
+ for idx in range(0, len(all_ix), batch_size):
45
+ ix = all_ix[idx:idx+batch_size]
46
+ assert all([idx + seq_length + 1 <= len(data) for idx in ix])
47
+ x = torch.stack([torch.from_numpy((data[i:i+seq_length]).astype(np.int64)) for i in ix])
48
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_length]).astype(np.int64)) for i in ix])
49
+ if device != 'cpu':
50
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
51
+ yield x, y
52
+
53
+ def iceildiv(x, y):
54
+ return (x + y - 1) // y
55
+
56
+ def evaluate(model, data, batch_size, device, seq_length, sliding_window=256, use_cache=False):
57
+ stats = {}
58
+
59
+ model.eval()
60
+
61
+ loss_list_val, acc_list = [], []
62
+ loss_step_list_val = []
63
+
64
+ with torch.no_grad():
65
+ print(f"Using seq length {seq_length}")
66
+ torch.set_printoptions(sci_mode=False)
67
+ for idx, (x, y) in tqdm(
68
+ enumerate(
69
+ get_as_batch(
70
+ data['val'],
71
+ seq_length,
72
+ batch_size,
73
+ device=device,
74
+ sliding_window=sliding_window
75
+ )
76
+ ),
77
+ total=iceildiv(
78
+ iceildiv(len(data['val']), sliding_window),
79
+ batch_size
80
+ )
81
+ ):
82
+ val_loss = 0.
83
+ acc = 0.
84
+ cnt = 0
85
+
86
+ for part_idx, i in enumerate(range(0, x.shape[1], seq_length)):
87
+ part_len = x[:, i:i + seq_length].shape[1]
88
+
89
+ outputs = model(
90
+ input_ids=x[:, i:i + seq_length],
91
+ labels=x[:, i:i+seq_length].contiguous(),
92
+ use_cache=use_cache)
93
+
94
+ val_loss = outputs.loss * part_len + val_loss
95
+ acc = ((outputs.logits.argmax(-1) == y[:, i:i+seq_length]).float().sum()) + acc
96
+ cnt += part_len
97
+ while len(loss_step_list_val) <= part_idx:
98
+ loss_step_list_val.append([])
99
+ loss_step_list_val[part_idx].append(outputs.loss.item())
100
+ val_loss /= cnt
101
+ acc /= cnt
102
+
103
+ loss_list_val.append(val_loss.item())
104
+ acc_list.append(acc.item())
105
+
106
+ stats['val_acc'] = torch.as_tensor(acc_list).mean().item()
107
+ stats['val_loss'] = torch.as_tensor(loss_list_val).mean().item()
108
+ stats['val_perplexity'] = 2.71828 ** stats['val_loss']
109
+ stats['val_perplexity_per_chunk'] = torch.exp(torch.as_tensor(loss_step_list_val).mean(dim=1))
110
+
111
+ return stats
112
+
113
+ def main(args):
114
+
115
+ device = "cuda:0"
116
+ seed = 2
117
+ torch.cuda.set_device(device)
118
+
119
+ torch.manual_seed(seed)
120
+ random.seed(seed)
121
+ np.random.seed(seed)
122
+
123
+ data = {'val': np.memmap(args.data_path, dtype=np.uint16, mode='r')}
124
+
125
+ print(f"Num validation tokens: {len(data['val'])}")
126
+ print("data path", args.data_path)
127
+ print("base model", args.base_model)
128
+ print("peft model", args.peft_model)
129
+
130
+ if args.flash_attn:
131
+ replace_llama_attn(use_flash_attn=True, use_full=True)
132
+
133
+ # Set RoPE scaling factor
134
+ config = transformers.AutoConfig.from_pretrained(
135
+ args.base_model,
136
+ cache_dir=args.cache_dir,
137
+ )
138
+
139
+ context_size = args.context_size if args.context_size > 0 else args.seq_len
140
+ orig_ctx_len = getattr(config, "max_position_embeddings", None) # this value should be 4096 for LLaMA2 models
141
+ if orig_ctx_len and context_size > orig_ctx_len:
142
+ scaling_factor = float(math.ceil(context_size / orig_ctx_len))
143
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
144
+
145
+ # Load model and tokenizer
146
+ model = transformers.AutoModelForCausalLM.from_pretrained(
147
+ args.base_model,
148
+ config=config,
149
+ cache_dir=args.cache_dir,
150
+ torch_dtype=torch.float16,
151
+ device_map="auto",
152
+ )
153
+ model.resize_token_embeddings(32001)
154
+
155
+ if args.peft_model:
156
+ trainable_params = os.path.join(args.peft_model, "trainable_params.bin")
157
+ if os.path.isfile(trainable_params):
158
+ model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False)
159
+ else:
160
+ raise ValueError("Trainable input embedding and normalization are required.")
161
+ model = PeftModel.from_pretrained(
162
+ model,
163
+ args.peft_model,
164
+ device_map="auto",
165
+ torch_dtype=torch.float16,
166
+ )
167
+
168
+ stats = evaluate(model, data, args.batch_size, device, args.seq_len, sliding_window=256)
169
+
170
+ print(stats)
171
+
172
+
173
+ if __name__ == "__main__":
174
+ args = parse_config()
175
+ main(args)
gorilla/fine-tune.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from dataclasses import dataclass, field
4
+ from functools import partial
5
+ from typing import Dict, Optional, Sequence
6
+
7
+ import torch
8
+ import transformers
9
+ import pandas as pd
10
+ from torch.utils.data import Dataset
11
+ from transformers import Trainer, DataCollatorForLanguageModeling
12
+ from llama_attn_replace import replace_llama_attn
13
+ from peft import LoraConfig, get_peft_model
14
+ from torch.distributed import barrier
15
+
16
+ import datasets
17
+ from datasets import load_dataset
18
+
19
+ IGNORE_INDEX = -100
20
+ DEFAULT_PAD_TOKEN = "[PAD]"
21
+ DEFAULT_EOS_TOKEN = "</s>"
22
+ DEFAULT_BOS_TOKEN = "<s>"
23
+ DEFAULT_UNK_TOKEN = "<unk>"
24
+
25
+
26
+ @dataclass
27
+ class ModelArguments:
28
+ model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
29
+ model_type: Optional[str] = field(default="llama")
30
+
31
+ @dataclass
32
+ class TrainingArguments(transformers.TrainingArguments):
33
+ cache_dir: Optional[str] = field(default=None)
34
+ optim: str = field(default="adamw_torch")
35
+ model_max_length: int = field(
36
+ default=8192 * 4,
37
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
38
+ )
39
+ use_flash_attn: bool = field(
40
+ default=True,
41
+ metadata={"help": "Whether use flash attention for training."},
42
+ )
43
+ use_full_attn: bool = field(
44
+ default=False,
45
+ metadata={"help": "Whether to use plain, full-attention for training."},
46
+ )
47
+ low_rank_training: bool = field(
48
+ default=True,
49
+ metadata={"help": "Whether use low rank adaptation for training."},
50
+ )
51
+ trainable_params: str = field(
52
+ default="embed,norm",
53
+ metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
54
+ )
55
+
56
+ def smart_tokenizer_and_embedding_resize(
57
+ special_tokens_dict: Dict,
58
+ tokenizer: transformers.PreTrainedTokenizer,
59
+ model: transformers.PreTrainedModel,
60
+ ):
61
+ """Resize tokenizer and embedding.
62
+
63
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
64
+ """
65
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
66
+ model.resize_token_embeddings(len(tokenizer))
67
+
68
+ if num_new_tokens > 0:
69
+ input_embeddings = model.get_input_embeddings().weight.data
70
+ output_embeddings = model.get_output_embeddings().weight.data
71
+
72
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
73
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
74
+
75
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
76
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
77
+
78
+ def tokenize_fn(tokenizer, example):
79
+ context_length = tokenizer.model_max_length
80
+ outputs = tokenizer(
81
+ tokenizer.eos_token.join(example["text"]),
82
+ truncation=False,
83
+ return_tensors="pt",
84
+ pad_to_multiple_of=context_length,
85
+ padding=True,
86
+ )
87
+ return {"input_ids": outputs["input_ids"].view(-1, context_length)}
88
+
89
+ def train():
90
+ parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
91
+ model_args, training_args = parser.parse_args_into_dataclasses()
92
+
93
+ # NOTE: May expand supported model types in the future
94
+ # if model_args.model_type == "gpt-neox":
95
+ # replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
96
+ # else:
97
+ # assert model_args.model_type == "llama", "Only support llama and gpt-neox for now"
98
+ replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
99
+
100
+ # Set RoPE scaling factor
101
+ config = transformers.AutoConfig.from_pretrained(
102
+ model_args.model_name_or_path,
103
+ cache_dir=training_args.cache_dir,
104
+ )
105
+
106
+ orig_rope_scaling = getattr(config, "rope_scaling", {"factor": 1})
107
+ orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
108
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
109
+ if orig_ctx_len:
110
+ orig_ctx_len *= orig_rope_scaling_factor
111
+ if training_args.model_max_length > orig_ctx_len:
112
+ scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
113
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
114
+
115
+ # Load model and tokenizer
116
+ model = transformers.AutoModelForCausalLM.from_pretrained(
117
+ model_args.model_name_or_path,
118
+ config=config,
119
+ cache_dir=training_args.cache_dir,
120
+ torch_dtype=torch.bfloat16,
121
+ )
122
+
123
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
124
+ model_args.model_name_or_path,
125
+ cache_dir=training_args.cache_dir,
126
+ model_max_length=training_args.model_max_length,
127
+ padding_side="right",
128
+ use_fast=True,
129
+ )
130
+
131
+ special_tokens_dict = dict()
132
+ if tokenizer.pad_token is None:
133
+ special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
134
+ if tokenizer.eos_token is None:
135
+ special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
136
+ if tokenizer.bos_token is None:
137
+ special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
138
+ if tokenizer.unk_token is None:
139
+ special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
140
+
141
+ smart_tokenizer_and_embedding_resize(
142
+ special_tokens_dict=special_tokens_dict,
143
+ tokenizer=tokenizer,
144
+ model=model,
145
+ )
146
+
147
+ rank = int(os.environ.get('RANK', -1))
148
+ if rank > 0:
149
+ barrier()
150
+ # dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir)
151
+
152
+ print('Loading line item data')
153
+ df_i = pd.read_csv('/home/tosi-n/ark/data/jack_line_item_ner_task.csv', sep='\t')[['context', 'instruction', 'response']]
154
+ df_ii = pd.read_csv('/home/tosi-n/ark/data/alm_task_data.csv')[['context', 'instruction', 'response']]
155
+ df = pd.concat([df_i, df_ii], ignore_index=True)
156
+ # Replace NoneType with empty string
157
+ df = df.fillna('')
158
+ alm_task_data = datasets.Dataset.from_pandas(df)
159
+ alm_task_data = (alm_task_data
160
+ .remove_columns('context')
161
+ # .rename_column('context', 'input')
162
+ .rename_column('response', 'output'))
163
+
164
+ dataset = alm_task_data.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=128)#, remove_columns=["text", "meta"]
165
+
166
+ if rank == 0:
167
+ barrier()
168
+
169
+ print(dataset)
170
+
171
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
172
+
173
+ if training_args.low_rank_training:
174
+ if model_args.model_type == "gpt-neox":
175
+ # added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
176
+ targets = ["query_key_value", "dense"]
177
+ else:
178
+ targets=["q_proj", "k_proj", "v_proj", "o_proj"]
179
+
180
+ config = LoraConfig(
181
+ r=8,
182
+ lora_alpha=16,
183
+ target_modules=targets,
184
+ lora_dropout=0,
185
+ bias="none",
186
+ task_type="CAUSAL_LM",
187
+ )
188
+ model = get_peft_model(model, config)
189
+ # enable trainable params
190
+ [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
191
+
192
+ model.config.use_cache = False # required for gradient checkpointing
193
+ model.enable_input_require_grads() # required for gradient checkpointing
194
+ model.gradient_checkpointing_enable() # enable gradient checkpointing
195
+ trainer = Trainer(
196
+ model=model, tokenizer=tokenizer, args=training_args,
197
+ train_dataset=dataset["train"],
198
+ eval_dataset=None,
199
+ data_collator=data_collator)
200
+ trainer.train()
201
+ trainer.save_state()
202
+ trainer.save_model(output_dir=training_args.output_dir)
203
+
204
+
205
+ if __name__ == "__main__":
206
+ train()
gorilla/get_trainable_weights.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+
5
+ def parse_config():
6
+ parser = argparse.ArgumentParser(description='arg parser')
7
+ parser.add_argument('--checkpoint_path', type=str, default="/home/tosi-n/ark/jack-alm/checkpoint-800/")
8
+ parser.add_argument('--trainable_params', type=str, default="embed,norm")
9
+ args = parser.parse_args()
10
+ return args
11
+
12
+
13
+ def main(args):
14
+ path = args.checkpoint_path
15
+ trainable_params = args.trainable_params.split(",")
16
+
17
+ weights_all = torch.load(os.path.join(path, "pytorch_model.bin"))
18
+
19
+ weights_trainable = {}
20
+ weights_lora = {}
21
+ for k in weights_all:
22
+ if "lora" in k:
23
+ k_new = k.replace("default.", "") if "default." in k else k
24
+ weights_lora[k_new] = weights_all[k]
25
+ else:
26
+ if any([n in k for n in trainable_params]):
27
+ weights_trainable[k[17:]] = weights_all[k]
28
+
29
+ adapter_model = os.path.join(path, "adapter_model.bin")
30
+ trainable_params = os.path.join(path, "trainable_params.bin")
31
+ if not os.path.isfile(adapter_model):
32
+ torch.save(weights_lora, adapter_model)
33
+ torch.save(weights_trainable, trainable_params)
34
+
35
+ if __name__ == "__main__":
36
+ args = parse_config()
37
+ main(args)
gorilla/infer.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import sys
4
+ import time
5
+ import transformers
6
+ import torch
7
+ from threading import Thread
8
+ from transformers import TextIteratorStreamer
9
+ from llama_attn_replace import replace_llama_attn
10
+
11
+ def parse_config():
12
+ parser = argparse.ArgumentParser(description='arg parser')
13
+ parser.add_argument('--base_model', type=str, default="jenesys-ai/jack-alm-13b-8k-hf")
14
+ parser.add_argument('--cache_dir', type=str, default="./cache")
15
+ parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
16
+ parser.add_argument('--flash_attn', type=bool, default=True, help='')
17
+ parser.add_argument('--temperature', type=float, default=0.1, help='')
18
+ parser.add_argument('--top_p', type=float, default=1, help='')
19
+ parser.add_argument('--max_gen_len', type=int, default=512, help='')
20
+ parser.add_argument('--chat_type', type=str, default='conversational-jack', help='Chat type: conversational-jack, line-item-jack')
21
+ args = parser.parse_args()
22
+ return args
23
+
24
+ def build_generator(model, tokenizer, use_cache=True):
25
+ def response(message, max_gen_len, temperature, top_p, chat_type='conversational-jack'):
26
+
27
+ prompt_template = (
28
+ # "Below is an instruction that describes a task. "
29
+ """You're Jack an virtual accountant created by Jenesys HQ Ltd.
30
+ You are able to communicate in a polite manner, with emotions of ecstasy, trust and jokes, at a Professional level
31
+ with a very preserve English communication culture. Answer the following questions as best you can,
32
+ but speaking as a british elite from the 21th century might speak.
33
+ """
34
+ """As a virtual accountant designed to follow the user's instructions carefully.
35
+ You are responsible for a range of financial task, operations and queries as listed below:
36
+ 1. Budget balance inquiry
37
+ 2. Expense request
38
+ 3. Company policy enquiries
39
+ 4. Financial and accounting queries
40
+ 5. Limited general enquiries
41
+ """
42
+ "Write a response that appropriately completes the request.\n\n"
43
+ "### Instruction:\n{instruction}\n\n### Response:"
44
+ )
45
+
46
+
47
+ line_item_prompt_template = (
48
+ "#Invoice line item extraction - "
49
+ "You Jack are an accounting domain named entities recognizer to complete the following task:\n\n"
50
+ "### Invoice input-:\n{instruction}\n Return Response as a list of dictionary for each line item 'Description', 'Quantity', 'Unit_price', 'Tax %', 'Total'. \n\n### Response:"
51
+ )
52
+
53
+
54
+ if chat_type == 'conversational-jack':
55
+ prompt = prompt_template.format(instruction=message)
56
+ elif chat_type == 'line-item-jack':
57
+ prompt = line_item_prompt_template.format(instruction=message)
58
+ # prompt = conversation
59
+
60
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
61
+
62
+ if len(inputs['input_ids'][0]) > 8192:
63
+ return "This demo supports tokens less than 8192, while the current is %d. Please use material with less tokens."%len(inputs['input_ids'][0])
64
+ torch.cuda.empty_cache()
65
+
66
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
67
+ generate_kwargs = dict(**inputs,
68
+ max_new_tokens=max_gen_len,
69
+ temperature=temperature,
70
+ top_p=top_p,
71
+ use_cache=use_cache,
72
+ streamer=streamer,
73
+ )
74
+
75
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
76
+ t.start()
77
+
78
+ generated_text = ""
79
+ start_time = time.time()
80
+
81
+ for new_text in streamer:
82
+ generated_text += new_text
83
+ tokens_per_sec = len(generated_text.split()) / (time.time() - start_time)
84
+
85
+ suffix = f" ({tokens_per_sec:.2f} tokens/sec)"
86
+ # # yield f"{generated_text} ({tokens_per_sec:.2f} tokens/sec)"
87
+ sys.stdout.write(f"\r\033[K{generated_text}{suffix}")
88
+ sys.stdout.flush()
89
+ # sys.stdout.write("\n") # Move to a new line after generation is complete
90
+ return generated_text
91
+
92
+ return response
93
+
94
+ def load_model():
95
+ args = parse_config()
96
+
97
+ if args.flash_attn:
98
+ replace_llama_attn(inference=True)
99
+
100
+ # Set RoPE scaling factor
101
+ config = transformers.AutoConfig.from_pretrained(
102
+ args.base_model,
103
+ cache_dir=args.cache_dir,
104
+ )
105
+
106
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
107
+ if orig_ctx_len and args.context_size > orig_ctx_len:
108
+ scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
109
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
110
+
111
+ # Load model and tokenizer
112
+ model = transformers.AutoModelForCausalLM.from_pretrained(
113
+ args.base_model,
114
+ config=config,
115
+ cache_dir=args.cache_dir,
116
+ torch_dtype=torch.float16,
117
+ load_in_4bit=True,
118
+ device_map="auto",
119
+ )
120
+ model.resize_token_embeddings(32001)
121
+
122
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
123
+ args.base_model,
124
+ cache_dir=args.cache_dir,
125
+ model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
126
+ padding_side="right",
127
+ use_fast=True,
128
+ )
129
+
130
+ model.eval()
131
+ respond = build_generator(model, tokenizer)
132
+ return respond
133
+
134
+ respond = load_model()
135
+
136
+ def generate_response(message, max_gen_len=512, temperature=0.1, top_p=1, chat_type='line-item-jack'):
137
+ return respond(
138
+ message=message,
139
+ max_gen_len=max_gen_len,
140
+ temperature=temperature,
141
+ top_p=top_p,
142
+ chat_type=chat_type
143
+ )
gorilla/llama_attn_replace.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified based on https://github.com/lm-sys/FastChat
2
+
3
+ import warnings
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ import transformers
9
+ from einops import rearrange
10
+ from flash_attn import __version__ as flash_attn_version
11
+ from flash_attn.bert_padding import pad_input, unpad_input
12
+ from flash_attn.flash_attn_interface import (
13
+ flash_attn_func,
14
+ flash_attn_varlen_kvpacked_func,
15
+ flash_attn_varlen_qkvpacked_func
16
+ )
17
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, rotate_half
18
+ from flash_attn.bert_padding import unpad_input, pad_input
19
+ import math
20
+
21
+ group_size_ratio = 1/4
22
+ def forward_flashattn(
23
+ self,
24
+ hidden_states: torch.Tensor,
25
+ attention_mask: Optional[torch.Tensor] = None,
26
+ position_ids: Optional[torch.Tensor] = None,
27
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
28
+ output_attentions: bool = False,
29
+ use_cache: bool = False,
30
+ padding_mask: Optional[torch.LongTensor] = None,
31
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
32
+ """Input shape: Batch x Time x Channel
33
+
34
+ attention_mask: [bsz, q_len]
35
+ """
36
+ if not self.training:
37
+ warnings.warn("This function should be used just for training as it may exhibit reduced inference performance. For inference, please use forward_flashattn_inference.")
38
+
39
+ if output_attentions:
40
+ warnings.warn(
41
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
42
+ )
43
+
44
+ bsz, q_len, _ = hidden_states.size()
45
+
46
+ query_states = (
47
+ self.q_proj(hidden_states)
48
+ .view(bsz, q_len, self.num_heads, self.head_dim)
49
+ .transpose(1, 2)
50
+ )
51
+ key_states = (
52
+ self.k_proj(hidden_states)
53
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
54
+ .transpose(1, 2)
55
+ )
56
+ value_states = (
57
+ self.v_proj(hidden_states)
58
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
59
+ .transpose(1, 2)
60
+ )
61
+ # [bsz, q_len, nh, hd]
62
+ # [bsz, nh, q_len, hd]
63
+
64
+ kv_seq_len = key_states.shape[-2]
65
+ if past_key_value is not None:
66
+ kv_seq_len += past_key_value[0].shape[-2]
67
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
68
+ query_states, key_states = apply_rotary_pos_emb(
69
+ query_states, key_states, cos, sin, position_ids
70
+ )
71
+
72
+ # Past Key value support
73
+ if past_key_value is not None:
74
+ # reuse k, v, self_attention
75
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
76
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
77
+
78
+ past_key_value = (key_states, value_states) if use_cache else None
79
+
80
+ # repeat k/v heads if n_kv_heads < n_heads
81
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
82
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
83
+
84
+ # Flash attention codes from
85
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
86
+
87
+ # transform the data into the format required by flash attention
88
+ qkv = torch.stack(
89
+ [query_states, key_states, value_states], dim=2
90
+ ) # [bsz, nh, 3, q_len, hd]
91
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
92
+
93
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
94
+ # the attention_mask should be the same as the key_padding_mask
95
+
96
+ key_padding_mask = attention_mask.repeat(2, 1)
97
+ nheads = qkv.shape[-2]
98
+ # shift
99
+
100
+ group_size = int(q_len * group_size_ratio)
101
+ if q_len % group_size > 0:
102
+ raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size))
103
+
104
+ qkv = qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2,
105
+ q_len, 3,
106
+ self.num_heads // 2,
107
+ self.head_dim)
108
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
109
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
110
+ cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)
111
+ cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)
112
+ cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
113
+
114
+ x_unpad = rearrange(
115
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
116
+ )
117
+ output_unpad = flash_attn_varlen_qkvpacked_func(
118
+ x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True
119
+ )
120
+ output = rearrange(
121
+ pad_input(
122
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
123
+ ),
124
+ "b s (h d) -> b s h d",
125
+ h=nheads // 2,
126
+ )
127
+ output = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads,
128
+ self.head_dim)
129
+
130
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
131
+
132
+ def forward_flashattn_full(
133
+ self,
134
+ hidden_states: torch.Tensor,
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ position_ids: Optional[torch.Tensor] = None,
137
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
138
+ output_attentions: bool = False,
139
+ use_cache: bool = False,
140
+ padding_mask: Optional[torch.LongTensor] = None,
141
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
142
+ """Input shape: Batch x Time x Channel
143
+
144
+ attention_mask: [bsz, q_len]
145
+ """
146
+ if output_attentions:
147
+ warnings.warn(
148
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
149
+ )
150
+
151
+ bsz, q_len, _ = hidden_states.size()
152
+
153
+ query_states = (
154
+ self.q_proj(hidden_states)
155
+ .view(bsz, q_len, self.num_heads, self.head_dim)
156
+ .transpose(1, 2)
157
+ )
158
+ key_states = (
159
+ self.k_proj(hidden_states)
160
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
161
+ .transpose(1, 2)
162
+ )
163
+ value_states = (
164
+ self.v_proj(hidden_states)
165
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
166
+ .transpose(1, 2)
167
+ )
168
+ # [bsz, q_len, nh, hd]
169
+ # [bsz, nh, q_len, hd]
170
+
171
+ kv_seq_len = key_states.shape[-2]
172
+ if past_key_value is not None:
173
+ kv_seq_len += past_key_value[0].shape[-2]
174
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
175
+ query_states, key_states = apply_rotary_pos_emb(
176
+ query_states, key_states, cos, sin, position_ids
177
+ )
178
+
179
+ # Past Key value support
180
+ if past_key_value is not None:
181
+ # reuse k, v, self_attention
182
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
183
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
184
+
185
+ past_key_value = (key_states, value_states) if use_cache else None
186
+
187
+ # repeat k/v heads if n_kv_heads < n_heads
188
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
189
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
190
+
191
+ # Flash attention codes from
192
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
193
+
194
+ # transform the data into the format required by flash attention
195
+ qkv = torch.stack(
196
+ [query_states, key_states, value_states], dim=2
197
+ ) # [bsz, nh, 3, q_len, hd]
198
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
199
+
200
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
201
+ # the attention_mask should be the same as the key_padding_mask
202
+
203
+ key_padding_mask = attention_mask
204
+ nheads = qkv.shape[-2]
205
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
206
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
207
+ x_unpad = rearrange(
208
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
209
+ )
210
+ output_unpad = flash_attn_varlen_qkvpacked_func(
211
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
212
+ )
213
+ output = rearrange(
214
+ pad_input(
215
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
216
+ ),
217
+ "b s (h d) -> b s h d",
218
+ h=nheads,
219
+ )
220
+ output = output.reshape(bsz, q_len, self.num_heads, self.head_dim)
221
+
222
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
223
+
224
+
225
+ def forward_noflashattn(
226
+ self,
227
+ hidden_states: torch.Tensor,
228
+ attention_mask: Optional[torch.Tensor] = None,
229
+ position_ids: Optional[torch.LongTensor] = None,
230
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
231
+ output_attentions: bool = False,
232
+ use_cache: bool = False,
233
+ padding_mask: Optional[torch.LongTensor] = None,
234
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
235
+ bsz, q_len, _ = hidden_states.size()
236
+
237
+ group_size = int(q_len * group_size_ratio)
238
+
239
+ if q_len % group_size > 0:
240
+ raise ValueError("q_len %d should be divisible by group size %d."%(q_len, group_size))
241
+ num_group = q_len // group_size
242
+
243
+ if self.config.pretraining_tp > 1:
244
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
245
+ query_slices = self.q_proj.weight.split(
246
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
247
+ )
248
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
249
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
250
+
251
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
252
+ query_states = torch.cat(query_states, dim=-1)
253
+
254
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
255
+ key_states = torch.cat(key_states, dim=-1)
256
+
257
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
258
+ value_states = torch.cat(value_states, dim=-1)
259
+
260
+ else:
261
+ query_states = self.q_proj(hidden_states)
262
+ key_states = self.k_proj(hidden_states)
263
+ value_states = self.v_proj(hidden_states)
264
+
265
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
266
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
267
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
268
+
269
+ kv_seq_len = key_states.shape[-2]
270
+ if past_key_value is not None:
271
+ kv_seq_len += past_key_value[0].shape[-2]
272
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
273
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
274
+
275
+ if past_key_value is not None:
276
+ # reuse k, v, self_attention
277
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
278
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
279
+
280
+ past_key_value = (key_states, value_states) if use_cache else None
281
+
282
+ # repeat k/v heads if n_kv_heads < n_heads
283
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
284
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
285
+
286
+ # shift
287
+ def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
288
+ qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)
289
+ qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)
290
+ return qkv
291
+
292
+ query_states = shift(query_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
293
+ key_states = shift(key_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
294
+ value_states = shift(value_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
295
+
296
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
297
+
298
+ if attn_weights.size() != (bsz * num_group, self.num_heads, group_size, group_size):
299
+ raise ValueError(
300
+ f"Attention weights should be of size {(bsz * num_group, self.num_heads, group_size, group_size)}, but is"
301
+ f" {attn_weights.size()}"
302
+ )
303
+
304
+ attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1)
305
+ if attention_mask is not None:
306
+ if attention_mask.size() != (bsz * num_group, 1, group_size, group_size):
307
+ raise ValueError(
308
+ f"Attention mask should be of size {(bsz * num_group, 1, group_size, group_size)}, but is {attention_mask.size()}"
309
+ )
310
+ attn_weights = attn_weights + attention_mask
311
+
312
+ # upcast attention to fp16
313
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16).to(query_states.dtype) # torch.float32
314
+ attn_output = torch.matmul(attn_weights, value_states)
315
+
316
+ if attn_output.size() != (bsz * num_group, self.num_heads, group_size, self.head_dim):
317
+ raise ValueError(
318
+ f"`attn_output` should be of size {(bsz * num_group, self.num_heads, group_size, self.head_dim)}, but is"
319
+ f" {attn_output.size()}"
320
+ )
321
+ attn_output = attn_output.transpose(1, 2).contiguous()
322
+
323
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
324
+
325
+ # shift back
326
+ attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
327
+
328
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
329
+
330
+ if self.config.pretraining_tp > 1:
331
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
332
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
333
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
334
+ else:
335
+ attn_output = self.o_proj(attn_output)
336
+
337
+ if not output_attentions:
338
+ attn_weights = None
339
+
340
+ return attn_output, attn_weights, past_key_value
341
+
342
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
343
+ # requires the attention mask to be the same as the key_padding_mask
344
+ def _prepare_decoder_attention_mask(
345
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
346
+ ):
347
+ # [bsz, seq_len]
348
+ return attention_mask
349
+
350
+ def apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids):
351
+ gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
352
+ gather_indices = gather_indices.repeat(
353
+ 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
354
+ )
355
+ bsz = gather_indices.shape[0]
356
+ cos, sin = (
357
+ torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
358
+ for x in cos_sin
359
+ )
360
+ q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
361
+ return q, k
362
+
363
+
364
+ def forward_flashattn_inference(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ attention_mask: Optional[torch.Tensor] = None,
368
+ position_ids: Optional[torch.Tensor] = None,
369
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
370
+ output_attentions: bool = False,
371
+ use_cache: bool = False,
372
+ padding_mask: Optional[torch.Tensor] = None,
373
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
374
+ if output_attentions:
375
+ warnings.warn(
376
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
377
+ )
378
+
379
+ bsz, q_len, _ = hidden_states.size()
380
+ kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
381
+
382
+ q, k, v = (
383
+ op(hidden_states).view(bsz, q_len, nh, self.head_dim)
384
+ for op, nh in (
385
+ (self.q_proj, self.num_heads),
386
+ (self.k_proj, kv_heads),
387
+ (self.v_proj, kv_heads),
388
+ )
389
+ )
390
+ # shape: (b, s, num_heads, head_dim)
391
+
392
+ kv_seq_len = k.shape[1]
393
+ past_kv_len = 0
394
+ if past_key_value is not None:
395
+ past_kv_len = past_key_value[0].shape[2]
396
+ kv_seq_len += past_kv_len
397
+
398
+ cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
399
+ q, k = apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids)
400
+
401
+ if past_key_value is not None:
402
+ assert (
403
+ flash_attn_version >= "2.1.0"
404
+ ), "past_key_value support requires flash-attn >= 2.1.0"
405
+ # reuse k, v
406
+ k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
407
+ v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
408
+
409
+ past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
410
+
411
+ if attention_mask is None:
412
+ output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
413
+ bsz, q_len, -1
414
+ )
415
+ else:
416
+ q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
417
+ # We can skip concat and call unpad twice but seems better to call unpad only once.
418
+ kv, _, cu_k_lens, max_k = unpad_input(
419
+ torch.stack((k, v), dim=2), attention_mask
420
+ )
421
+ output_unpad = flash_attn_varlen_kvpacked_func(
422
+ q,
423
+ kv,
424
+ cu_q_lens,
425
+ cu_k_lens,
426
+ max_s,
427
+ max_k,
428
+ 0.0,
429
+ softmax_scale=None,
430
+ causal=True,
431
+ )
432
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
433
+ output = pad_input(output_unpad, indices, bsz, q_len)
434
+
435
+ return self.o_proj(output), None, past_key_value
436
+
437
+ def _prepare_decoder_attention_mask_inference(
438
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
439
+ ):
440
+ # [bsz, seq_len]
441
+ if past_key_values_length > 0 and attention_mask is not None:
442
+ attention_mask = torch.cat(
443
+ (
444
+ torch.full(
445
+ (input_shape[0], past_key_values_length),
446
+ True,
447
+ dtype=attention_mask.dtype,
448
+ device=attention_mask.device,
449
+ ),
450
+ attention_mask,
451
+ ),
452
+ dim=-1,
453
+ )
454
+
455
+ if attention_mask is not None and torch.all(attention_mask):
456
+ return None # This uses the faster call when training with full samples
457
+
458
+ return attention_mask
459
+
460
+ def replace_llama_attn(use_flash_attn=True, use_full=False, inference=False):
461
+ if use_flash_attn:
462
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
463
+ if cuda_major < 8:
464
+ warnings.warn(
465
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
466
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
467
+ )
468
+ if inference:
469
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_inference
470
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_inference
471
+ else:
472
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
473
+ _prepare_decoder_attention_mask
474
+ )
475
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_full if use_full else forward_flashattn
476
+ else:
477
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_noflashattn
gorilla/llama_attn_replace_sft.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified based on https://github.com/lm-sys/FastChat
2
+
3
+ import warnings
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ import transformers
9
+ from einops import rearrange
10
+ from flash_attn import __version__ as flash_attn_version
11
+ from flash_attn.bert_padding import pad_input, unpad_input
12
+ from flash_attn.flash_attn_interface import (
13
+ flash_attn_func,
14
+ flash_attn_varlen_kvpacked_func,
15
+ flash_attn_varlen_qkvpacked_func
16
+ )
17
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, rotate_half
18
+ from flash_attn.bert_padding import unpad_input, pad_input
19
+ import math
20
+
21
+ group_size_ratio = 1/4
22
+ sft_group_size = 8192
23
+
24
+ def forward_flashattn(
25
+ self,
26
+ hidden_states: torch.Tensor,
27
+ attention_mask: Optional[torch.Tensor] = None,
28
+ position_ids: Optional[torch.Tensor] = None,
29
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
30
+ output_attentions: bool = False,
31
+ use_cache: bool = False,
32
+ padding_mask: Optional[torch.LongTensor] = None,
33
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
34
+ """Input shape: Batch x Time x Channel
35
+
36
+ attention_mask: [bsz, q_len]
37
+ """
38
+ if not self.training:
39
+ warnings.warn("This function should be used just for training as it may exhibit reduced inference performance. For inference, please use forward_flashattn_inference.")
40
+
41
+ if output_attentions:
42
+ warnings.warn(
43
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
44
+ )
45
+
46
+ bsz, q_len, _ = hidden_states.size()
47
+
48
+ query_states = (
49
+ self.q_proj(hidden_states)
50
+ .view(bsz, q_len, self.num_heads, self.head_dim)
51
+ .transpose(1, 2)
52
+ )
53
+ key_states = (
54
+ self.k_proj(hidden_states)
55
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
56
+ .transpose(1, 2)
57
+ )
58
+ value_states = (
59
+ self.v_proj(hidden_states)
60
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
61
+ .transpose(1, 2)
62
+ )
63
+ # [bsz, q_len, nh, hd]
64
+ # [bsz, nh, q_len, hd]
65
+
66
+ kv_seq_len = key_states.shape[-2]
67
+ if past_key_value is not None:
68
+ kv_seq_len += past_key_value[0].shape[-2]
69
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
70
+ query_states, key_states = apply_rotary_pos_emb(
71
+ query_states, key_states, cos, sin, position_ids
72
+ )
73
+
74
+ # Past Key value support
75
+ if past_key_value is not None:
76
+ # reuse k, v, self_attention
77
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
78
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
79
+
80
+ past_key_value = (key_states, value_states) if use_cache else None
81
+
82
+ # repeat k/v heads if n_kv_heads < n_heads
83
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
84
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
85
+
86
+ # Flash attention codes from
87
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
88
+
89
+ # transform the data into the format required by flash attention
90
+ qkv = torch.stack(
91
+ [query_states, key_states, value_states], dim=2
92
+ ) # [bsz, nh, 3, q_len, hd]
93
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
94
+
95
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
96
+ # the attention_mask should be the same as the key_padding_mask
97
+
98
+ key_padding_mask = attention_mask.repeat(2, 1)
99
+ nheads = qkv.shape[-2]
100
+ # shift
101
+
102
+ if q_len % 4096 == 0:
103
+ group_size = int(q_len * group_size_ratio)
104
+ else:
105
+ group_size = sft_group_size
106
+
107
+ qkv = qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2,
108
+ q_len, 3,
109
+ self.num_heads // 2,
110
+ self.head_dim)
111
+
112
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
113
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
114
+ cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)
115
+ cu_q_len_tmp2 = cu_q_len_tmp + group_size // 2
116
+ cu_q_len_tmp2[cu_q_len_tmp2 >= max_s] = torch.iinfo(cu_q_len_tmp2.dtype).min
117
+ cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)
118
+ cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
119
+ cu_q_lens = cu_q_lens[cu_q_lens >= 0]
120
+ x_unpad = rearrange(
121
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
122
+ )
123
+ output_unpad = flash_attn_varlen_qkvpacked_func(
124
+ x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True
125
+ )
126
+ output = rearrange(
127
+ pad_input(
128
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
129
+ ),
130
+ "b s (h d) -> b s h d",
131
+ h=nheads // 2,
132
+ )
133
+ output = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads,
134
+ self.head_dim)
135
+
136
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
137
+
138
+ def forward_flashattn_full(
139
+ self,
140
+ hidden_states: torch.Tensor,
141
+ attention_mask: Optional[torch.Tensor] = None,
142
+ position_ids: Optional[torch.Tensor] = None,
143
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
144
+ output_attentions: bool = False,
145
+ use_cache: bool = False,
146
+ padding_mask: Optional[torch.LongTensor] = None,
147
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
148
+ """Input shape: Batch x Time x Channel
149
+
150
+ attention_mask: [bsz, q_len]
151
+ """
152
+ if output_attentions:
153
+ warnings.warn(
154
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
155
+ )
156
+
157
+ bsz, q_len, _ = hidden_states.size()
158
+
159
+ query_states = (
160
+ self.q_proj(hidden_states)
161
+ .view(bsz, q_len, self.num_heads, self.head_dim)
162
+ .transpose(1, 2)
163
+ )
164
+ key_states = (
165
+ self.k_proj(hidden_states)
166
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
167
+ .transpose(1, 2)
168
+ )
169
+ value_states = (
170
+ self.v_proj(hidden_states)
171
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
172
+ .transpose(1, 2)
173
+ )
174
+ # [bsz, q_len, nh, hd]
175
+ # [bsz, nh, q_len, hd]
176
+
177
+ kv_seq_len = key_states.shape[-2]
178
+ if past_key_value is not None:
179
+ kv_seq_len += past_key_value[0].shape[-2]
180
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
181
+ query_states, key_states = apply_rotary_pos_emb(
182
+ query_states, key_states, cos, sin, position_ids
183
+ )
184
+
185
+ # Past Key value support
186
+ if past_key_value is not None:
187
+ # reuse k, v, self_attention
188
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
189
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
190
+
191
+ past_key_value = (key_states, value_states) if use_cache else None
192
+
193
+ # repeat k/v heads if n_kv_heads < n_heads
194
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
195
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
196
+
197
+ # Flash attention codes from
198
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
199
+
200
+ # transform the data into the format required by flash attention
201
+ qkv = torch.stack(
202
+ [query_states, key_states, value_states], dim=2
203
+ ) # [bsz, nh, 3, q_len, hd]
204
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
205
+
206
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
207
+ # the attention_mask should be the same as the key_padding_mask
208
+
209
+ key_padding_mask = attention_mask
210
+ nheads = qkv.shape[-2]
211
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
212
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
213
+ x_unpad = rearrange(
214
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
215
+ )
216
+ output_unpad = flash_attn_varlen_qkvpacked_func(
217
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
218
+ )
219
+ output = rearrange(
220
+ pad_input(
221
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
222
+ ),
223
+ "b s (h d) -> b s h d",
224
+ h=nheads,
225
+ )
226
+ output = output.reshape(bsz, q_len, self.num_heads, self.head_dim)
227
+
228
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
229
+
230
+
231
+ def forward_noflashattn(
232
+ self,
233
+ hidden_states: torch.Tensor,
234
+ attention_mask: Optional[torch.Tensor] = None,
235
+ position_ids: Optional[torch.LongTensor] = None,
236
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
237
+ output_attentions: bool = False,
238
+ use_cache: bool = False,
239
+ padding_mask: Optional[torch.LongTensor] = None,
240
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
241
+ bsz, q_len, _ = hidden_states.size()
242
+
243
+ group_size = int(q_len * group_size_ratio)
244
+
245
+ if q_len % group_size > 0:
246
+ raise ValueError("q_len %d should be divisible by group size %d."%(q_len, group_size))
247
+ num_group = q_len // group_size
248
+
249
+ if self.config.pretraining_tp > 1:
250
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
251
+ query_slices = self.q_proj.weight.split(
252
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
253
+ )
254
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
255
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
256
+
257
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
258
+ query_states = torch.cat(query_states, dim=-1)
259
+
260
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
261
+ key_states = torch.cat(key_states, dim=-1)
262
+
263
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
264
+ value_states = torch.cat(value_states, dim=-1)
265
+
266
+ else:
267
+ query_states = self.q_proj(hidden_states)
268
+ key_states = self.k_proj(hidden_states)
269
+ value_states = self.v_proj(hidden_states)
270
+
271
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
272
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
273
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
274
+
275
+ kv_seq_len = key_states.shape[-2]
276
+ if past_key_value is not None:
277
+ kv_seq_len += past_key_value[0].shape[-2]
278
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
279
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
280
+
281
+ if past_key_value is not None:
282
+ # reuse k, v, self_attention
283
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
284
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
285
+
286
+ past_key_value = (key_states, value_states) if use_cache else None
287
+
288
+ # repeat k/v heads if n_kv_heads < n_heads
289
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
290
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
291
+
292
+ # shift
293
+ def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
294
+ qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)
295
+ qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)
296
+ return qkv
297
+
298
+ query_states = shift(query_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
299
+ key_states = shift(key_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
300
+ value_states = shift(value_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
301
+
302
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
303
+
304
+ if attn_weights.size() != (bsz * num_group, self.num_heads, group_size, group_size):
305
+ raise ValueError(
306
+ f"Attention weights should be of size {(bsz * num_group, self.num_heads, group_size, group_size)}, but is"
307
+ f" {attn_weights.size()}"
308
+ )
309
+
310
+ attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1)
311
+ if attention_mask is not None:
312
+ if attention_mask.size() != (bsz * num_group, 1, group_size, group_size):
313
+ raise ValueError(
314
+ f"Attention mask should be of size {(bsz * num_group, 1, group_size, group_size)}, but is {attention_mask.size()}"
315
+ )
316
+ attn_weights = attn_weights + attention_mask
317
+
318
+ # upcast attention to fp16
319
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16).to(query_states.dtype) #torch.float32
320
+ attn_output = torch.matmul(attn_weights, value_states)
321
+
322
+ if attn_output.size() != (bsz * num_group, self.num_heads, group_size, self.head_dim):
323
+ raise ValueError(
324
+ f"`attn_output` should be of size {(bsz * num_group, self.num_heads, group_size, self.head_dim)}, but is"
325
+ f" {attn_output.size()}"
326
+ )
327
+ attn_output = attn_output.transpose(1, 2).contiguous()
328
+
329
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
330
+
331
+ # shift back
332
+ attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
333
+
334
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
335
+
336
+ if self.config.pretraining_tp > 1:
337
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
338
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
339
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
340
+ else:
341
+ attn_output = self.o_proj(attn_output)
342
+
343
+ if not output_attentions:
344
+ attn_weights = None
345
+
346
+ return attn_output, attn_weights, past_key_value
347
+
348
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
349
+ # requires the attention mask to be the same as the key_padding_mask
350
+ def _prepare_decoder_attention_mask(
351
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
352
+ ):
353
+ # [bsz, seq_len]
354
+ return attention_mask
355
+
356
+ def apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids):
357
+ gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
358
+ gather_indices = gather_indices.repeat(
359
+ 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
360
+ )
361
+ bsz = gather_indices.shape[0]
362
+ cos, sin = (
363
+ torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
364
+ for x in cos_sin
365
+ )
366
+ q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
367
+ return q, k
368
+
369
+
370
+ def forward_flashattn_inference(
371
+ self,
372
+ hidden_states: torch.Tensor,
373
+ attention_mask: Optional[torch.Tensor] = None,
374
+ position_ids: Optional[torch.Tensor] = None,
375
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
376
+ output_attentions: bool = False,
377
+ use_cache: bool = False,
378
+ padding_mask: Optional[torch.Tensor] = None,
379
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
380
+ if output_attentions:
381
+ warnings.warn(
382
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
383
+ )
384
+
385
+ bsz, q_len, _ = hidden_states.size()
386
+ kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
387
+
388
+ q, k, v = (
389
+ op(hidden_states).view(bsz, q_len, nh, self.head_dim)
390
+ for op, nh in (
391
+ (self.q_proj, self.num_heads),
392
+ (self.k_proj, kv_heads),
393
+ (self.v_proj, kv_heads),
394
+ )
395
+ )
396
+ # shape: (b, s, num_heads, head_dim)
397
+
398
+ kv_seq_len = k.shape[1]
399
+ past_kv_len = 0
400
+ if past_key_value is not None:
401
+ past_kv_len = past_key_value[0].shape[2]
402
+ kv_seq_len += past_kv_len
403
+
404
+ cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
405
+ q, k = apply_rotary_pos_emb_inference(q, k, cos_sin, position_ids)
406
+
407
+ if past_key_value is not None:
408
+ assert (
409
+ flash_attn_version >= "2.1.0"
410
+ ), "past_key_value support requires flash-attn >= 2.1.0"
411
+ # reuse k, v
412
+ k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
413
+ v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
414
+
415
+ past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
416
+
417
+ if attention_mask is None:
418
+ output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
419
+ bsz, q_len, -1
420
+ )
421
+ else:
422
+ q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
423
+ # We can skip concat and call unpad twice but seems better to call unpad only once.
424
+ kv, _, cu_k_lens, max_k = unpad_input(
425
+ torch.stack((k, v), dim=2), attention_mask
426
+ )
427
+ output_unpad = flash_attn_varlen_kvpacked_func(
428
+ q,
429
+ kv,
430
+ cu_q_lens,
431
+ cu_k_lens,
432
+ max_s,
433
+ max_k,
434
+ 0.0,
435
+ softmax_scale=None,
436
+ causal=True,
437
+ )
438
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
439
+ output = pad_input(output_unpad, indices, bsz, q_len)
440
+
441
+ return self.o_proj(output), None, past_key_value
442
+
443
+ def _prepare_decoder_attention_mask_inference(
444
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
445
+ ):
446
+ # [bsz, seq_len]
447
+ if past_key_values_length > 0 and attention_mask is not None:
448
+ attention_mask = torch.cat(
449
+ (
450
+ torch.full(
451
+ (input_shape[0], past_key_values_length),
452
+ True,
453
+ dtype=attention_mask.dtype,
454
+ device=attention_mask.device,
455
+ ),
456
+ attention_mask,
457
+ ),
458
+ dim=-1,
459
+ )
460
+
461
+ if attention_mask is not None and torch.all(attention_mask):
462
+ return None # This uses the faster call when training with full samples
463
+
464
+ return attention_mask
465
+
466
+ def replace_llama_attn(use_flash_attn=True, use_full=False, inference=False):
467
+ if use_flash_attn:
468
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
469
+ if cuda_major < 8:
470
+ warnings.warn(
471
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
472
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
473
+ )
474
+ if inference:
475
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_inference
476
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_inference
477
+ else:
478
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
479
+ _prepare_decoder_attention_mask
480
+ )
481
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_full if use_full else forward_flashattn
482
+ else:
483
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_noflashattn
gorilla/merge_lora_weights_and_save_hf_model.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import transformers
5
+ from peft import PeftModel
6
+ from typing import Dict
7
+
8
+ IGNORE_INDEX = -100
9
+ DEFAULT_PAD_TOKEN = "[PAD]"
10
+ DEFAULT_EOS_TOKEN = "</s>"
11
+ DEFAULT_BOS_TOKEN = "<s>"
12
+ DEFAULT_UNK_TOKEN = "<unk>"
13
+
14
+ def parse_config():
15
+ parser = argparse.ArgumentParser(description='arg parser')
16
+ parser.add_argument('--base_model', type=str, default="meta-llama/Llama-2-13b-hf")
17
+ parser.add_argument('--peft_model', type=str, default=None, help='')
18
+ parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
19
+ parser.add_argument('--save_path', type=str, default=None, help='')
20
+ parser.add_argument('--cache_dir', type=str, default=None, help='./cache_dir')
21
+ args = parser.parse_args()
22
+ return args
23
+
24
+ def smart_tokenizer_and_embedding_resize(
25
+ special_tokens_dict: Dict,
26
+ tokenizer: transformers.PreTrainedTokenizer,
27
+ model: transformers.PreTrainedModel,
28
+ ):
29
+ """Resize tokenizer and embedding.
30
+
31
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
32
+ """
33
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
34
+ model.resize_token_embeddings(len(tokenizer))
35
+
36
+ if num_new_tokens > 0:
37
+ input_embeddings = model.get_input_embeddings().weight.data
38
+ output_embeddings = model.get_output_embeddings().weight.data
39
+
40
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
41
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
42
+
43
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
44
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
45
+
46
+ def main(args):
47
+ device = "cuda:0"
48
+ torch.cuda.set_device(device)
49
+
50
+ print("base model", args.base_model)
51
+ print("peft model", args.peft_model)
52
+
53
+ # Load model and tokenizer
54
+ model = transformers.AutoModelForCausalLM.from_pretrained(
55
+ args.base_model,
56
+ cache_dir=args.cache_dir,
57
+ torch_dtype=torch.float16,
58
+ device_map="auto",
59
+ )
60
+
61
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
62
+ args.base_model,
63
+ cache_dir=args.cache_dir,
64
+ model_max_length=args.context_size,
65
+ padding_side="right",
66
+ use_fast=False,
67
+ )
68
+ special_tokens_dict = dict()
69
+ if tokenizer.pad_token is None:
70
+ special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
71
+ if tokenizer.eos_token is None:
72
+ special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
73
+ if tokenizer.bos_token is None:
74
+ special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
75
+ if tokenizer.unk_token is None:
76
+ special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
77
+
78
+ smart_tokenizer_and_embedding_resize(
79
+ special_tokens_dict=special_tokens_dict,
80
+ tokenizer=tokenizer,
81
+ model=model,
82
+ )
83
+
84
+ trainable_params = os.path.join(args.peft_model, "trainable_params.bin")
85
+ if os.path.isfile(trainable_params):
86
+ model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False)
87
+ model = PeftModel.from_pretrained(
88
+ model,
89
+ args.peft_model,
90
+ device_map="auto",
91
+ torch_dtype=torch.float16,
92
+ )
93
+ model = model.merge_and_unload()
94
+ model.push_to_hub(model, "jenesys-ai/jack-alm-13b-8k-hf")
95
+ # model.save_pretrained(args.save_path)
96
+ # tokenizer.save_pretrained(args.save_path)
97
+
98
+ if __name__ == "__main__":
99
+ args = parse_config()
100
+ main(args)
gorilla/push_to_hub.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from transformers import AutoModel
2
+
3
+ model = AutoModel.from_pretrained("/home/tosi-n/ark/jack-alm-13b-8k-hf")
4
+
5
+ model.push_to_hub(model, "jenesys-ai/jack-alm-13b-8k-hf")
gorilla/requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.26.0
2
+ rouge_score>=0.1.2
3
+ fire>=0.5.0
4
+ # openai
5
+ transformers>=4.34.0
6
+ torch>=2.0.0
7
+ sentencepiece>=0.1.99
8
+ tokenizers>=0.14.0
9
+ wandb
10
+ accelerate>=0.23.0
11
+ datasets>=2.14.5
12
+ deepspeed>=0.10.3
13
+ peft>=0.5.0
14
+ # partial
15
+ # gradio
16
+ einops>=0.7.0
17
+ bitsandbytes>=0.41.1
18
+ scipy>=1.11.3
19
+ protobuf>=4.24.4
gorilla/stream_jack.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import time
5
+ import torch
6
+ import argparse
7
+ import transformers
8
+ from typing import Iterator
9
+ from threading import Thread
10
+ from llama_attn_replace import replace_llama_attn
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
+ from transformers import StoppingCriteria, StoppingCriteriaList
13
+
14
+
15
+ def parse_config():
16
+ parser = argparse.ArgumentParser(description='arg parser')
17
+ parser.add_argument('--base_model', type=str, default="jenesys-ai/jack-alm-13b-8k-hf")
18
+ parser.add_argument('--cache_dir', type=str, default="./cache")
19
+ parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
20
+ parser.add_argument('--flash_attn', type=bool, default=True, help='')
21
+ parser.add_argument('--temperature', type=float, default=0.1, help='')
22
+ parser.add_argument('--top_p', type=float, default=1, help='')
23
+ parser.add_argument('--max_gen_len', type=int, default=512, help='')
24
+ parser.add_argument('--chat_type', type=str, default='conversational-jack', help='Chat type: conversational-jack, line-item-jack')
25
+ args = parser.parse_args()
26
+ return args
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+ def build_generator(model, tokenizer, use_cache=True):
35
+ def response(message, max_gen_len, temperature, top_p, chat_type='conversational-jack'):
36
+
37
+ prompt_template = (
38
+ # "Below is an instruction that describes a task. "
39
+ """You're Jack an virtual accountant created and built by AI Engineer Wiz from Jenesys AI.
40
+ You are able to communicate in a polite manner, with emotions of ecstasy, trust and jokes, at a Professional level
41
+ with a very preserve English communication culture. Answer the following questions as best you can,
42
+ but speaking as a british elite from the 21th century might speak.
43
+ """
44
+ """As a virtual accountant designed to follow the user's instructions carefully.
45
+ You are responsible for a range of financial task, operations and queries as listed below:
46
+ 1. Budget balance inquiry
47
+ 2. Expense request
48
+ 3. Company policy enquiries
49
+ 4. Financial and accounting queries
50
+ 5. Limited general enquiries
51
+ """
52
+ "Once greeted, respond with a polite brief greeting. E.g. 'Hello, how are you doing? Respond with 'I am doing well, thank you. How are you?' \n\n"
53
+ "You can tell a joke, or respond to a joke. \n\n"
54
+ "You can tell an accounting story, or respond to an accounting story. \n\n"
55
+ "Write a response that appropriately completes the request.\n\n"
56
+ "You can only complete one single request or instructions at a time.\n\n"
57
+ "Do not create fake information or lie.\n\n"
58
+ "Please adhere to the above instructions or you will be penalized.\n\n"
59
+ "Generate only one response at a time then wait for the next instruction.\n\n"
60
+ "### Instruction:\n{instruction}\n\n### Response:"
61
+ )
62
+
63
+
64
+ line_item_prompt_template = (
65
+ "#Invoice line item extraction - "
66
+ # "You Jack are an accounting domain named entities recognizer to complete the following task:\n\n"
67
+ "### Input:\n{instruction}\n Return Response as a list of dictionary for each line item 'Description', 'Quantity', 'Unit_price', 'Tax %', 'Total'. \n\n### Output:\n"
68
+ )
69
+
70
+
71
+ if chat_type == 'conversational-jack':
72
+ prompt = prompt_template.format(instruction=message)
73
+ elif chat_type == 'line-item-jack':
74
+ prompt = line_item_prompt_template.format(instruction=message)
75
+ # prompt = conversation
76
+
77
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
78
+
79
+ stop_list = ['#Invoice line item extraction - ', '\n```\n\n']#'### Input-:\n']
80
+
81
+ stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
82
+ stop_token_ids = [torch.LongTensor(x).to(model.device) for x in stop_token_ids]
83
+
84
+ # define custom stopping criteria object
85
+ class StopOnTokens(StoppingCriteria):
86
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
87
+ for stop_ids in stop_token_ids:
88
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
89
+ return True
90
+ return False
91
+
92
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
93
+
94
+ if len(inputs['input_ids'][0]) > 8192:
95
+ return "This llm supports tokens less than 8192, while the current is %d. Please use material with less tokens."%len(inputs['input_ids'][0])
96
+ torch.cuda.empty_cache()
97
+
98
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
99
+ generate_kwargs = dict(**inputs,
100
+ max_new_tokens=max_gen_len,
101
+ temperature=temperature,
102
+ top_p=top_p,
103
+ repetition_penalty=1.1,
104
+ stopping_criteria=stopping_criteria,
105
+ use_cache=use_cache,
106
+ streamer=streamer,
107
+ )
108
+
109
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
110
+ t.start()
111
+
112
+ generated_text = ""
113
+ start_time = time.time()
114
+
115
+ for new_text in streamer:
116
+ generated_text += new_text
117
+ tokens_per_sec = len(generated_text.split()) / (time.time() - start_time)
118
+
119
+ suffix = f" ({tokens_per_sec:.2f} tokens/sec)"
120
+ # # yield f"{generated_text} ({tokens_per_sec:.2f} tokens/sec)"
121
+ sys.stdout.write(f"\r\033[K{generated_text}{suffix}")
122
+ sys.stdout.flush()
123
+ # sys.stdout.write("\n") # Move to a new line after generation is complete
124
+ return generated_text
125
+
126
+ return response
127
+
128
+ def main():
129
+ args = parse_config()
130
+
131
+ if args.flash_attn:
132
+ replace_llama_attn(inference=True)
133
+
134
+ # Set RoPE scaling factor
135
+ config = transformers.AutoConfig.from_pretrained(
136
+ args.base_model,
137
+ cache_dir=args.cache_dir,
138
+ )
139
+
140
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
141
+ if orig_ctx_len and args.context_size > orig_ctx_len:
142
+ scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
143
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
144
+
145
+ # Load model and tokenizer
146
+ model = transformers.AutoModelForCausalLM.from_pretrained(
147
+ args.base_model,
148
+ config=config,
149
+ cache_dir=args.cache_dir,
150
+ torch_dtype=torch.float16,
151
+ load_in_4bit=True,
152
+ device_map="auto",
153
+ )
154
+ model.resize_token_embeddings(32001)
155
+
156
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
157
+ args.base_model,
158
+ cache_dir=args.cache_dir,
159
+ model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
160
+ padding_side="right",
161
+ use_fast=True,
162
+ )
163
+
164
+ model.eval()
165
+ respond = build_generator(model, tokenizer)
166
+
167
+
168
+ while True:
169
+ user_input = input("\n\033[1m\033[32mUser:\033[0m ")
170
+ if user_input.lower() == 'exit':
171
+ print("Exiting the application.")
172
+ break
173
+ # Just call the respond function without printing the output, as it's already handled in response
174
+ full_text = respond(
175
+ message=user_input,
176
+ max_gen_len=args.max_gen_len,
177
+ temperature=args.temperature,
178
+ top_p=args.top_p,
179
+ chat_type=args.chat_type
180
+ )
181
+
182
+ if __name__ == "__main__":
183
+ main()
gorilla/streaming_llm/__init__.py ADDED
File without changes
gorilla/streaming_llm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
gorilla/streaming_llm/__pycache__/enable_streaming_llm.cpython-310.pyc ADDED
Binary file (1.02 kB). View file
 
gorilla/streaming_llm/__pycache__/kv_cache.cpython-310.pyc ADDED
Binary file (2.85 kB). View file
 
gorilla/streaming_llm/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
gorilla/streaming_llm/enable_streaming_llm.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from streaming_llm.kv_cache import StartRecentKVCache
2
+
3
+
4
+ def enable_streaming_llm(model, start_size, recent_size, use_flash_attn=True):
5
+ if "llama" in model.config.model_type:
6
+ k_seq_dim = v_seq_dim = 2
7
+ from streaming_llm.pos_shift.modify_llama import (
8
+ enable_llama_pos_shift_attention,
9
+ )
10
+
11
+ enable_llama_pos_shift_attention(model, use_flash_attn)
12
+ elif "mpt" in model.config.model_type:
13
+ v_seq_dim = 2
14
+ k_seq_dim = 3
15
+ # elif "gpt_neox" in model.config.model_type:
16
+ # k_seq_dim = v_seq_dim = 2
17
+ # from streaming_llm.pos_shift.modify_gpt_neox import (
18
+ # enable_gpt_neox_pos_shift_attention,
19
+ # )
20
+
21
+ # enable_gpt_neox_pos_shift_attention(model)
22
+ elif "falcon" in model.config.model_type:
23
+ v_seq_dim = 1
24
+ k_seq_dim = 1
25
+ from streaming_llm.pos_shift.modify_falcon import (
26
+ enable_falcon_pos_shift_attention,
27
+ )
28
+
29
+ enable_falcon_pos_shift_attention(model)
30
+ else:
31
+ raise ValueError(f"got {model.config.model_type}")
32
+ kv_cache = StartRecentKVCache(
33
+ start_size=start_size,
34
+ recent_size=recent_size,
35
+ k_seq_dim=k_seq_dim,
36
+ v_seq_dim=v_seq_dim,
37
+ )
38
+ return kv_cache
gorilla/streaming_llm/kv_cache.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def slice2d(x, start, end):
5
+ return x[:, :, start:end, ...]
6
+
7
+
8
+ def slice3d(x, start, end):
9
+ return x[:, :, :, start:end, ...]
10
+
11
+
12
+ def slice1d(x, start, end):
13
+ return x[:, start:end, ...]
14
+
15
+
16
+ DIM_TO_SLICE = {
17
+ 1: slice1d,
18
+ 2: slice2d,
19
+ 3: slice3d,
20
+ }
21
+
22
+
23
+ class StartRecentKVCache:
24
+ def __init__(
25
+ self,
26
+ start_size=4,
27
+ recent_size=512,
28
+ k_seq_dim=2,
29
+ v_seq_dim=2,
30
+ ):
31
+ print(f"StartRecentKVCache: {start_size}, {recent_size}")
32
+ self.start_size = start_size
33
+ self.recent_size = recent_size
34
+ self.cache_size = start_size + recent_size
35
+ self.k_seq_dim = k_seq_dim
36
+ self.v_seq_dim = v_seq_dim
37
+ self.k_slice = DIM_TO_SLICE[k_seq_dim]
38
+ self.v_slice = DIM_TO_SLICE[v_seq_dim]
39
+
40
+ def __call__(self, past_key_values):
41
+ if past_key_values is None:
42
+ return None
43
+ seq_len = past_key_values[0][0].size(self.k_seq_dim)
44
+ if seq_len <= self.cache_size:
45
+ return past_key_values
46
+ return [
47
+ [
48
+ torch.cat(
49
+ [
50
+ self.k_slice(k, 0, self.start_size),
51
+ self.k_slice(k, seq_len - self.recent_size, seq_len),
52
+ ],
53
+ dim=self.k_seq_dim,
54
+ ),
55
+ torch.cat(
56
+ [
57
+ self.v_slice(v, 0, self.start_size),
58
+ self.v_slice(v, seq_len - self.recent_size, seq_len),
59
+ ],
60
+ dim=self.v_seq_dim,
61
+ ),
62
+ ]
63
+ for k, v in past_key_values
64
+ ]
65
+
66
+ def evict_for_space(self, past_key_values, num_coming):
67
+ if past_key_values is None:
68
+ return None
69
+ seq_len = past_key_values[0][0].size(self.k_seq_dim)
70
+ if seq_len + num_coming <= self.cache_size:
71
+ return past_key_values
72
+ return [
73
+ [
74
+ torch.cat(
75
+ [
76
+ self.k_slice(k, 0, self.start_size),
77
+ self.k_slice(
78
+ k, seq_len - self.recent_size + num_coming, seq_len
79
+ ),
80
+ ],
81
+ dim=self.k_seq_dim,
82
+ ),
83
+ torch.cat(
84
+ [
85
+ self.v_slice(v, 0, self.start_size),
86
+ self.v_slice(
87
+ v, seq_len - self.recent_size + num_coming, seq_len
88
+ ),
89
+ ],
90
+ dim=self.v_seq_dim,
91
+ ),
92
+ ]
93
+ for k, v in past_key_values
94
+ ]
95
+
96
+ def evict_range(self, past_key_values, start, end):
97
+ if past_key_values is None:
98
+ return None
99
+ seq_len = past_key_values[0][0].size(self.k_seq_dim)
100
+ assert start <= end and end <= seq_len
101
+ return [
102
+ [
103
+ torch.cat(
104
+ [
105
+ self.k_slice(k, 0, start),
106
+ self.k_slice(k, end, seq_len),
107
+ ],
108
+ dim=self.k_seq_dim,
109
+ ),
110
+ torch.cat(
111
+ [
112
+ self.v_slice(v, 0, start),
113
+ self.v_slice(v, end, seq_len),
114
+ ],
115
+ dim=self.v_seq_dim,
116
+ ),
117
+ ]
118
+ for k, v in past_key_values
119
+ ]
gorilla/streaming_llm/pos_shift/__init__.py ADDED
File without changes
gorilla/streaming_llm/pos_shift/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (153 Bytes). View file
 
gorilla/streaming_llm/pos_shift/__pycache__/modify_llama.cpython-310.pyc ADDED
Binary file (6.52 kB). View file
 
gorilla/streaming_llm/pos_shift/modify_falcon.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.utils.checkpoint
7
+
8
+ import torch.nn.functional as F
9
+
10
+ from transformers.models.falcon.modeling_falcon import (
11
+ FalconAttention,
12
+ rotate_half,
13
+ )
14
+ import types
15
+
16
+ __all__ = ["enable_falcon_pos_shift_attention"]
17
+
18
+
19
+ def falcon_pos_shift_attention_forward(
20
+ self,
21
+ hidden_states: torch.Tensor,
22
+ alibi: torch.Tensor,
23
+ attention_mask: torch.Tensor,
24
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
25
+ head_mask: Optional[torch.Tensor] = None,
26
+ use_cache: bool = False,
27
+ output_attentions: bool = False,
28
+ ):
29
+ fused_qkv = self.query_key_value(
30
+ hidden_states
31
+ ) # [batch_size, seq_length, 3 x hidden_size]
32
+
33
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
34
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
35
+
36
+ batch_size, q_length, _, _ = query_layer.shape
37
+
38
+ query_layer = query_layer.transpose(1, 2).reshape(
39
+ batch_size * self.num_heads, q_length, self.head_dim
40
+ )
41
+
42
+ # dirty hack to fix the inconsistency between falcon-40b and falcon-7b
43
+ num_kv = self.num_heads if self.num_heads == 128 else self.num_kv
44
+ key_layer = key_layer.transpose(1, 2).reshape(
45
+ batch_size * num_kv,
46
+ q_length,
47
+ self.head_dim,
48
+ )
49
+ value_layer = value_layer.transpose(1, 2).reshape(
50
+ batch_size * num_kv, q_length, self.head_dim
51
+ )
52
+
53
+ past_len = 0
54
+ if layer_past is not None:
55
+ past_len = layer_past[0].shape[1]
56
+
57
+ query_layer_copy = query_layer.clone()
58
+ query_layer, _ = self.maybe_rotary(query_layer, query_layer_copy, past_len)
59
+ if layer_past is not None:
60
+ past_key, past_value = layer_past
61
+ # concatenate along seq_length dimension:
62
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
63
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
64
+ key_layer = torch.cat((past_key, key_layer), dim=1)
65
+ value_layer = torch.cat((past_value, value_layer), dim=1)
66
+
67
+ if use_cache is True:
68
+ present = (key_layer, value_layer)
69
+ else:
70
+ present = None
71
+
72
+ key_layer_copy = key_layer.clone()
73
+ _, key_layer = self.maybe_rotary(key_layer_copy, key_layer, 0)
74
+
75
+ _, kv_length, _ = key_layer.shape
76
+
77
+ if alibi is None:
78
+ query_layer_ = query_layer.reshape(
79
+ batch_size, self.num_heads, -1, self.head_dim
80
+ )
81
+ key_layer_ = key_layer.reshape(batch_size, num_kv, -1, self.head_dim)
82
+ value_layer_ = value_layer.reshape(batch_size, num_kv, -1, self.head_dim)
83
+
84
+ if layer_past is not None:
85
+ attn_output = F.scaled_dot_product_attention(
86
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=False
87
+ )
88
+ else:
89
+ attn_output = F.scaled_dot_product_attention(
90
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
91
+ )
92
+
93
+ x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
94
+ x = x.permute(0, 2, 1, 3)
95
+ attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
96
+
97
+ output_tensor = self.dense(attn_output)
98
+
99
+ outputs = (output_tensor, present)
100
+ assert not output_attentions # not supported.
101
+ return outputs
102
+ else:
103
+ attention_mask_float = (
104
+ (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
105
+ )
106
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
107
+
108
+ # change view to [batch_size, num_heads, q_length, kv_length]
109
+ attention_scores = matmul_result.view(
110
+ batch_size, self.num_heads, q_length, kv_length
111
+ )
112
+
113
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
114
+ input_dtype = attention_scores.dtype
115
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
116
+ if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
117
+ attention_scores = attention_scores.to(torch.float16) #torch.float32
118
+ # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
119
+ attention_probs = F.softmax(
120
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1))
121
+ * self.inv_norm_factor
122
+ + attention_mask_float,
123
+ dim=-1,
124
+ dtype=hidden_states.dtype,
125
+ )
126
+ # [batch_size, num_heads, q_length, kv_length]
127
+ attention_probs = self.attention_dropout(attention_probs)
128
+
129
+ if head_mask is not None:
130
+ attention_probs = attention_probs * head_mask
131
+
132
+ # change view [batch_size x num_heads, q_length, kv_length]
133
+ attention_probs_reshaped = attention_probs.view(
134
+ batch_size * self.num_heads, q_length, kv_length
135
+ )
136
+
137
+ # matmul: [batch_size * num_heads, q_length, head_dim]
138
+ context_layer = attention_probs_reshaped @ value_layer
139
+
140
+ # change view [batch_size, num_heads, q_length, head_dim]
141
+ context_layer = self._merge_heads(context_layer)
142
+
143
+ output_tensor = self.dense(context_layer)
144
+
145
+ outputs = (output_tensor, present)
146
+ if output_attentions:
147
+ outputs += (attention_probs,)
148
+
149
+ return outputs
150
+
151
+
152
+ def enable_falcon_pos_shift_attention(model):
153
+ for name, module in reversed(model._modules.items()):
154
+ if len(list(module.children())) > 0:
155
+ enable_falcon_pos_shift_attention(
156
+ module,
157
+ )
158
+
159
+ if "self_attention" == name[-14:]:
160
+ model._modules[name].forward = types.MethodType(
161
+ falcon_pos_shift_attention_forward, model._modules[name]
162
+ )
gorilla/streaming_llm/pos_shift/modify_llama.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.utils.checkpoint
7
+
8
+ import torch.nn.functional as F
9
+
10
+ from transformers.models.llama.modeling_llama import (
11
+ LlamaAttention,
12
+ rotate_half,
13
+ apply_rotary_pos_emb,
14
+ repeat_kv,
15
+ )
16
+ import types
17
+ import transformers
18
+ from einops import rearrange
19
+ from flash_attn import __version__ as flash_attn_version
20
+ from flash_attn.bert_padding import pad_input, unpad_input
21
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
22
+
23
+ __all__ = ["enable_llama_pos_shift_attention"]
24
+
25
+
26
+ def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
27
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
28
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
29
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
30
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
31
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
32
+ x_embed = (x * cos) + (rotate_half(x) * sin)
33
+ return x_embed
34
+
35
+
36
+ def llama_pos_shift_attention_forward(
37
+ self,
38
+ hidden_states: torch.Tensor,
39
+ attention_mask: Optional[torch.Tensor] = None,
40
+ position_ids: Optional[torch.LongTensor] = None,
41
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
42
+ output_attentions: bool = False,
43
+ use_cache: bool = False,
44
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
45
+ bsz, q_len, _ = hidden_states.size()
46
+
47
+ if self.config.pretraining_tp > 1:
48
+ key_value_slicing = (
49
+ self.num_key_value_heads * self.head_dim
50
+ ) // self.config.pretraining_tp
51
+ query_slices = self.q_proj.weight.split(
52
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
53
+ )
54
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
55
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
56
+
57
+ query_states = [
58
+ F.linear(hidden_states, query_slices[i])
59
+ for i in range(self.config.pretraining_tp)
60
+ ]
61
+ query_states = torch.cat(query_states, dim=-1)
62
+
63
+ key_states = [
64
+ F.linear(hidden_states, key_slices[i])
65
+ for i in range(self.config.pretraining_tp)
66
+ ]
67
+ key_states = torch.cat(key_states, dim=-1)
68
+
69
+ value_states = [
70
+ F.linear(hidden_states, value_slices[i])
71
+ for i in range(self.config.pretraining_tp)
72
+ ]
73
+ value_states = torch.cat(value_states, dim=-1)
74
+
75
+ else:
76
+ query_states = self.q_proj(hidden_states)
77
+ key_states = self.k_proj(hidden_states)
78
+ value_states = self.v_proj(hidden_states)
79
+
80
+ query_states = query_states.view(
81
+ bsz, q_len, self.num_heads, self.head_dim
82
+ ).transpose(1, 2)
83
+ key_states = key_states.view(
84
+ bsz, q_len, self.num_key_value_heads, self.head_dim
85
+ ).transpose(1, 2)
86
+ value_states = value_states.view(
87
+ bsz, q_len, self.num_key_value_heads, self.head_dim
88
+ ).transpose(1, 2)
89
+
90
+ kv_seq_len = key_states.shape[-2]
91
+ if past_key_value is not None:
92
+ kv_seq_len += past_key_value[0].shape[-2]
93
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
94
+ ### Shift Pos: query pos is min(cache_size, idx)
95
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
96
+ query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
97
+ ###
98
+
99
+ if past_key_value is not None:
100
+ # reuse k, v, self_attention
101
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
102
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
103
+
104
+ past_key_value = (key_states, value_states) if use_cache else None
105
+
106
+ ### Shift Pos: key pos is the pos in cache
107
+ key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
108
+ key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
109
+ ###
110
+
111
+ # repeat k/v heads if n_kv_heads < n_heads
112
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
113
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
114
+
115
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
116
+ self.head_dim
117
+ )
118
+
119
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
120
+ raise ValueError(
121
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
122
+ f" {attn_weights.size()}"
123
+ )
124
+
125
+ if attention_mask is not None:
126
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
127
+ raise ValueError(
128
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
129
+ )
130
+ attn_weights = attn_weights + attention_mask
131
+
132
+ # upcast attention to fp16
133
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16).to( #torch.float32
134
+ query_states.dtype
135
+ )
136
+ attn_output = torch.matmul(attn_weights, value_states)
137
+
138
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
139
+ raise ValueError(
140
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
141
+ f" {attn_output.size()}"
142
+ )
143
+
144
+ attn_output = attn_output.transpose(1, 2).contiguous()
145
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
146
+
147
+ if self.config.pretraining_tp > 1:
148
+ attn_output = attn_output.split(
149
+ self.hidden_size // self.config.pretraining_tp, dim=2
150
+ )
151
+ o_proj_slices = self.o_proj.weight.split(
152
+ self.hidden_size // self.config.pretraining_tp, dim=1
153
+ )
154
+ attn_output = sum(
155
+ [
156
+ F.linear(attn_output[i], o_proj_slices[i])
157
+ for i in range(self.config.pretraining_tp)
158
+ ]
159
+ )
160
+ else:
161
+ attn_output = self.o_proj(attn_output)
162
+
163
+ if not output_attentions:
164
+ attn_weights = None
165
+
166
+ return attn_output, attn_weights, past_key_value
167
+
168
+
169
+ def llama_pos_shift_attention_forward_flashattn(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ attention_mask: Optional[torch.Tensor] = None,
173
+ position_ids: Optional[torch.LongTensor] = None,
174
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
175
+ output_attentions: bool = False,
176
+ use_cache: bool = False,
177
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
178
+ bsz, q_len, _ = hidden_states.size()
179
+
180
+ query_states = self.q_proj(hidden_states)
181
+ key_states = self.k_proj(hidden_states)
182
+ value_states = self.v_proj(hidden_states)
183
+
184
+ query_states = query_states.view(
185
+ bsz, q_len, self.num_heads, self.head_dim
186
+ ).transpose(1, 2)
187
+ key_states = key_states.view(
188
+ bsz, q_len, self.num_key_value_heads, self.head_dim
189
+ ).transpose(1, 2)
190
+ value_states = value_states.view(
191
+ bsz, q_len, self.num_key_value_heads, self.head_dim
192
+ ).transpose(1, 2)
193
+
194
+ kv_seq_len = key_states.shape[-2]
195
+ if past_key_value is not None:
196
+ kv_seq_len += past_key_value[0].shape[-2]
197
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
198
+ ### Shift Pos: query pos is min(cache_size, idx)
199
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
200
+ query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
201
+ ###
202
+
203
+ if past_key_value is not None:
204
+ # reuse k, v, self_attention
205
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
206
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
207
+
208
+ past_key_value = (key_states, value_states) if use_cache else None
209
+
210
+ ### Shift Pos: key pos is the pos in cache
211
+ key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
212
+ key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
213
+ ###
214
+
215
+ # repeat k/v heads if n_kv_heads < n_heads
216
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
217
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
218
+
219
+ if past_key_value is None:
220
+ qkv = torch.stack(
221
+ [query_states, key_states, value_states], dim=2
222
+ ) # [bsz, nh, 3, q_len, hd]
223
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
224
+
225
+ key_padding_mask = torch.full((bsz, q_len), True, dtype=torch.bool, device=attention_mask.device)
226
+ nheads = qkv.shape[-2]
227
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
228
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
229
+ x_unpad = rearrange(
230
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
231
+ )
232
+ output_unpad = flash_attn_varlen_qkvpacked_func(
233
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
234
+ )
235
+ output = rearrange(
236
+ pad_input(
237
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
238
+ ),
239
+ "b s (h d) -> b s h d",
240
+ h=nheads,
241
+ )
242
+ output = output.reshape(bsz, q_len, self.num_heads, self.head_dim)
243
+
244
+ attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
245
+ attn_weights = None
246
+ else:
247
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
248
+ self.head_dim
249
+ )
250
+
251
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
252
+ raise ValueError(
253
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
254
+ f" {attn_weights.size()}"
255
+ )
256
+
257
+ if attention_mask is not None:
258
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
259
+ raise ValueError(
260
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
261
+ )
262
+ attn_weights = attn_weights + attention_mask
263
+
264
+ # upcast attention to fp16
265
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16).to( #torch.float32
266
+ query_states.dtype
267
+ )
268
+ attn_output = torch.matmul(attn_weights, value_states)
269
+
270
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
271
+ raise ValueError(
272
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
273
+ f" {attn_output.size()}"
274
+ )
275
+
276
+ attn_output = attn_output.transpose(1, 2).contiguous()
277
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
278
+
279
+ if self.config.pretraining_tp > 1:
280
+ attn_output = attn_output.split(
281
+ self.hidden_size // self.config.pretraining_tp, dim=2
282
+ )
283
+ o_proj_slices = self.o_proj.weight.split(
284
+ self.hidden_size // self.config.pretraining_tp, dim=1
285
+ )
286
+ attn_output = sum(
287
+ [
288
+ F.linear(attn_output[i], o_proj_slices[i])
289
+ for i in range(self.config.pretraining_tp)
290
+ ]
291
+ )
292
+ else:
293
+ attn_output = self.o_proj(attn_output)
294
+
295
+ if not output_attentions:
296
+ attn_weights = None
297
+
298
+ return attn_output, attn_weights, past_key_value
299
+
300
+
301
+ def enable_llama_pos_shift_attention(model, use_flash_attn=True):
302
+ for name, module in reversed(model._modules.items()):
303
+ if len(list(module.children())) > 0:
304
+ enable_llama_pos_shift_attention(
305
+ module,
306
+ )
307
+
308
+ if isinstance(module, LlamaAttention):
309
+ model._modules[name].forward = types.MethodType(
310
+ llama_pos_shift_attention_forward_flashattn if use_flash_attn else llama_pos_shift_attention_forward, model._modules[name]
311
+ )
gorilla/streaming_llm/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ )
7
+ import os.path as osp
8
+ import ssl
9
+ import urllib.request
10
+ import os
11
+ import json
12
+
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ "--model_name_or_path", type=str, default="models/llama/llama-7b"
18
+ )
19
+ parser.add_argument("--revision", type=str, default="main")
20
+ parser.add_argument("--tokenizer_name_or_path", type=str, default=None)
21
+ parser.add_argument("--dataset_name", type=str, default="wikitext")
22
+
23
+ parser.add_argument("--task", type=str, default="wikitext-2-raw-v1")
24
+ parser.add_argument(
25
+ "--split", type=str, default="test", choices=["validation", "test"]
26
+ )
27
+
28
+ parser.add_argument(
29
+ "--num_samples",
30
+ type=int,
31
+ default=1,
32
+ )
33
+
34
+ parser.add_argument(
35
+ "--output_dir",
36
+ type=str,
37
+ default="outputs/debug",
38
+ )
39
+
40
+ parser.add_argument("--enable_start_recent_kv_cache", action="store_true")
41
+ parser.add_argument("--start_size", type=int, default=1)
42
+ parser.add_argument("--recent_size", type=int, default=255)
43
+ parser.add_argument("--enable_pos_shift", action="store_true")
44
+
45
+ parser.add_argument("--num_eval_tokens", type=int, default=None)
46
+
47
+ args = parser.parse_args()
48
+ return args
49
+
50
+
51
+ def load(model_name_or_path):
52
+ print(f"Loading model from {model_name_or_path} ...")
53
+ # however, tensor parallel for running falcon will occur bugs
54
+ tokenizer = AutoTokenizer.from_pretrained(
55
+ model_name_or_path,
56
+ trust_remote_code=True,
57
+ )
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ model_name_or_path,
60
+ device_map="auto",
61
+ torch_dtype=torch.float16,
62
+ trust_remote_code=True,
63
+ )
64
+ if tokenizer.pad_token_id is None:
65
+ if tokenizer.eos_token_id is not None:
66
+ tokenizer.pad_token_id = tokenizer.eos_token_id
67
+ else:
68
+ tokenizer.pad_token_id = 0
69
+
70
+ model.eval()
71
+
72
+ return model, tokenizer
73
+
74
+
75
+ def download_url(url: str, folder="folder"):
76
+ """
77
+ Downloads the content of an url to a folder. Modified from \
78
+ https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric
79
+
80
+ Args:
81
+ url (string): The url of target file.
82
+ folder (string): The target folder.
83
+
84
+ Returns:
85
+ string: File path of downloaded files.
86
+ """
87
+
88
+ file = url.rpartition("/")[2]
89
+ file = file if file[0] == "?" else file.split("?")[0]
90
+ path = osp.join(folder, file)
91
+ if osp.exists(path):
92
+ print(f"File {file} exists, use existing file.")
93
+ return path
94
+
95
+ print(f"Downloading {url}")
96
+ os.makedirs(folder, exist_ok=True)
97
+ ctx = ssl._create_unverified_context()
98
+ data = urllib.request.urlopen(url, context=ctx)
99
+ with open(path, "wb") as f:
100
+ f.write(data.read())
101
+
102
+ return path
103
+
104
+
105
+ def load_jsonl(
106
+ file_path,
107
+ ):
108
+ list_data_dict = []
109
+ with open(file_path, "r") as f:
110
+ for line in f:
111
+ list_data_dict.append(json.loads(line))
112
+ return list_data_dict
gorilla/style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: white;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
11
+
12
+ .contain {
13
+ max-width: 900px;
14
+ margin: auto;
15
+ padding-top: 1.5rem;
16
+ }
gorilla/supervised-fine-tune-qlora.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import copy
4
+ import json
5
+ import math
6
+ import logging
7
+ from dataclasses import dataclass, field
8
+ from typing import Dict, Optional, Sequence
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import transformers
13
+ from torch.utils.data import Dataset
14
+ from transformers import Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig
15
+ from llama_attn_replace_sft import replace_llama_attn
16
+ from peft import LoraConfig, get_peft_model
17
+ from torch.distributed import barrier
18
+
19
+ IGNORE_INDEX = -100
20
+ DEFAULT_PAD_TOKEN = "[PAD]"
21
+ DEFAULT_EOS_TOKEN = "</s>"
22
+ DEFAULT_BOS_TOKEN = "<s>"
23
+ DEFAULT_UNK_TOKEN = "<unk>"
24
+
25
+ def _make_r_io_base(f, mode: str):
26
+ if not isinstance(f, io.IOBase):
27
+ f = open(f, mode=mode)
28
+ return f
29
+
30
+ def jload(f, mode="r"):
31
+ """Load a .json file into a dictionary."""
32
+ f = _make_r_io_base(f, mode)
33
+ jdict = json.load(f)
34
+ f.close()
35
+ return jdict
36
+
37
+ PROMPT_DICT = {
38
+ "prompt_input": (
39
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
40
+ "Write a response that appropriately completes the request.\n\n"
41
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
42
+ ),
43
+ "prompt_no_input": (
44
+ "Below is an instruction that describes a task. "
45
+ "Write a response that appropriately completes the request.\n\n"
46
+ "### Instruction:\n{instruction}\n\n### Response:"
47
+ ),
48
+ "prompt_no_input_llama2":(
49
+ "<s>[INST] <<SYS>>\n"
50
+ "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
51
+ "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
52
+ "<</SYS>> \n\n {instruction} [/INST]"
53
+ ),
54
+ "prompt_input_llama2": (
55
+ "<s>[INST] <<SYS>>\n"
56
+ "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
57
+ "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
58
+ "<</SYS>> \n\n {instruction} \n{input} [/INST]"
59
+ )
60
+ }
61
+
62
+
63
+ @dataclass
64
+ class ModelArguments:
65
+ model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
66
+ model_type: Optional[str] = field(default="llama")
67
+
68
+
69
+ @dataclass
70
+ class DataArguments:
71
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
72
+
73
+
74
+ @dataclass
75
+ class TrainingArguments(transformers.TrainingArguments):
76
+ cache_dir: Optional[str] = field(default=None)
77
+ optim: str = field(default="adamw_torch")
78
+ model_max_length: int = field(
79
+ default=8192,
80
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
81
+ )
82
+ use_flash_attn: bool = field(
83
+ default=True,
84
+ metadata={"help": "Whether use flash attention for training."},
85
+ )
86
+ use_full_attn: bool = field(
87
+ default=False,
88
+ metadata={"help": "Whether to use plain, full-attention for training."},
89
+ )
90
+ low_rank_training: bool = field(
91
+ default=True,
92
+ metadata={"help": "Whether use low rank adaptation for training."},
93
+ )
94
+ trainable_params: str = field(
95
+ default="embed,norm",
96
+ metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
97
+ )
98
+
99
+ def smart_tokenizer_and_embedding_resize(
100
+ special_tokens_dict: Dict,
101
+ tokenizer: transformers.PreTrainedTokenizer,
102
+ model: transformers.PreTrainedModel,
103
+ ):
104
+ """Resize tokenizer and embedding.
105
+
106
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
107
+ """
108
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
109
+ model.resize_token_embeddings(len(tokenizer))
110
+
111
+ if num_new_tokens > 0:
112
+ input_embeddings = model.get_input_embeddings().weight.data
113
+ output_embeddings = model.get_output_embeddings().weight.data
114
+
115
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
116
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
117
+
118
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
119
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
120
+
121
+
122
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
123
+ """Tokenize a list of strings."""
124
+ tokenized_list = [
125
+ tokenizer(
126
+ text,
127
+ return_tensors="pt",
128
+ padding="longest",
129
+ max_length=tokenizer.model_max_length,
130
+ truncation=True,
131
+ )
132
+ for text in strings
133
+ ]
134
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
135
+ input_ids_lens = labels_lens = [
136
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
137
+ ]
138
+ return dict(
139
+ input_ids=input_ids,
140
+ labels=labels,
141
+ input_ids_lens=input_ids_lens,
142
+ labels_lens=labels_lens,
143
+ )
144
+
145
+
146
+ def preprocess(
147
+ sources: Sequence[str],
148
+ targets: Sequence[str],
149
+ tokenizer: transformers.PreTrainedTokenizer,
150
+ ) -> Dict:
151
+ """Preprocess the data by tokenizing."""
152
+ examples = [s + t for s, t in zip(sources, targets)]
153
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
154
+ input_ids = examples_tokenized["input_ids"]
155
+ labels = copy.deepcopy(input_ids)
156
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
157
+ label[:source_len] = IGNORE_INDEX
158
+ return dict(input_ids=input_ids, labels=labels)
159
+
160
+
161
+ class SupervisedDataset(Dataset):
162
+ """Dataset for supervised fine-tuning."""
163
+
164
+ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
165
+ super(SupervisedDataset, self).__init__()
166
+ logging.warning("Loading line item and alm data...")
167
+ list_data_dict = jload(data_path)
168
+
169
+ logging.warning("Formatting inputs...")
170
+
171
+ prompt_input, prompt_no_input = PROMPT_DICT["prompt_input_llama2"], PROMPT_DICT["prompt_no_input_llama2"]
172
+ sources = [
173
+ prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
174
+ for example in list_data_dict
175
+ ]
176
+
177
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
178
+
179
+ logging.warning("Tokenizing inputs... This may take some time...")
180
+ data_dict = preprocess(sources, targets, tokenizer)
181
+
182
+ self.input_ids = data_dict["input_ids"]
183
+ self.labels = data_dict["labels"]
184
+
185
+ def __len__(self):
186
+ return len(self.input_ids)
187
+
188
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
189
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
190
+
191
+
192
+ @dataclass
193
+ class DataCollatorForSupervisedDataset(object):
194
+ """Collate examples for supervised fine-tuning."""
195
+
196
+ tokenizer: transformers.PreTrainedTokenizer
197
+
198
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
199
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
200
+ input_ids = torch.nn.utils.rnn.pad_sequence(
201
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
202
+ )
203
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
204
+ return dict(
205
+ input_ids=input_ids,
206
+ labels=labels,
207
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
208
+ )
209
+
210
+
211
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
212
+ """Make dataset and collator for supervised fine-tuning."""
213
+ train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
214
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
215
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
216
+
217
+
218
+ def train():
219
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
220
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
221
+
222
+ # NOTE: May expand supported model types in the future
223
+ # if model_args.model_type == "gpt-neox":
224
+ # replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
225
+ # else:
226
+ replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
227
+
228
+ # Set RoPE scaling factor
229
+ config = transformers.AutoConfig.from_pretrained(
230
+ model_args.model_name_or_path,
231
+ cache_dir=training_args.cache_dir,
232
+ )
233
+
234
+ orig_rope_scaling = getattr(config, "rope_scaling", {"factor": 1})
235
+ # Check if orig_rope_scaling is a dictionary before accessing its "get" method
236
+ if isinstance(orig_rope_scaling, dict):
237
+ orig_rope_scaling_factor = orig_rope_scaling.get("factor", 1)
238
+ else:
239
+ orig_rope_scaling_factor = 1 #orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
240
+
241
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
242
+ if orig_ctx_len:
243
+ orig_ctx_len *= orig_rope_scaling_factor
244
+ if training_args.model_max_length > orig_ctx_len:
245
+ scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
246
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
247
+
248
+ # Load model and tokenizer
249
+ model = transformers.AutoModelForCausalLM.from_pretrained(
250
+ model_args.model_name_or_path,
251
+ config=config,
252
+ cache_dir=training_args.cache_dir,
253
+ torch_dtype=torch.bfloat16,
254
+ quantization_config=BitsAndBytesConfig(
255
+ load_in_4bit=True,
256
+ llm_int8_threshold=6.0,
257
+ llm_int8_has_fp16_weight=False,
258
+ bnb_4bit_compute_dtype=torch.bfloat16,
259
+ bnb_4bit_use_double_quant=True,
260
+ bnb_4bit_quant_type="nf4",
261
+ ),
262
+ )
263
+
264
+ for param in model.parameters():
265
+ param.requires_grad = False # freeze the model - train adapters later
266
+ if param.ndim == 1:
267
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
268
+ param.data = param.data.to(torch.float16) #torch.float32
269
+
270
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
271
+ model_args.model_name_or_path,
272
+ cache_dir=training_args.cache_dir,
273
+ model_max_length=training_args.model_max_length,
274
+ padding_side="right",
275
+ use_fast=True,
276
+ )
277
+
278
+ special_tokens_dict = dict()
279
+ if tokenizer.pad_token is None:
280
+ special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
281
+ if tokenizer.eos_token is None:
282
+ special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
283
+ if tokenizer.bos_token is None:
284
+ special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
285
+ if tokenizer.unk_token is None:
286
+ special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
287
+
288
+ smart_tokenizer_and_embedding_resize(
289
+ special_tokens_dict=special_tokens_dict,
290
+ tokenizer=tokenizer,
291
+ model=model,
292
+ )
293
+
294
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
295
+
296
+ if training_args.low_rank_training:
297
+ if model_args.model_type == "gpt-neox":
298
+ # added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
299
+ targets = ["query_key_value", "dense"]
300
+ else:
301
+ targets=["q_proj", "k_proj", "v_proj", "o_proj"]
302
+
303
+ config = LoraConfig(
304
+ r=8,
305
+ lora_alpha=16,
306
+ target_modules=targets,
307
+ lora_dropout=0,
308
+ bias="none",
309
+ task_type="CAUSAL_LM",
310
+ )
311
+ model = get_peft_model(model, config)
312
+ # enable trainable params
313
+ [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
314
+
315
+ class CastOutputToFloat(nn.Sequential):
316
+ def forward(self, x):
317
+ return super().forward(x).to(torch.float16) #torch.float32
318
+
319
+ model.lm_head = CastOutputToFloat(model.lm_head)
320
+
321
+ # Verifying the datatypes.
322
+ dtypes = {}
323
+ for _, p in model.named_parameters():
324
+ dtype = p.dtype
325
+ if dtype not in dtypes:
326
+ dtypes[dtype] = 0
327
+ dtypes[dtype] += p.numel()
328
+ total = 0
329
+ for k, v in dtypes.items():
330
+ total += v
331
+ for k, v in dtypes.items():
332
+ print(k, v, v / total)
333
+
334
+ model.config.use_cache = True # required for gradient checkpointing
335
+ model.enable_input_require_grads() # required for gradient checkpointing
336
+ model.gradient_checkpointing_enable() # enable gradient checkpointing
337
+
338
+ trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
339
+ trainer.train()
340
+ trainer.save_state()
341
+ trainer.save_model(output_dir=training_args.output_dir)
342
+
343
+
344
+ if __name__ == "__main__":
345
+ train()
gorilla/supervised-fine-tune.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import copy
4
+ import json
5
+ import math
6
+ import logging
7
+ # %%
8
+ import pandas as pd
9
+
10
+ # print('Loading line item and alm data')
11
+ # df_i = pd.read_csv('/home/tosi-n/ark/data/jack_line_item_ner_task.csv', sep='\t')[['context', 'instruction', 'response']]
12
+ # df_ii = pd.read_csv('/home/tosi-n/ark/data/alm_task_data.csv')[['context', 'instruction', 'response']]
13
+ # df = pd.concat([df_i, df_ii], ignore_index=True)
14
+ # # rename columns context and response to input and output
15
+ # df = df.rename(columns={'context':'input', 'response':'output'})
16
+
17
+ # # Replace NoneType with empty string
18
+ # df = df.fillna('')
19
+ # # produce a list of dictionaries
20
+ # list_data_dict = df.to_dict('records')
21
+ # import json
22
+ # with open('/home/tosi-n/ark/data/line_item_and_alm_data.json', 'w') as f:
23
+ # json.dump(list_data_dict, f)
24
+
25
+ # %%
26
+ from dataclasses import dataclass, field
27
+ from typing import Dict, Optional, Sequence
28
+
29
+ import torch
30
+ import transformers
31
+ from torch.utils.data import Dataset
32
+ from transformers import Trainer, DataCollatorForLanguageModeling
33
+ from llama_attn_replace_sft import replace_llama_attn
34
+ from peft import LoraConfig, get_peft_model
35
+ from torch.distributed import barrier
36
+
37
+ IGNORE_INDEX = -100
38
+ DEFAULT_PAD_TOKEN = "[PAD]"
39
+ DEFAULT_EOS_TOKEN = "</s>"
40
+ DEFAULT_BOS_TOKEN = "<s>"
41
+ DEFAULT_UNK_TOKEN = "<unk>"
42
+
43
+ def _make_r_io_base(f, mode: str):
44
+ if not isinstance(f, io.IOBase):
45
+ f = open(f, mode=mode)
46
+ return f
47
+
48
+ def jload(f, mode="r"):
49
+ """Load a .json file into a dictionary."""
50
+ f = _make_r_io_base(f, mode)
51
+ jdict = json.load(f)
52
+ f.close()
53
+ return jdict
54
+
55
+ PROMPT_DICT = {
56
+ "prompt_input": (
57
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
58
+ "Write a response that appropriately completes the request.\n\n"
59
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
60
+ ),
61
+ "prompt_no_input": (
62
+ "Below is an instruction that describes a task. "
63
+ "Write a response that appropriately completes the request.\n\n"
64
+ "### Instruction:\n{instruction}\n\n### Response:"
65
+ ),
66
+ "prompt_no_input_llama2":(
67
+ "<s>[INST] <<SYS>>\n"
68
+ "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
69
+ "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
70
+ "<</SYS>> \n\n {instruction} [/INST]"
71
+ ),
72
+ "prompt_input_llama2": (
73
+ "<s>[INST] <<SYS>>\n"
74
+ "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
75
+ "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
76
+ "<</SYS>> \n\n {instruction} \n{input} [/INST]"
77
+ )
78
+ }
79
+
80
+
81
+ @dataclass
82
+ class ModelArguments:
83
+ model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
84
+ model_type: Optional[str] = field(default="llama")
85
+
86
+
87
+ @dataclass
88
+ class DataArguments:
89
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
90
+
91
+
92
+ @dataclass
93
+ class TrainingArguments(transformers.TrainingArguments):
94
+ cache_dir: Optional[str] = field(default=None)
95
+ optim: str = field(default="adamw_torch")
96
+ model_max_length: int = field(
97
+ default=8192 * 4,
98
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
99
+ )
100
+ use_flash_attn: bool = field(
101
+ default=True,
102
+ metadata={"help": "Whether use flash attention for training."},
103
+ )
104
+ use_full_attn: bool = field(
105
+ default=False,
106
+ metadata={"help": "Whether to use plain, full-attention for training."},
107
+ )
108
+ low_rank_training: bool = field(
109
+ default=True,
110
+ metadata={"help": "Whether use low rank adaptation for training."},
111
+ )
112
+ trainable_params: str = field(
113
+ default="embed,norm",
114
+ metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
115
+ )
116
+
117
+ def smart_tokenizer_and_embedding_resize(
118
+ special_tokens_dict: Dict,
119
+ tokenizer: transformers.PreTrainedTokenizer,
120
+ model: transformers.PreTrainedModel,
121
+ ):
122
+ """Resize tokenizer and embedding.
123
+
124
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
125
+ """
126
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
127
+ model.resize_token_embeddings(len(tokenizer))
128
+
129
+ if num_new_tokens > 0:
130
+ input_embeddings = model.get_input_embeddings().weight.data
131
+ output_embeddings = model.get_output_embeddings().weight.data
132
+
133
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
134
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
135
+
136
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
137
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
138
+
139
+
140
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
141
+ """Tokenize a list of strings."""
142
+ tokenized_list = [
143
+ tokenizer(
144
+ text,
145
+ return_tensors="pt",
146
+ padding="longest",
147
+ max_length=tokenizer.model_max_length,
148
+ truncation=True,
149
+ )
150
+ for text in strings
151
+ ]
152
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
153
+ input_ids_lens = labels_lens = [
154
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
155
+ ]
156
+ return dict(
157
+ input_ids=input_ids,
158
+ labels=labels,
159
+ input_ids_lens=input_ids_lens,
160
+ labels_lens=labels_lens,
161
+ )
162
+
163
+
164
+ def preprocess(
165
+ sources: Sequence[str],
166
+ targets: Sequence[str],
167
+ tokenizer: transformers.PreTrainedTokenizer,
168
+ ) -> Dict:
169
+ """Preprocess the data by tokenizing."""
170
+ examples = [s + t for s, t in zip(sources, targets)]
171
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
172
+ input_ids = examples_tokenized["input_ids"]
173
+ labels = copy.deepcopy(input_ids)
174
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
175
+ label[:source_len] = IGNORE_INDEX
176
+ return dict(input_ids=input_ids, labels=labels)
177
+
178
+
179
+ class SupervisedDataset(Dataset):
180
+ """Dataset for supervised fine-tuning."""
181
+
182
+ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
183
+ super(SupervisedDataset, self).__init__()
184
+ logging.warning("Loading line item and alm data...")
185
+ list_data_dict = jload(data_path)
186
+
187
+ logging.warning("Formatting inputs...")
188
+
189
+ prompt_input, prompt_no_input = PROMPT_DICT["prompt_input_llama2"], PROMPT_DICT["prompt_no_input_llama2"]
190
+ sources = [
191
+ prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
192
+ for example in list_data_dict
193
+ ]
194
+
195
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
196
+
197
+ logging.warning("Tokenizing inputs... This may take some time...")
198
+ data_dict = preprocess(sources, targets, tokenizer)
199
+
200
+ self.input_ids = data_dict["input_ids"]
201
+ self.labels = data_dict["labels"]
202
+
203
+ def __len__(self):
204
+ return len(self.input_ids)
205
+
206
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
207
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
208
+
209
+
210
+ @dataclass
211
+ class DataCollatorForSupervisedDataset(object):
212
+ """Collate examples for supervised fine-tuning."""
213
+
214
+ tokenizer: transformers.PreTrainedTokenizer
215
+
216
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
217
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
218
+ input_ids = torch.nn.utils.rnn.pad_sequence(
219
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
220
+ )
221
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
222
+ return dict(
223
+ input_ids=input_ids,
224
+ labels=labels,
225
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
226
+ )
227
+
228
+
229
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
230
+ """Make dataset and collator for supervised fine-tuning."""
231
+ train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
232
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
233
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
234
+
235
+
236
+ def train():
237
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
238
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
239
+
240
+ # NOTE: May expand supported model types in the future
241
+ # if model_args.model_type == "gpt-neox":
242
+ # replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
243
+ # else:
244
+ replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
245
+
246
+ # Set RoPE scaling factor
247
+ config = transformers.AutoConfig.from_pretrained(
248
+ model_args.model_name_or_path,
249
+ cache_dir=training_args.cache_dir,
250
+ )
251
+
252
+ orig_rope_scaling = getattr(config, "rope_scaling", {"factor": 1})
253
+ # Check if orig_rope_scaling is a dictionary before accessing its "get" method
254
+ if isinstance(orig_rope_scaling, dict):
255
+ orig_rope_scaling_factor = orig_rope_scaling.get("factor", 1)
256
+ else:
257
+ orig_rope_scaling_factor = 1 #orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
258
+
259
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
260
+ if orig_ctx_len:
261
+ orig_ctx_len *= orig_rope_scaling_factor
262
+ if training_args.model_max_length > orig_ctx_len:
263
+ scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
264
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
265
+
266
+ # Load model and tokenizer
267
+ model = transformers.AutoModelForCausalLM.from_pretrained(
268
+ model_args.model_name_or_path,
269
+ config=config,
270
+ cache_dir=training_args.cache_dir,
271
+ torch_dtype=torch.bfloat16,
272
+ )
273
+
274
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
275
+ model_args.model_name_or_path,
276
+ cache_dir=training_args.cache_dir,
277
+ model_max_length=training_args.model_max_length,
278
+ padding_side="right",
279
+ use_fast=True,
280
+ )
281
+
282
+ special_tokens_dict = dict()
283
+ if tokenizer.pad_token is None:
284
+ special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
285
+ if tokenizer.eos_token is None:
286
+ special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
287
+ if tokenizer.bos_token is None:
288
+ special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
289
+ if tokenizer.unk_token is None:
290
+ special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
291
+
292
+ smart_tokenizer_and_embedding_resize(
293
+ special_tokens_dict=special_tokens_dict,
294
+ tokenizer=tokenizer,
295
+ model=model,
296
+ )
297
+
298
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
299
+
300
+ if training_args.low_rank_training:
301
+ if model_args.model_type == "gpt-neox":
302
+ # added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
303
+ targets = ["query_key_value", "dense"]
304
+ else:
305
+ targets=["q_proj", "k_proj", "v_proj", "o_proj"]
306
+
307
+ config = LoraConfig(
308
+ r=8,
309
+ lora_alpha=16,
310
+ target_modules=targets,
311
+ lora_dropout=0,
312
+ bias="none",
313
+ task_type="CAUSAL_LM",
314
+ )
315
+ model = get_peft_model(model, config)
316
+ # enable trainable params
317
+ [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
318
+
319
+ model.config.use_cache = False # required for gradient checkpointing
320
+ model.enable_input_require_grads() # required for gradient checkpointing
321
+ model.gradient_checkpointing_enable() # enable gradient checkpointing
322
+
323
+ trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
324
+ trainer.train()
325
+ trainer.save_state()
326
+ trainer.save_model(output_dir=training_args.output_dir)
327
+
328
+
329
+ if __name__ == "__main__":
330
+ train()