lorahub / app.py
SivilTaram
update demo
470be5c
raw history blame
No virus
5.18 kB
import streamlit as st
from hub_name import LORA_HUB_NAMES
from random import shuffle
import pandas as pd
import streamlit as st
import contextlib
from functools import wraps
from io import StringIO
import contextlib
import redirect as rd
import torch
import shutil
import os
css = """
<style>
.stDataFrame { width: 100% !important; }
</style>
"""
st.markdown(css, unsafe_allow_html=True)
def main():
st.title("LoraHub")
st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.")
st.markdown("In this demo, you will use avaiable lora modules selected in the left sidebar to tackle your few-shot examples. When the LoraHub learning is done, you can download the final LoRA module and use it for your new task. You can check out more details in our [paper](https://huggingface.co/papers/2307.13269).")
with st.sidebar:
st.title("LoRA Module Pool")
st.markdown(
"The following modules are available for you to compose for your new task. Every module name is a peft repository in Huggingface Hub, and you can find them [here](https://huggingface.co/models?search=lorahub).")
df = pd.DataFrame({
"Index": list(range(len(LORA_HUB_NAMES))),
"Module Name": LORA_HUB_NAMES,
})
st.data_editor(df,
disabled=["LoRA Module", "Index"],
hide_index=True)
st.multiselect(
'Select your favorite modules as the candidate for LoRA composition',
list(range(len(LORA_HUB_NAMES))),
[],
key="select_names")
def set_lucky_modules():
names = list(range(len(LORA_HUB_NAMES)))
shuffle(names)
names = names[:20]
st.session_state["select_names"] = names
st.button(":game_die: Give 20 Lucky Modules",
on_click=set_lucky_modules)
st.write('We will use the following modules', [
LORA_HUB_NAMES[i] for i in st.session_state["select_names"]])
st.subheader("Prepare your few-shot examples")
txt_input = st.text_area('Examples Inputs (One Line One Input)',
'''
Infer the date from context. Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A:
Infer the date from context. Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A:
Infer the date from context. Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A:
'''.strip())
txt_output = st.text_area('Examples Outputs (One Line One Output)', '''
(C)
(E)
(F)
'''.strip())
max_step = st.slider('Maximum iteration step', 10, 1000, step=10)
# st.subheader("Watch the logs below")
buffer = st.expander("Learning Logs")
if st.button(':rocket: Start!'):
if len(st.session_state["select_names"]) == 0:
st.error("Please select at least 1 module!")
elif max_step < len(st.session_state["select_names"]):
st.error(
"Please specify a larger maximum iteration step than the number of selected modules!")
else:
buffer.text("* begin to perform lorahub learning *")
from util import lorahub_learning
with rd.stderr(to=buffer):
recommendation, final_lora = lorahub_learning([LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
txt_input, txt_output, max_inference_step=max_step)
st.success("Lorahub learning finished! You got the following recommendation:")
df = {
"modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]],
"weights": recommendation.value,
}
st.table(df)
# zip the final lora module
torch.save(final_lora, "lora/adapter_model.bin")
# create a zip file
shutil.make_archive("lora_module", 'zip', "lora")
with open("lora_module.zip", "rb") as fp:
btn = st.download_button(
label="Download ZIP",
data=fp,
file_name="lora_module.zip",
mime="application/zip"
)
if __name__ == "__main__":
main()