File size: 8,300 Bytes
470be5c
 
 
 
 
 
 
 
 
 
 
 
 
3b2cf1a
ab1022c
 
 
 
 
 
470be5c
 
 
 
 
 
 
 
 
 
9d20c05
470be5c
 
9d20c05
 
 
 
470be5c
 
9d20c05
470be5c
 
 
 
 
 
 
 
 
 
 
 
9d20c05
470be5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d20c05
 
470be5c
9d20c05
 
 
 
470be5c
 
 
 
 
 
9d20c05
470be5c
 
 
 
 
9d20c05
 
fe11039
9d20c05
470be5c
fe11039
9d20c05
 
470be5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab1022c
470be5c
 
 
 
ab1022c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470be5c
3b2cf1a
 
 
 
470be5c
3b2cf1a
470be5c
3b2cf1a
80c71e7
470be5c
9d20c05
470be5c
3b2cf1a
470be5c
 
ab1022c
 
 
 
 
 
 
 
 
 
470be5c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import streamlit as st
from hub_name import LORA_HUB_NAMES
from random import shuffle
import pandas as pd
import streamlit as st
import contextlib
from functools import wraps
from io import StringIO
import contextlib
import redirect as rd
import torch
import shutil
import os
import uuid
import json


from google.oauth2 import service_account
import gspread
from google.oauth2.service_account import Credentials


css = """
<style>
.stDataFrame { width: 100% !important; }
</style>
"""
st.markdown(css, unsafe_allow_html=True)

def main():
    st.title("πŸ’‘ LoraHub")
    st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.")

    st.image(open("lorahub_demo.jpg", "rb").read(),
             "The Illustration of LoraHub Learning", use_column_width=True)
    
    st.markdown("In this demo, you will use avaiable lora modules selected in the left sidebar to tackle your new task. When the LoraHub learning is done, you can download the final LoRA module and use it for your new task. You can check out more details in our [paper](https://huggingface.co/papers/2307.13269).")
        
    with st.sidebar:
        st.title("πŸ›’ LoRA Module Market", help="Feel free to clone this demo and add more modules to the marketplace. Remember to make sure your lora modules share the same base model and have the same rank.")
        st.markdown(
            "The following modules are available for you to compose for your new task. Every module name is a peft repository in Huggingface Hub, and you can find them [here](https://huggingface.co/models?search=lorahub).")
        
        df = pd.DataFrame({
            "Index": list(range(len(LORA_HUB_NAMES))),
            "Module Name": LORA_HUB_NAMES,
        })
        st.data_editor(df,
                       disabled=["LoRA Module", "Index"],
                       hide_index=True)

        st.multiselect(
            'Choose the modules you want to add',
            list(range(len(LORA_HUB_NAMES))),
            [],
            key="select_names")

        def set_lucky_modules():
            names = list(range(len(LORA_HUB_NAMES)))
            shuffle(names)
            names = names[:20]
            st.session_state["select_names"] = names

        st.button(":game_die: Give 20 Lucky Modules",
                  on_click=set_lucky_modules)
        st.write('We will use the following modules', [
                 LORA_HUB_NAMES[i] for i in st.session_state["select_names"]])

    st.subheader("Choose the Module Candidates")
    st.markdown("Please checkout the sidebar on the left to select the modules you want to compose for your new task. You can also click the button to **get 20 lucky modules**.")

    st.subheader("Upload Examples of Your Task")
    st.markdown("When faced with a new task, our method requires a few examples of that task in order to perform the lora module composition. Below you should provide a few examples of the task you want to perform. The default examples are from the Date Understanding task of the BBH benchmark.")
    
    txt_input = st.text_area('*Examples Inputs (One Line One Input)*',
                             '''
Infer the date from context.  Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A:
Infer the date from context.  Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A:
Infer the date from context.  Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A:
'''.strip())

    txt_output = st.text_area('*Examples Outputs (One Line One Output)*', '''
(C)
(E)
(F)
'''.strip())

    st.subheader("Set Iteration Steps")
    st.markdown("Our method involves performing multiple inference iterations to perform the LoRA module composition. The module can then be intergrated into the LLM to carry out the new task. The maximum number of inference steps impacts performance and speed. We suggest setting it to 40 steps if 20 modules were chosen, with more steps typically needed for more modules.")

    max_step = st.slider('Maximum iteration step', 10, 100, step=5)

    st.subheader("Start LoraHub Learning")
    
    st.markdown("Note that the learning process may take a while (depending on the maximum iteration step), and downloading LoRA modules from HuggingfaceHub also takes some time. This demo runs on CPU by default, and you can monitor the learning logs below.")
    # st.subheader("Watch the logs below")
    buffer = st.expander("Learning Logs")

    if st.button(':rocket: Start!'):
        if len(st.session_state["select_names"]) == 0:
            st.error("Please select at least 1 module!")
        elif max_step < len(st.session_state["select_names"]):
            st.error(
                "Please specify a larger maximum iteration step than the number of selected modules!")
        else:
            buffer.text("* begin to perform lorahub learning *")
            from util import lorahub_learning
            with rd.stderr(to=buffer):
                recommendation, final_lora = lorahub_learning([LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
                                                  txt_input, txt_output, max_inference_step=max_step)
                
            st.success("Lorahub learning finished! You got the following recommendation:")

            df = {
                "modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
                "weights": recommendation.value,
            }



            def share():
                credentials = service_account.Credentials.from_service_account_info(
                json.loads(st.secrets["gcp_service_account"]),
                scopes=[
                    "https://www.googleapis.com/auth/spreadsheets",
                ]
                )
                gsheet_url = st.secrets["private_gsheets_url"]
                gc = gspread.authorize(credentials)
                sh = gc.open_by_url(gsheet_url)

                ws = sh.sheet1
                ws.insert_rows([[LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],recommendation.value.tolist(),[]])
            st.table(df)
            random_id = uuid.uuid4().hex
            os.makedirs(f"lora/{random_id}")
            # copy config file
            shutil.copyfile("lora/adapter_config.json", f"lora/{random_id}/adapter_config.json")
            # zip the final lora module
            torch.save(final_lora, f"lora/{random_id}/adapter_model.bin")
            # create a zip file
            shutil.make_archive(f"lora_{random_id}", 'zip', f"lora/{random_id}")
            with open(f"lora_{random_id}.zip", "rb") as fp:
                btn = st.download_button(
                    label="πŸ“₯ Download the final LoRA Module",
                    data=fp,
                    file_name=f"lora_{random_id}.zip",
                    mime="application/zip"
                )
            with open(f"lora_{random_id}.zip", "rb") as fp:
                btn = st.download_button(
                    label="πŸ“₯ Download the final LoRA Module and share your results",
                    data=fp,
                    file_name=f"lora_{random_id}.zip",
                    mime="application/zip",
                    on_click=share
                )
            st.button("πŸ“₯ Share your results",on_click=share)
            st.warning("The page will be refreshed once you click the download button. Share results may cost 1-2 mins.")



if __name__ == "__main__":
    main()