import streamlit as st from hub_name import LORA_HUB_NAMES from random import shuffle import pandas as pd import streamlit as st import contextlib from functools import wraps from io import StringIO import contextlib import redirect as rd import torch import shutil import os css = """ """ st.markdown(css, unsafe_allow_html=True) def main(): st.title("LoraHub") st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.") st.markdown("In this demo, you will use avaiable lora modules selected in the left sidebar to tackle your few-shot examples. When the LoraHub learning is done, you can download the final LoRA module and use it for your new task. You can check out more details in our [paper](https://huggingface.co/papers/2307.13269).") with st.sidebar: st.title("LoRA Module Pool") st.markdown( "The following modules are available for you to compose for your new task. Every module name is a peft repository in Huggingface Hub, and you can find them [here](https://huggingface.co/models?search=lorahub).") df = pd.DataFrame({ "Index": list(range(len(LORA_HUB_NAMES))), "Module Name": LORA_HUB_NAMES, }) st.data_editor(df, disabled=["LoRA Module", "Index"], hide_index=True) st.multiselect( 'Select your favorite modules as the candidate for LoRA composition', list(range(len(LORA_HUB_NAMES))), [], key="select_names") def set_lucky_modules(): names = list(range(len(LORA_HUB_NAMES))) shuffle(names) names = names[:20] st.session_state["select_names"] = names st.button(":game_die: Give 20 Lucky Modules", on_click=set_lucky_modules) st.write('We will use the following modules', [ LORA_HUB_NAMES[i] for i in st.session_state["select_names"]]) st.subheader("Prepare your few-shot examples") txt_input = st.text_area('Examples Inputs (One Line One Input)', ''' Infer the date from context. Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A: Infer the date from context. Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A: Infer the date from context. Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A: '''.strip()) txt_output = st.text_area('Examples Outputs (One Line One Output)', ''' (C) (E) (F) '''.strip()) max_step = st.slider('Maximum iteration step', 10, 1000, step=10) # st.subheader("Watch the logs below") buffer = st.expander("Learning Logs") if st.button(':rocket: Start!'): if len(st.session_state["select_names"]) == 0: st.error("Please select at least 1 module!") elif max_step < len(st.session_state["select_names"]): st.error( "Please specify a larger maximum iteration step than the number of selected modules!") else: buffer.text("* begin to perform lorahub learning *") from util import lorahub_learning with rd.stderr(to=buffer): recommendation, final_lora = lorahub_learning([LORA_HUB_NAMES[i] for i in st.session_state["select_names"]], txt_input, txt_output, max_inference_step=max_step) st.success("Lorahub learning finished! You got the following recommendation:") df = { "modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]], "weights": recommendation.value, } st.table(df) # zip the final lora module torch.save(final_lora, "lora/adapter_model.bin") # create a zip file shutil.make_archive("lora_module", 'zip', "lora") with open("lora_module.zip", "rb") as fp: btn = st.download_button( label="Download ZIP", data=fp, file_name="lora_module.zip", mime="application/zip" ) if __name__ == "__main__": main()