Spaces:
Paused
Paused
import streamlit as st | |
import torch | |
from langchain import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
from streamlit_extras.row import row | |
if 'model' not in st.session_state: | |
st.session_state['model'] = 0 | |
if 'max_length' not in st.session_state: | |
st.session_state['max_length'] = 0 | |
if 'temperature' not in st.session_state: | |
st.session_state['temperature'] = 0 | |
if 'repetition_penalty' not in st.session_state: | |
st.session_state['repetition_penalty'] = 0 | |
def load_llm_model(max_length, temperature, repetition_penalty): | |
# llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0', | |
# task= 'text2text-generation', | |
# model_kwargs={ "device_map": "auto", | |
# "load_in_8bit": True,"max_length": 256, "temperature": 0, | |
# "repetition_penalty": 1.5}) | |
llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0', | |
task= 'text2text-generation', | |
model_kwargs={ "max_length": max_length, "temperature": temperature, | |
"torch_dtype":torch.float32, | |
"repetition_penalty": repetition_penalty}) | |
return llm | |
st.title("Model Download") | |
# st.subheader("This page allows users to adjust some parameters of the model before downloading") | |
st.subheader("ผู้ใช้สามารถปรับเลือกการตั้งค่าต่อไปนี้เพื่อทำการดาวน์โลดโมเดล") | |
# model_row = row([2, 2, 2], vertical_align="bottom") | |
# max_length = model_row.number_input("max_length", value = 256) | |
# temperature = model_row.number_input("temperature", value = 0) | |
# repetition_penalty = model_row.number_input("repetition_penalty", value = 1.3) | |
max_length = st.number_input("max_length", value = 256, step = 128) | |
st.caption(""" | |
กำหนดจำนวนคำของโมเดลภาษา หากตั้งค่าน้อย โมเดลจะตอบสั้นและกระชับ ถ้าตั้งค่าให้มากๆ การตอบกลับอาจมีรายละเอียดมากขึ้น แต่ต้องระวังเพราะคำตอบที่ยาวเกินไปอาจไม่มีจุดโฟกัส | |
""") | |
st.divider() | |
temperature = st.number_input("temperature", value = 0.0, step = 0.1, max_value = 1.0) | |
st.caption(""" | |
กำหนดความคิดสร้างสรรค์และความหลากหลายในการตอบของโมเดล ค่าที่ต่ำสุด 0 จะเป็นการตอบแบบมีการควบคุมสูงสุด โดย 1 จะมีความหลากหลายสูงสุด และสร้างสรรค์มากสุด แต่ความมีเหตุผลอาจลดลง | |
""") | |
st.divider() | |
repetition_penalty = st.number_input("repetition_penalty", value = 1.3, step = 0.1, max_value = 2.0) | |
st.caption(""" | |
กำหนดให้โมเดลพยายามหลีกเลี่ยงการใช้คำหรือวลีเดียวกันซ้ำๆ ค่าที่สูงขึ้นจะทำให้โมเดลเลี่ยงการตอบโดยใช้คำ หรือวลีเดิมๆ | |
""") | |
load_model_button = st.button("ดาวน์โลดโมเดล") | |
if load_model_button: | |
st.session_state['max_length'] = max_length | |
st.session_state['temperature'] = temperature | |
st.session_state['repetition_penalty'] = repetition_penalty | |
st.session_state['model'] = load_llm_model(max_length, temperature, repetition_penalty) | |
st.write("⚠️ Please expect to wait **1 - 2 minutes ** for the application to download the 3-billion-parameter LLM") | |
st.write('Successfully model loaded ✅') | |
# st.write('Successfully mผ te['repetition_penalty']) | |
# st.markdown(type(st.session_state['model'])) |