SivilTaram commited on
Commit
470be5c
1 Parent(s): 753c587

update demo

Browse files
Files changed (6) hide show
  1. app.py +115 -0
  2. hub_name.py +198 -0
  3. lora/adapter_config.json +20 -0
  4. redirect.py +128 -0
  5. requirements.txt +3 -0
  6. util.py +170 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from hub_name import LORA_HUB_NAMES
3
+ from random import shuffle
4
+ import pandas as pd
5
+ import streamlit as st
6
+ import contextlib
7
+ from functools import wraps
8
+ from io import StringIO
9
+ import contextlib
10
+ import redirect as rd
11
+ import torch
12
+ import shutil
13
+ import os
14
+
15
+
16
+ css = """
17
+ <style>
18
+ .stDataFrame { width: 100% !important; }
19
+ </style>
20
+ """
21
+ st.markdown(css, unsafe_allow_html=True)
22
+
23
+
24
+ def main():
25
+ st.title("LoraHub")
26
+ st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.")
27
+
28
+ st.markdown("In this demo, you will use avaiable lora modules selected in the left sidebar to tackle your few-shot examples. When the LoraHub learning is done, you can download the final LoRA module and use it for your new task. You can check out more details in our [paper](https://huggingface.co/papers/2307.13269).")
29
+
30
+ with st.sidebar:
31
+ st.title("LoRA Module Pool")
32
+ st.markdown(
33
+ "The following modules are available for you to compose for your new task. Every module name is a peft repository in Huggingface Hub, and you can find them [here](https://huggingface.co/models?search=lorahub).")
34
+
35
+ df = pd.DataFrame({
36
+ "Index": list(range(len(LORA_HUB_NAMES))),
37
+ "Module Name": LORA_HUB_NAMES,
38
+ })
39
+ st.data_editor(df,
40
+ disabled=["LoRA Module", "Index"],
41
+ hide_index=True)
42
+
43
+ st.multiselect(
44
+ 'Select your favorite modules as the candidate for LoRA composition',
45
+ list(range(len(LORA_HUB_NAMES))),
46
+ [],
47
+ key="select_names")
48
+
49
+ def set_lucky_modules():
50
+ names = list(range(len(LORA_HUB_NAMES)))
51
+ shuffle(names)
52
+ names = names[:20]
53
+ st.session_state["select_names"] = names
54
+
55
+ st.button(":game_die: Give 20 Lucky Modules",
56
+ on_click=set_lucky_modules)
57
+ st.write('We will use the following modules', [
58
+ LORA_HUB_NAMES[i] for i in st.session_state["select_names"]])
59
+
60
+ st.subheader("Prepare your few-shot examples")
61
+
62
+ txt_input = st.text_area('Examples Inputs (One Line One Input)',
63
+ '''
64
+ Infer the date from context. Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A:
65
+ Infer the date from context. Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A:
66
+ Infer the date from context. Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A:
67
+ '''.strip())
68
+
69
+ txt_output = st.text_area('Examples Outputs (One Line One Output)', '''
70
+ (C)
71
+ (E)
72
+ (F)
73
+ '''.strip())
74
+
75
+ max_step = st.slider('Maximum iteration step', 10, 1000, step=10)
76
+
77
+ # st.subheader("Watch the logs below")
78
+ buffer = st.expander("Learning Logs")
79
+
80
+ if st.button(':rocket: Start!'):
81
+ if len(st.session_state["select_names"]) == 0:
82
+ st.error("Please select at least 1 module!")
83
+ elif max_step < len(st.session_state["select_names"]):
84
+ st.error(
85
+ "Please specify a larger maximum iteration step than the number of selected modules!")
86
+ else:
87
+ buffer.text("* begin to perform lorahub learning *")
88
+ from util import lorahub_learning
89
+ with rd.stderr(to=buffer):
90
+ recommendation, final_lora = lorahub_learning([LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
91
+ txt_input, txt_output, max_inference_step=max_step)
92
+
93
+ st.success("Lorahub learning finished! You got the following recommendation:")
94
+ df = {
95
+ "modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
96
+ "weights": recommendation.value,
97
+ }
98
+ st.table(df)
99
+
100
+ # zip the final lora module
101
+ torch.save(final_lora, "lora/adapter_model.bin")
102
+ # create a zip file
103
+ shutil.make_archive("lora_module", 'zip', "lora")
104
+ with open("lora_module.zip", "rb") as fp:
105
+ btn = st.download_button(
106
+ label="Download ZIP",
107
+ data=fp,
108
+ file_name="lora_module.zip",
109
+ mime="application/zip"
110
+ )
111
+
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
hub_name.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LORA_HUB_NAMES = [
2
+ "lorahub/flan_t5_large-qasc_qa_with_separated_facts_3",
3
+ "lorahub/flan_t5_large-ag_news_subset",
4
+ "lorahub/flan_t5_large-web_questions_whats_the_answer",
5
+ "lorahub/flan_t5_large-wiki_hop_original_choose_best_object_affirmative_1",
6
+ "lorahub/flan_t5_large-quoref_What_Is_The_Answer",
7
+ "lorahub/flan_t5_large-qasc_is_correct_1",
8
+ "lorahub/flan_t5_large-ropes_given_background_situation",
9
+ "lorahub/flan_t5_large-duorc_SelfRC_title_generation",
10
+ "lorahub/flan_t5_large-wiki_hop_original_choose_best_object_affirmative_3",
11
+ "lorahub/flan_t5_large-wiki_hop_original_generate_subject",
12
+ "lorahub/flan_t5_large-coqa",
13
+ "lorahub/flan_t5_large-adversarial_qa_droberta_question_context_answer",
14
+ "lorahub/flan_t5_large-amazon_polarity_flattering_or_not",
15
+ "lorahub/flan_t5_large-quarel_choose_between",
16
+ "lorahub/flan_t5_large-adversarial_qa_dbidaf_based_on",
17
+ "lorahub/flan_t5_large-adversarial_qa_dbert_answer_the_following_q",
18
+ "lorahub/flan_t5_large-dbpedia_14_given_a_list_of_category_what_does_the_title_belong_to",
19
+ "lorahub/flan_t5_large-wiki_hop_original_choose_best_object_interrogative_1",
20
+ "lorahub/flan_t5_large-trec",
21
+ "lorahub/flan_t5_large-race_high_Write_a_multi_choice_question_options_given_",
22
+ "lorahub/flan_t5_large-social_i_qa_Show_choices_and_generate_answer",
23
+ "lorahub/flan_t5_large-app_reviews_categorize_rating_using_review",
24
+ "lorahub/flan_t5_large-wiki_hop_original_generate_subject_and_object",
25
+ "lorahub/flan_t5_large-true_case",
26
+ "lorahub/flan_t5_large-wiki_qa_Topic_Prediction_Answer_Only",
27
+ "lorahub/flan_t5_large-quartz_given_the_fact_answer_the_q",
28
+ "lorahub/flan_t5_large-quail_context_question_description_answer_text",
29
+ "lorahub/flan_t5_large-dbpedia_14_given_a_choice_of_categories_",
30
+ "lorahub/flan_t5_large-dream_baseline",
31
+ "lorahub/flan_t5_large-wiki_qa_Is_This_True_",
32
+ "lorahub/flan_t5_large-glue_wnli",
33
+ "lorahub/flan_t5_large-adversarial_qa_dbert_based_on",
34
+ "lorahub/flan_t5_large-quoref_Read_And_Extract_",
35
+ "lorahub/flan_t5_large-amazon_polarity_User_recommend_this_product",
36
+ "lorahub/flan_t5_large-wiqa_what_is_the_final_step_of_the_following_process",
37
+ "lorahub/flan_t5_large-ropes_plain_no_background",
38
+ "lorahub/flan_t5_large-wiki_hop_original_choose_best_object_affirmative_2",
39
+ "lorahub/flan_t5_large-race_middle_Select_the_best_answer_generate_span_",
40
+ "lorahub/flan_t5_large-quoref_Answer_Question_Given_Context",
41
+ "lorahub/flan_t5_large-wmt16_translate_tr-en",
42
+ "lorahub/flan_t5_large-quoref_Found_Context_Online",
43
+ "lorahub/flan_t5_large-wiki_qa_Decide_good_answer",
44
+ "lorahub/flan_t5_large-para_crawl_enes",
45
+ "lorahub/flan_t5_large-race_middle_Taking_a_test",
46
+ "lorahub/flan_t5_large-ropes_background_new_situation_answer",
47
+ "lorahub/flan_t5_large-fix_punct",
48
+ "lorahub/flan_t5_large-super_glue_rte",
49
+ "lorahub/flan_t5_large-ropes_background_situation_middle",
50
+ "lorahub/flan_t5_large-race_high_Taking_a_test",
51
+ "lorahub/flan_t5_large-wiki_bio_who",
52
+ "lorahub/flan_t5_large-quartz_paragraph_question_plain_concat",
53
+ "lorahub/flan_t5_large-ropes_plain_background_situation",
54
+ "lorahub/flan_t5_large-quoref_Given_Context_Answer_Question",
55
+ "lorahub/flan_t5_large-adversarial_qa_dbidaf_question_context_answer",
56
+ "lorahub/flan_t5_large-wmt16_translate_ro-en",
57
+ "lorahub/flan_t5_large-adversarial_qa_dbert_question_context_answer",
58
+ "lorahub/flan_t5_large-duorc_ParaphraseRC_question_answering",
59
+ "lorahub/flan_t5_large-race_high_Is_this_the_right_answer",
60
+ "lorahub/flan_t5_large-sciq_Direct_Question",
61
+ "lorahub/flan_t5_large-super_glue_wsc.fixed",
62
+ "lorahub/flan_t5_large-super_glue_wic",
63
+ "lorahub/flan_t5_large-quoref_Answer_Friend_Question",
64
+ "lorahub/flan_t5_large-imdb_reviews_plain_text",
65
+ "lorahub/flan_t5_large-race_middle_Select_the_best_answer",
66
+ "lorahub/flan_t5_large-quail_context_question_answer_description_id",
67
+ "lorahub/flan_t5_large-wiki_qa_found_on_google",
68
+ "lorahub/flan_t5_large-glue_sst2",
69
+ "lorahub/flan_t5_large-quail_context_description_question_answer_id",
70
+ "lorahub/flan_t5_large-super_glue_cb",
71
+ "lorahub/flan_t5_large-ropes_prompt_bottom_no_hint",
72
+ "lorahub/flan_t5_large-anli_r1",
73
+ "lorahub/flan_t5_large-ropes_read_background_situation",
74
+ "lorahub/flan_t5_large-qasc_qa_with_separated_facts_2",
75
+ "lorahub/flan_t5_large-quarel_heres_a_story",
76
+ "lorahub/flan_t5_large-social_i_qa_Generate_the_question_from_the_answer",
77
+ "lorahub/flan_t5_large-sciq_Multiple_Choice_Closed_Book_",
78
+ "lorahub/flan_t5_large-math_dataset_algebra__linear_1d",
79
+ "lorahub/flan_t5_large-yelp_polarity_reviews",
80
+ "lorahub/flan_t5_large-adversarial_qa_droberta_tell_what_it_is",
81
+ "lorahub/flan_t5_large-wiqa_what_might_be_the_last_step_of_the_process",
82
+ "lorahub/flan_t5_large-adversarial_qa_dbidaf_answer_the_following_q",
83
+ "lorahub/flan_t5_large-quoref_Guess_Answer",
84
+ "lorahub/flan_t5_large-amazon_polarity_convey_negative_or_positive_sentiment",
85
+ "lorahub/flan_t5_large-wiki_qa_Topic_Prediction_Question_Only",
86
+ "lorahub/flan_t5_large-ropes_new_situation_background_answer",
87
+ "lorahub/flan_t5_large-web_questions_potential_correct_answer",
88
+ "lorahub/flan_t5_large-qasc_is_correct_2",
89
+ "lorahub/flan_t5_large-quoref_Find_Answer",
90
+ "lorahub/flan_t5_large-app_reviews_convert_to_rating",
91
+ "lorahub/flan_t5_large-quail_description_context_question_answer_text",
92
+ "lorahub/flan_t5_large-qasc_qa_with_separated_facts_4",
93
+ "lorahub/flan_t5_large-qasc_qa_with_separated_facts_5",
94
+ "lorahub/flan_t5_large-quoref_Guess_Title_For_Context",
95
+ "lorahub/flan_t5_large-wiki_hop_original_explain_relation",
96
+ "lorahub/flan_t5_large-ropes_prompt_beginning",
97
+ "lorahub/flan_t5_large-gem_e2e_nlg",
98
+ "lorahub/flan_t5_large-race_high_Select_the_best_answer_no_instructions_",
99
+ "lorahub/flan_t5_large-quail_context_question_description_answer_id",
100
+ "lorahub/flan_t5_large-qasc_qa_with_combined_facts_1",
101
+ "lorahub/flan_t5_large-glue_cola",
102
+ "lorahub/flan_t5_large-quail_description_context_question_answer_id",
103
+ "lorahub/flan_t5_large-wiqa_which_of_the_following_is_the_supposed_perturbation",
104
+ "lorahub/flan_t5_large-sciq_Direct_Question_Closed_Book_",
105
+ "lorahub/flan_t5_large-wmt14_translate_fr-en",
106
+ "lorahub/flan_t5_large-quoref_Context_Contains_Answer",
107
+ "lorahub/flan_t5_large-kilt_tasks_hotpotqa_complex_question",
108
+ "lorahub/flan_t5_large-amazon_polarity_negative_or_positive_tone",
109
+ "lorahub/flan_t5_large-amazon_polarity_would_you_buy",
110
+ "lorahub/flan_t5_large-wiki_qa_exercise",
111
+ "lorahub/flan_t5_large-adversarial_qa_dbert_tell_what_it_is",
112
+ "lorahub/flan_t5_large-word_segment",
113
+ "lorahub/flan_t5_large-gem_dart",
114
+ "lorahub/flan_t5_large-duorc_ParaphraseRC_extract_answer",
115
+ "lorahub/flan_t5_large-duorc_ParaphraseRC_title_generation",
116
+ "lorahub/flan_t5_large-ropes_plain_bottom_hint",
117
+ "lorahub/flan_t5_large-wiki_bio_comprehension",
118
+ "lorahub/flan_t5_large-anli_r2",
119
+ "lorahub/flan_t5_large-quail_context_question_answer_description_text",
120
+ "lorahub/flan_t5_large-wiki_hop_original_generate_object",
121
+ "lorahub/flan_t5_large-squad_v1.1",
122
+ "lorahub/flan_t5_large-wiki_qa_Jeopardy_style",
123
+ "lorahub/flan_t5_large-lambada",
124
+ "lorahub/flan_t5_large-quartz_having_read_above_passage",
125
+ "lorahub/flan_t5_large-quartz_use_info_from_question_paragraph",
126
+ "lorahub/flan_t5_large-wiki_bio_key_content",
127
+ "lorahub/flan_t5_large-duorc_SelfRC_answer_question",
128
+ "lorahub/flan_t5_large-duorc_ParaphraseRC_answer_question",
129
+ "lorahub/flan_t5_large-wiki_qa_Topic_Prediction_Question_and_Answer_Pair",
130
+ "lorahub/flan_t5_large-anli_r3",
131
+ "lorahub/flan_t5_large-glue_mnli",
132
+ "lorahub/flan_t5_large-wiki_bio_guess_person",
133
+ "lorahub/flan_t5_large-race_high_Select_the_best_answer_generate_span_",
134
+ "lorahub/flan_t5_large-glue_stsb",
135
+ "lorahub/flan_t5_large-gem_web_nlg_en",
136
+ "lorahub/flan_t5_large-adversarial_qa_droberta_based_on",
137
+ "lorahub/flan_t5_large-duorc_SelfRC_question_answering",
138
+ "lorahub/flan_t5_large-dream_read_the_following_conversation_and_answer_the_question",
139
+ "lorahub/flan_t5_large-duorc_SelfRC_generate_question_by_answer",
140
+ "lorahub/flan_t5_large-definite_pronoun_resolution",
141
+ "lorahub/flan_t5_large-quartz_read_passage_below_choose",
142
+ "lorahub/flan_t5_large-race_middle_Is_this_the_right_answer",
143
+ "lorahub/flan_t5_large-wiqa_effect_with_label_answer",
144
+ "lorahub/flan_t5_large-wiqa_what_might_be_the_first_step_of_the_process",
145
+ "lorahub/flan_t5_large-sciq_Multiple_Choice",
146
+ "lorahub/flan_t5_large-quartz_use_info_from_paragraph_question",
147
+ "lorahub/flan_t5_large-quarel_do_not_use",
148
+ "lorahub/flan_t5_large-quac",
149
+ "lorahub/flan_t5_large-glue_qqp",
150
+ "lorahub/flan_t5_large-quail_no_prompt_text",
151
+ "lorahub/flan_t5_large-duorc_ParaphraseRC_decide_worth_it",
152
+ "lorahub/flan_t5_large-wiqa_effect_with_string_answer",
153
+ "lorahub/flan_t5_large-wiki_hop_original_choose_best_object_interrogative_2",
154
+ "lorahub/flan_t5_large-bool_q",
155
+ "lorahub/flan_t5_large-social_i_qa_Check_if_a_random_answer_is_valid_or_not",
156
+ "lorahub/flan_t5_large-ropes_prompt_bottom_hint_beginning",
157
+ "lorahub/flan_t5_large-newsroom",
158
+ "lorahub/flan_t5_large-ropes_prompt_mix",
159
+ "lorahub/flan_t5_large-quartz_answer_question_based_on",
160
+ "lorahub/flan_t5_large-qasc_qa_with_separated_facts_1",
161
+ "lorahub/flan_t5_large-race_high_Select_the_best_answer",
162
+ "lorahub/flan_t5_large-duorc_ParaphraseRC_movie_director",
163
+ "lorahub/flan_t5_large-amazon_polarity_user_satisfied",
164
+ "lorahub/flan_t5_large-sentiment140",
165
+ "lorahub/flan_t5_large-glue_mrpc",
166
+ "lorahub/flan_t5_large-super_glue_multirc",
167
+ "lorahub/flan_t5_large-quoref_Answer_Test",
168
+ "lorahub/flan_t5_large-wiqa_what_is_the_missing_first_step",
169
+ "lorahub/flan_t5_large-race_middle_Select_the_best_answer_no_instructions_",
170
+ "lorahub/flan_t5_large-snli",
171
+ "lorahub/flan_t5_large-dbpedia_14_pick_one_category_for_the_following_text",
172
+ "lorahub/flan_t5_large-amazon_polarity_Is_this_review_negative",
173
+ "lorahub/flan_t5_large-quarel_testing_students",
174
+ "lorahub/flan_t5_large-glue_qnli",
175
+ "lorahub/flan_t5_large-kilt_tasks_hotpotqa_final_exam",
176
+ "lorahub/flan_t5_large-web_questions_get_the_answer",
177
+ "lorahub/flan_t5_large-duorc_SelfRC_decide_worth_it",
178
+ "lorahub/flan_t5_large-paws_wiki",
179
+ "lorahub/flan_t5_large-social_i_qa_Show_choices_and_generate_index",
180
+ "lorahub/flan_t5_large-duorc_SelfRC_extract_answer",
181
+ "lorahub/flan_t5_large-drop",
182
+ "lorahub/flan_t5_large-adversarial_qa_droberta_answer_the_following_q",
183
+ "lorahub/flan_t5_large-amazon_polarity_Is_this_product_review_positive",
184
+ "lorahub/flan_t5_large-quail_no_prompt_id",
185
+ "lorahub/flan_t5_large-wiki_qa_automatic_system",
186
+ "lorahub/flan_t5_large-sciq_Multiple_Choice_Question_First",
187
+ "lorahub/flan_t5_large-squad_v2.0",
188
+ "lorahub/flan_t5_large-wiqa_does_the_supposed_perturbation_have_an_effect",
189
+ "lorahub/flan_t5_large-wiki_bio_what_content",
190
+ "lorahub/flan_t5_large-duorc_SelfRC_movie_director",
191
+ "lorahub/flan_t5_large-quarel_logic_test",
192
+ "lorahub/flan_t5_large-quartz_answer_question_below",
193
+ "lorahub/flan_t5_large-dbpedia_14_given_list_what_category_does_the_paragraph_belong_to",
194
+ "lorahub/flan_t5_large-amazon_polarity_Is_this_review",
195
+ "lorahub/flan_t5_large-race_middle_Write_a_multi_choice_question_options_given_",
196
+ "lorahub/flan_t5_large-adversarial_qa_dbidaf_tell_what_it_is",
197
+ "lorahub/flan_t5_large-quail_context_description_question_answer_text"
198
+ ]
lora/adapter_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model_name_or_path": "google/flan-t5-large",
3
+ "bias": "none",
4
+ "fan_in_fan_out": false,
5
+ "inference_mode": true,
6
+ "init_lora_weights": true,
7
+ "layers_pattern": null,
8
+ "layers_to_transform": null,
9
+ "lora_alpha": 32,
10
+ "lora_dropout": 0.1,
11
+ "modules_to_save": null,
12
+ "peft_type": "LORA",
13
+ "r": 16,
14
+ "revision": null,
15
+ "target_modules": [
16
+ "q",
17
+ "v"
18
+ ],
19
+ "task_type": "SEQ_2_SEQ_LM"
20
+ }
redirect.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import io
3
+ import contextlib
4
+ import sys
5
+ import re
6
+
7
+
8
+ class _Redirect:
9
+ class IOStuff(io.StringIO):
10
+ def __init__(self, trigger, max_buffer, buffer_separator, regex, dup=None):
11
+ super().__init__()
12
+ self._trigger = trigger
13
+ self._max_buffer = max_buffer
14
+ self._buffer_separator = buffer_separator
15
+ self._regex = regex and re.compile(regex)
16
+ self._dup = dup
17
+
18
+ def write(self, __s: str) -> int:
19
+ if self._max_buffer:
20
+ concatenated_len = super().tell() + len(__s)
21
+ if concatenated_len > self._max_buffer:
22
+ rest = self.get_filtered_output()[concatenated_len - self._max_buffer:]
23
+ if self._buffer_separator is not None:
24
+ rest = rest.split(self._buffer_separator, 1)[-1]
25
+ super().seek(0)
26
+ super().write(rest)
27
+ super().truncate(super().tell() + len(__s))
28
+ res = super().write(__s)
29
+ if self._dup is not None:
30
+ self._dup.write(__s)
31
+ self._trigger(self.get_filtered_output())
32
+ return res
33
+
34
+ def get_filtered_output(self):
35
+ if self._regex is None or self._buffer_separator is None:
36
+ return self.getvalue()
37
+
38
+ return self._buffer_separator.join(filter(self._regex.search, self.getvalue().split(self._buffer_separator)))
39
+
40
+ def print_at_end(self):
41
+ self._trigger(self.get_filtered_output())
42
+
43
+ def __init__(self, stdout=None, stderr=False, format=None, to=None, max_buffer=None, buffer_separator='\n',
44
+ regex=None, duplicate_out=False):
45
+ self.io_args = {'trigger': self._write, 'max_buffer': max_buffer, 'buffer_separator': buffer_separator,
46
+ 'regex': regex}
47
+ self.redirections = []
48
+ self.st = None
49
+ self.stderr = stderr is True
50
+ self.stdout = stdout is True or (stdout is None and not self.stderr)
51
+ self.format = format or 'code'
52
+ self.to = to
53
+ self.fun = None
54
+ self.duplicate_out = duplicate_out or None
55
+ self.active_nested = None
56
+
57
+ if not self.stdout and not self.stderr:
58
+ raise ValueError("one of stdout or stderr must be True")
59
+
60
+ if self.format not in ['text', 'markdown', 'latex', 'code', 'write']:
61
+ raise ValueError(
62
+ f"format need oneof the following: {', '.join(['text', 'markdown', 'latex', 'code', 'write'])}")
63
+
64
+ if self.to and (not hasattr(self.to, 'text') or not hasattr(self.to, 'empty')):
65
+ raise ValueError(f"'to' is not a streamlit container object")
66
+
67
+ def __enter__(self):
68
+ if self.st is not None:
69
+ if self.to is None:
70
+ if self.active_nested is None:
71
+ self.active_nested = self(format=self.format, max_buffer=self.io_args['max_buffer'],
72
+ buffer_separator=self.io_args['buffer_separator'],
73
+ regex=self.io_args['regex'], duplicate_out=self.duplicate_out)
74
+ return self.active_nested.__enter__()
75
+ else:
76
+ raise Exception("Already entered")
77
+ to = self.to or st
78
+
79
+ # to.text(f"{'stdout and stderr' if self.stdout and self.stderr else 'stdout' if self.stdout else 'stderr'}"
80
+ # f"{' [' + self.io_args['regex'] + ']' if self.io_args['regex'] else ''}"
81
+ # f":")
82
+ self.st = to.empty()
83
+ self.fun = getattr(self.st, self.format)
84
+
85
+ io_obj = None
86
+
87
+ def redirect(to_duplicate):
88
+ nonlocal io_obj
89
+ io_obj = _Redirect.IOStuff(dup=self.duplicate_out and to_duplicate, **self.io_args)
90
+ redirection = contextlib.redirect_stdout(io_obj)
91
+ self.redirections.append((redirection, io_obj))
92
+ redirection.__enter__()
93
+
94
+ if self.stderr:
95
+ redirect(sys.stderr)
96
+ if self.stdout:
97
+ redirect(sys.stdout)
98
+
99
+ return io_obj
100
+
101
+ def __call__(self, to=None, format=None, max_buffer=None, buffer_separator='\n', regex=None, duplicate_out=False):
102
+ return _Redirect(self.stdout, self.stderr, format=format, to=to, max_buffer=max_buffer,
103
+ buffer_separator=buffer_separator, regex=regex, duplicate_out=duplicate_out)
104
+
105
+ def __exit__(self, *exc):
106
+ if self.active_nested is not None:
107
+ nested = self.active_nested
108
+ if nested.active_nested is None:
109
+ self.active_nested = None
110
+ return nested.__exit__(*exc)
111
+
112
+ res = None
113
+ for redirection, io_obj in reversed(self.redirections):
114
+ res = redirection.__exit__(*exc)
115
+ io_obj.print_at_end()
116
+
117
+ self.redirections = []
118
+ self.st = None
119
+ self.fun = None
120
+ return res
121
+
122
+ def _write(self, data):
123
+ self.fun(data)
124
+
125
+
126
+ stdout = _Redirect()
127
+ stderr = _Redirect(stderr=True)
128
+ stdouterr = _Redirect(stdout=True, stderr=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ peft
2
+ transformers
3
+ pandas
util.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM
2
+ import torch
3
+ from datasets import Dataset
4
+ from torch.utils.data import DataLoader
5
+ from transformers import default_data_collator
6
+ from transformers import AutoTokenizer
7
+ from tqdm import tqdm
8
+ import pandas as pd
9
+ import numpy
10
+ import random
11
+ import nevergrad as ng
12
+ from peft.utils.save_and_load import set_peft_model_state_dict, get_peft_model_state_dict
13
+ from peft import PeftModel, PeftConfig
14
+ from functools import partial
15
+
16
+ random.seed(42)
17
+ numpy.random.seed(42)
18
+
19
+ def load_base_model_and_lora_modules(lora_module_list):
20
+ # use gpu if available
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ # load basic model
23
+ default_peft_model_id = lora_module_list[0]
24
+ # find the base model
25
+ model_name_or_path = PeftConfig.from_pretrained(default_peft_model_id).base_model_name_or_path
26
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
27
+ # load tokenizer
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
29
+ # 0 is the default model
30
+ peft_model = PeftModel.from_pretrained(base_model, default_peft_model_id)
31
+ peft_model = peft_model.to(device)
32
+ peft_model.eval()
33
+
34
+ print("> Begin to load lora modules")
35
+ cache = {}
36
+ for peft_model_id in tqdm(lora_module_list):
37
+ print("> Loading {} ...".format(peft_model_id))
38
+ cur_peft_model = PeftModel.from_pretrained(base_model, peft_model_id)
39
+ cache[peft_model_id] = get_peft_model_state_dict(cur_peft_model)
40
+
41
+ return peft_model, tokenizer, cache
42
+
43
+
44
+ def preprocess_function(examples, tokenizer):
45
+ inputs = examples["input"]
46
+ targets = examples["output"]
47
+ model_inputs = tokenizer(
48
+ inputs,
49
+ max_length=2048,
50
+ padding=True,
51
+ truncation=True,
52
+ return_tensors="pt",
53
+ )
54
+ labels = tokenizer(
55
+ targets,
56
+ max_length=256,
57
+ padding=True,
58
+ truncation=True,
59
+ return_tensors="pt",
60
+ )
61
+ labels = labels["input_ids"]
62
+ labels[labels == tokenizer.pad_token_id] = -100
63
+ model_inputs["labels"] = labels
64
+ return model_inputs
65
+
66
+
67
+ def load_dataset_and_run(example_inputs, example_outputs, tokenizer):
68
+ df = [
69
+ {"input": example_inputs[i], "output": example_outputs[i]}
70
+ for i in range(len(example_inputs))
71
+ ]
72
+ dataset = Dataset.from_pandas(pd.DataFrame(df))
73
+ preprocess_func_with_tokenizer = partial(preprocess_function, tokenizer=tokenizer)
74
+ processed_datasets = dataset.map(
75
+ preprocess_func_with_tokenizer,
76
+ batched=True,
77
+ num_proc=1,
78
+ desc="Running tokenizer on dataset",
79
+ )
80
+ return processed_datasets
81
+
82
+
83
+ def get_score(weights, model, cache, example_dataset):
84
+ # the composed lora state dict
85
+ final_state_dict = {}
86
+ # module list is the list
87
+ lora_module_list = list(cache.keys())
88
+ # all keys are the same
89
+ keys = cache[lora_module_list[0]].keys()
90
+ for i, peft_model_id in enumerate(lora_module_list):
91
+ lora_state_dict = cache[peft_model_id]
92
+ if i == 0:
93
+ for key in keys:
94
+ final_state_dict[key] = weights[i] * lora_state_dict[key]
95
+ else:
96
+ for key in keys:
97
+ final_state_dict[key] = (
98
+ final_state_dict[key] + weights[i] * lora_state_dict[key]
99
+ )
100
+ # reload the model with the new adapter config
101
+ set_peft_model_state_dict(model, final_state_dict)
102
+
103
+ def get_loss():
104
+ # use gpu if available
105
+ train_dataset = example_dataset
106
+ train_dataloader = DataLoader(
107
+ train_dataset,
108
+ collate_fn=default_data_collator,
109
+ batch_size=len(train_dataset),
110
+ pin_memory=True,
111
+ )
112
+ train_loss = 0
113
+ with torch.no_grad():
114
+ device = "cuda" if torch.cuda.is_available() else "cpu"
115
+ for _, batch in enumerate(train_dataloader):
116
+ batch = {k: v.to(device) for k, v in batch.items()}
117
+ with torch.no_grad():
118
+ outputs = model(**batch)
119
+ loss = outputs.loss
120
+ train_loss += loss.detach().float()
121
+ loss = train_loss.float()
122
+ return float(loss) / len(train_dataset["input"])
123
+
124
+ # minimize the metric
125
+ loss = get_loss()
126
+ # L1 regularization term
127
+ sum_of_squares = sum([abs(x) for x in weights]) / len(weights)
128
+ metric_val = loss + 0.05 * sum_of_squares
129
+
130
+ return metric_val
131
+
132
+ def get_final_weights(weights, lora_module_list, cache):
133
+ final_state_dict = {}
134
+ keys = cache[lora_module_list[0]].keys()
135
+ for i, peft_model_id in enumerate(lora_module_list):
136
+ lora_state_dict = cache[peft_model_id]
137
+ if i == 0:
138
+ for key in keys:
139
+ final_state_dict[key] = weights[i] * lora_state_dict[key]
140
+ else:
141
+ for key in keys:
142
+ final_state_dict[key] = (
143
+ final_state_dict[key] + weights[i] * lora_state_dict[key]
144
+ )
145
+ return final_state_dict
146
+
147
+
148
+
149
+ def lorahub_learning(lora_module_list, text_input, text_output, max_inference_step):
150
+ number_of_loras = len(lora_module_list)
151
+ if number_of_loras == 0:
152
+ return None
153
+ # load model
154
+ model, tokenizer, cache = load_base_model_and_lora_modules(lora_module_list)
155
+ # process dataset
156
+ dataset = load_dataset_and_run(text_input.split("\n"), text_output.split("\n"), tokenizer)
157
+
158
+ get_score_partial = partial(get_score, model=model, cache=cache,
159
+ example_dataset=dataset)
160
+ # set up the limit of the weights
161
+ instrum = ng.p.Array(
162
+ init=[0] * number_of_loras,
163
+ upper=[1.5] * number_of_loras,
164
+ lower=[-1.5] * number_of_loras,
165
+ )
166
+ optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=max_inference_step)
167
+ print("> Begin to perform gradient-free optimization ...")
168
+ recommendation = optimizer.minimize(get_score_partial, verbosity=1)
169
+ final_lora = get_final_weights(recommendation.value, lora_module_list, cache)
170
+ return recommendation, final_lora