Spaces:
Running
Running
File size: 4,804 Bytes
455b92e e002a1d 768da65 22a6156 455b92e 768da65 455b92e d3e8b15 e002a1d 70849f0 d3e8b15 768da65 e002a1d 0e9a542 e420bbe 0e9a542 455b92e 7aa0c88 455b92e c03594b 455b92e c03594b 455b92e c03594b 455b92e aeaad97 455b92e c03594b 455b92e d3e8b15 455b92e bafde03 455b92e 0e9a542 455b92e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import streamlit as st
from langchain.llms import HuggingFaceHub
from models import return_sum_models
class LLM_Langchain():
def __init__(self):
st.header('🦜 Code summarization')
st.warning("Warning: input function needs cleaning and may take long to be processed at first time")
st.info("Reference: [CodeT5](https://arxiv.org/abs/2109.00859), [The Vault](https://arxiv.org/abs/2305.06156), [CodeXGLUE](https://arxiv.org/abs/2102.04664)")
st.info("About me: namnh113")
self.api_key_area = st.sidebar.text_input(
'API key (not necessary for now)',
type='password',
help="Type in your HuggingFace API key to use this app")
self.API_KEY = os.environ["API_KEY"]
model_parent = st.sidebar.selectbox(
label = "Choose language",
options = ["python", "java", "javascript", "php", "ruby", "go"],
help="Choose languages",
)
if model_parent is None:
model_name_visibility = True
else:
model_name_visibility = False
model_name = return_sum_models(model_parent)
list_model = [model_name]
if model_parent == "python":
list_model += [model_name+"_v2"]
if model_parent != "C++":
list_model += ["Salesforce/codet5-base-multi-sum", f"Salesforce/codet5-base-codexglue-sum-{model_parent}"]
self.checkpoint = st.sidebar.selectbox(
label = "Choose model (nam194/... is my model)",
options = list_model,
help="Model used to predict",
disabled=model_name_visibility
)
self.max_new_tokens = st.sidebar.slider(
label="Token Length",
min_value=32,
max_value=1024,
step=32,
value=64,
help="Set the max tokens to get accurate results"
)
self.num_beams = st.sidebar.slider(
label="num beams",
min_value=1,
max_value=10,
step=1,
value=4,
help="Set num beam"
)
self.top_k = st.sidebar.slider(
label="top k",
min_value=1,
max_value=50,
step=1,
value=30,
help="Set the top_k"
)
self.top_p = st.sidebar.slider(
label="top p",
min_value=0.1,
max_value=1.0,
step=0.05,
value=0.95,
help="Set the top_p"
)
self.model_kwargs = {
"max_new_tokens": self.max_new_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"num_beams": self.num_beams
}
os.environ['HUGGINGFACEHUB_API_TOKEN'] = self.API_KEY
def generate_response(self, input_text):
llm = HuggingFaceHub(
repo_id = self.checkpoint,
model_kwargs = self.model_kwargs
)
return llm(input_text)
def form_data(self):
# with st.form('my_form'):
try:
if not self.API_KEY.startswith('hf_'):
st.warning('Please enter your API key!', icon='⚠')
if "messages" not in st.session_state:
st.session_state.messages = []
st.write(f"You are using {self.checkpoint} model")
for message in st.session_state.messages:
with st.chat_message(message.get('role')):
st.write(message.get("content"))
text = st.chat_input(disabled=False)
if text:
st.session_state.messages.append(
{
"role":"user",
"content": text
}
)
with st.chat_message("user"):
st.write(text)
if text.lower() == "clear":
del st.session_state.messages
return
result = self.generate_response(text)
st.session_state.messages.append(
{
"role": "assistant",
"content": result
}
)
with st.chat_message('assistant'):
st.markdown(result)
except Exception as e:
st.error(e, icon="🚨")
model = LLM_Langchain()
model.form_data()
|