File size: 5,261 Bytes
470be5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe11039
470be5c
 
 
 
 
 
 
 
 
 
 
 
 
 
fe11039
 
470be5c
 
fe11039
470be5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
from hub_name import LORA_HUB_NAMES
from random import shuffle
import pandas as pd
import streamlit as st
import contextlib
from functools import wraps
from io import StringIO
import contextlib
import redirect as rd
import torch
import shutil
import os


css = """
<style>
.stDataFrame { width: 100% !important; }
</style>
"""
st.markdown(css, unsafe_allow_html=True)


def main():
    st.title("LoraHub")
    st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.")

    st.markdown("In this demo, you will use avaiable lora modules selected in the left sidebar to tackle your few-shot examples. When the LoraHub learning is done, you can download the final LoRA module and use it for your new task. You can check out more details in our [paper](https://huggingface.co/papers/2307.13269).")
        
    with st.sidebar:
        st.title("LoRA Module Pool")
        st.markdown(
            "The following modules are available for you to compose for your new task. Every module name is a peft repository in Huggingface Hub, and you can find them [here](https://huggingface.co/models?search=lorahub).")
        
        df = pd.DataFrame({
            "Index": list(range(len(LORA_HUB_NAMES))),
            "Module Name": LORA_HUB_NAMES,
        })
        st.data_editor(df,
                       disabled=["LoRA Module", "Index"],
                       hide_index=True)

        st.multiselect(
            'Select your favorite modules as the candidate for LoRA composition',
            list(range(len(LORA_HUB_NAMES))),
            [],
            key="select_names")

        def set_lucky_modules():
            names = list(range(len(LORA_HUB_NAMES)))
            shuffle(names)
            names = names[:20]
            st.session_state["select_names"] = names

        st.button(":game_die: Give 20 Lucky Modules",
                  on_click=set_lucky_modules)
        st.write('We will use the following modules', [
                 LORA_HUB_NAMES[i] for i in st.session_state["select_names"]])

    st.subheader("Prepare Few-shot Examples")

    txt_input = st.text_area('Examples Inputs (One Line One Input)',
                             '''
Infer the date from context.  Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A:
Infer the date from context.  Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A:
Infer the date from context.  Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A:
'''.strip())

    txt_output = st.text_area('Examples Outputs (One Line One Output)', '''
(C)
(E)
(F)
'''.strip())

    st.subheader("Set Hyper-parameter")

    max_step = st.slider('Maximum iteration step', 10, 1000, step=10)

    st.subheader("Start LoraHub Learning")
    # st.subheader("Watch the logs below")
    buffer = st.expander("Learning Logs")

    if st.button(':rocket: Start!'):
        if len(st.session_state["select_names"]) == 0:
            st.error("Please select at least 1 module!")
        elif max_step < len(st.session_state["select_names"]):
            st.error(
                "Please specify a larger maximum iteration step than the number of selected modules!")
        else:
            buffer.text("* begin to perform lorahub learning *")
            from util import lorahub_learning
            with rd.stderr(to=buffer):
                recommendation, final_lora = lorahub_learning([LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
                                                  txt_input, txt_output, max_inference_step=max_step)
                
            st.success("Lorahub learning finished! You got the following recommendation:")
            df = {
                "modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
                "weights": recommendation.value,
            }
            st.table(df)
            
            # zip the final lora module
            torch.save(final_lora, "lora/adapter_model.bin")
            # create a zip file
            shutil.make_archive("lora_module", 'zip', "lora")
            with open("lora_module.zip", "rb") as fp:
                btn = st.download_button(
                    label="Download ZIP",
                    data=fp,
                    file_name="lora_module.zip",
                    mime="application/zip"
                )



if __name__ == "__main__":
    main()