Spaces:
Running
Running
import os | |
import streamlit as st | |
from langchain.llms import HuggingFaceHub | |
from models import return_sum_models | |
class LLM_Langchain(): | |
def __init__(self): | |
st.header('π¦ Code summarization') | |
st.warning("Warning: input function needs cleaning and may take long to be processed at first time") | |
st.info("Reference: [CodeT5](https://arxiv.org/abs/2109.00859), [The Vault](https://arxiv.org/abs/2305.06156), [CodeXGLUE](https://arxiv.org/abs/2102.04664)") | |
st.info("About me: namnh113") | |
self.api_key_area = st.sidebar.text_input( | |
'API key (not necessary for now)', | |
type='password', | |
help="Type in your HuggingFace API key to use this app") | |
self.API_KEY = os.environ["API_KEY"] | |
model_parent = st.sidebar.selectbox( | |
label = "Choose language", | |
options = ["python", "java", "javascript", "php", "ruby", "go"], | |
help="Choose languages", | |
) | |
if model_parent is None: | |
model_name_visibility = True | |
else: | |
model_name_visibility = False | |
model_name = return_sum_models(model_parent) | |
list_model = [model_name] | |
if model_parent == "python": | |
list_model += [model_name+"_v2"] | |
if model_parent != "C++": | |
list_model += ["Salesforce/codet5-base-multi-sum", f"Salesforce/codet5-base-codexglue-sum-{model_parent}"] | |
self.checkpoint = st.sidebar.selectbox( | |
label = "Choose model (nam194/... is my model)", | |
options = list_model, | |
help="Model used to predict", | |
disabled=model_name_visibility | |
) | |
self.max_new_tokens = st.sidebar.slider( | |
label="Token Length", | |
min_value=32, | |
max_value=1024, | |
step=32, | |
value=64, | |
help="Set the max tokens to get accurate results" | |
) | |
self.num_beams = st.sidebar.slider( | |
label="num beams", | |
min_value=1, | |
max_value=10, | |
step=1, | |
value=4, | |
help="Set num beam" | |
) | |
self.top_k = st.sidebar.slider( | |
label="top k", | |
min_value=1, | |
max_value=50, | |
step=1, | |
value=30, | |
help="Set the top_k" | |
) | |
self.top_p = st.sidebar.slider( | |
label="top p", | |
min_value=0.1, | |
max_value=1.0, | |
step=0.05, | |
value=0.95, | |
help="Set the top_p" | |
) | |
self.model_kwargs = { | |
"max_new_tokens": self.max_new_tokens, | |
"top_k": self.top_k, | |
"top_p": self.top_p, | |
"num_beams": self.num_beams | |
} | |
os.environ['HUGGINGFACEHUB_API_TOKEN'] = self.API_KEY | |
def generate_response(self, input_text): | |
llm = HuggingFaceHub( | |
repo_id = self.checkpoint, | |
model_kwargs = self.model_kwargs | |
) | |
return llm(input_text) | |
def form_data(self): | |
# with st.form('my_form'): | |
try: | |
if not self.API_KEY.startswith('hf_'): | |
st.warning('Please enter your API key!', icon='β ') | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
st.write(f"You are using {self.checkpoint} model") | |
for message in st.session_state.messages: | |
with st.chat_message(message.get('role')): | |
st.write(message.get("content")) | |
text = st.chat_input(disabled=False) | |
if text: | |
st.session_state.messages.append( | |
{ | |
"role":"user", | |
"content": text | |
} | |
) | |
with st.chat_message("user"): | |
st.write(text) | |
if text.lower() == "clear": | |
del st.session_state.messages | |
return | |
result = self.generate_response(text) | |
st.session_state.messages.append( | |
{ | |
"role": "assistant", | |
"content": result | |
} | |
) | |
with st.chat_message('assistant'): | |
st.markdown(result) | |
except Exception as e: | |
st.error(e, icon="π¨") | |
model = LLM_Langchain() | |
model.form_data() | |