|
import ast |
|
import streamlit as st |
|
|
|
def generate_sidebar(): |
|
st.sidebar.header("About", divider="rainbow") |
|
st.sidebar.markdown( |
|
("SciPIP will generate ideas step by step. The generation pipeline is the same as " |
|
"one-click generation, While you can improve each part manually after SciPIP providing the manuscript.") |
|
) |
|
|
|
DONE_COLOR = "black" |
|
UNDONE_COLOR = "gray" |
|
|
|
INPROGRESS_COLOR = "black" |
|
color_list = [] |
|
pipeline_list = ["1. Input Background", "2. Brainstorming", "3. Extracting Entities", "4. Retrieving Related Works", |
|
"5. Generate Initial Ideas", "6. Generate Final Ideas"] |
|
for i in range(1, 8): |
|
if st.session_state["global_state_step"] < i: |
|
color_list.append(UNDONE_COLOR) |
|
elif st.session_state["global_state_step"] == i: |
|
color_list.append(INPROGRESS_COLOR) |
|
elif st.session_state["global_state_step"] > i: |
|
color_list.append(DONE_COLOR) |
|
st.sidebar.header("Pipeline", divider="red") |
|
for i in range(6): |
|
st.sidebar.markdown(f"<font color='{color_list[i]}'>{pipeline_list[i]}</font>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
st.sidebar.header("Supported Fields", divider="orange") |
|
st.sidebar.caption("The supported fields are temporarily limited because we only collect literature " |
|
"from ICML, ICLR, NeurIPS, ACL, and EMNLP. Support for other fields are in progress.") |
|
st.sidebar.checkbox("Natural Language Processing (NLP)", value=True, disabled=True) |
|
st.sidebar.checkbox("Computer Vision (CV)", value=False, disabled=True) |
|
st.sidebar.checkbox("[Partial] Multimodal", value=True, disabled=True) |
|
st.sidebar.checkbox("Incoming Other Fields", value=False, disabled=True) |
|
|
|
st.sidebar.header("Help Us To Improve", divider="green") |
|
st.sidebar.markdown("https://forms.gle/YpLUrhqs1ahyCAe99", unsafe_allow_html=True) |
|
|
|
def get_textarea_height(text_content): |
|
if text_content is None: |
|
return 100 |
|
lines = text_content.split("\n") |
|
count = len(lines) |
|
for line in lines: |
|
count += len(line) // 96 |
|
return count * 23 + 20 |
|
|
|
def genrate_mainpage(backend): |
|
|
|
st.title('π¦ Generate Idea Step-by-step') |
|
st.markdown("# π³ Background") |
|
with st.form('background_form') as bg_form: |
|
background = st.session_state.get("background", "") |
|
background = st.text_area("Input your field background", background, placeholder="Input your field background", height=200, label_visibility="collapsed") |
|
|
|
cols = st.columns(4) |
|
def click_demo_i(i): |
|
st.session_state["background"] = backend.get_demo_i(i) |
|
for i, col in enumerate(cols): |
|
col.form_submit_button(f"Example {i + 1}", use_container_width=True, on_click=click_demo_i, args=(i,)) |
|
|
|
col1, col2 = st.columns([2, 30]) |
|
submitted = col1.form_submit_button('Submit', type="primary") |
|
if submitted: |
|
st.session_state["global_state_step"] = 2.0 |
|
with st.spinner(text="Brainstorming..."): |
|
st.session_state["brainstorms"] = backend.background2brainstorm_callback(background) |
|
|
|
st.session_state["brainstorms_expand"] = True |
|
st.session_state["global_state_step"] = 2.5 |
|
|
|
|
|
|
|
st.markdown("# π» Brainstorms") |
|
with st.expander("Here is the generated brainstorms", expanded=st.session_state.get("brainstorms_expand", False)): |
|
|
|
col1, col2 = st.columns(2) |
|
widget_height = get_textarea_height(st.session_state.get("brainstorms", "")) |
|
brainstorms = col1.text_area(label="brainstorms", value=st.session_state.get("brainstorms", ""), |
|
label_visibility="collapsed", height=widget_height) |
|
st.session_state["brainstorms"] = brainstorms |
|
if brainstorms: |
|
col2.markdown(f"{brainstorms}") |
|
else: |
|
col2.markdown(f"Please input the brainstorms on the left.") |
|
|
|
col1, col2 = st.columns([2, 30]) |
|
submitted = col1.button('Submit') |
|
if submitted: |
|
st.session_state["global_state_step"] = 3.0 |
|
with st.spinner(text="Extracting entities..."): |
|
st.session_state["entities"] = backend.brainstorm2entities_callback(background, brainstorms) |
|
|
|
st.session_state["global_state_step"] = 3.5 |
|
st.session_state["entities_expand"] = True |
|
|
|
|
|
st.markdown("# π± Extracted Entities") |
|
with st.expander("Here is the extracted entities", expanded=st.session_state.get("entities_expand", False)): |
|
col1, col2 = st.columns(2, ) |
|
entities = col1.text_area(label="entities", value=st.session_state.get("entities", "[]"), label_visibility="collapsed") |
|
entities = ast.literal_eval(entities) |
|
st.session_state["entities"] = entities |
|
if entities: |
|
col2.markdown(f"{entities}") |
|
else: |
|
col2.markdown(f"Please input the entities on the left.") |
|
submitted = col1.button('Submit', key="entities_button") |
|
if submitted: |
|
st.session_state["global_state_step"] = 4.0 |
|
with st.spinner(text="Retrieving related works..."): |
|
st.session_state["related_works"], st.session_state["related_works_intact"] = backend.entities2literature_callback(background, entities) |
|
|
|
st.session_state["global_state_step"] = 4.5 |
|
st.session_state["related_works_expand"] = True |
|
|
|
|
|
st.markdown("# π Retrieved Related Works") |
|
with st.expander("Here is the retrieved related works", expanded=st.session_state.get("related_works_expand", False)): |
|
col1, col2 = st.columns(2, ) |
|
widget_height = get_textarea_height(st.session_state.get("related_works", "")) |
|
related_works_title = col1.text_area(label="related_works", value=st.session_state.get("related_works", ""), |
|
label_visibility="collapsed", height=widget_height) |
|
if related_works_title: |
|
col2.markdown(f"{related_works_title}") |
|
else: |
|
col2.markdown(f"Please input the related works on the left.") |
|
submitted = col1.button('Submit', key="related_works_button") |
|
if submitted: |
|
st.session_state["global_state_step"] = 5.0 |
|
with st.spinner(text="Generating initial ideas..."): |
|
res = backend.literature2initial_ideas_callback(background, brainstorms, st.session_state["related_works_intact"]) |
|
st.session_state["initial_ideas"] = res[0] |
|
st.session_state["final_ideas"] = res[1] |
|
|
|
st.session_state["global_state_step"] = 5.5 |
|
st.session_state["initial_ideas_expand"] = True |
|
|
|
|
|
st.markdown("# πΌ Generated Initial Ideas") |
|
with st.expander("Here is the generated initial ideas", expanded=st.session_state.get("initial_ideas_expand", False)): |
|
col1, col2 = st.columns(2, ) |
|
widget_height = get_textarea_height(st.session_state.get("initial_ideas", "")) |
|
initial_ideas = col1.text_area(label="initial_ideas", value=st.session_state.get("initial_ideas", ""), |
|
label_visibility="collapsed", height=widget_height) |
|
if initial_ideas: |
|
col2.markdown(f"{initial_ideas}") |
|
else: |
|
col2.markdown(f"Please input the initial ideas on the left.") |
|
submitted = col1.button('Submit', key="initial_ideas_button") |
|
if submitted: |
|
st.session_state["global_state_step"] = 6.0 |
|
with st.spinner(text="Generating final ideas..."): |
|
st.session_state["final_ideas"] = backend.initial2final_callback(initial_ideas, st.session_state["final_ideas"]) |
|
|
|
st.session_state["global_state_step"] = 6.5 |
|
st.session_state["final_ideas_expand"] = True |
|
|
|
|
|
st.markdown("# πΈ Generated Final Ideas") |
|
with st.expander("Here is the generated final ideas", expanded=st.session_state.get("final_ideas_expand", False)): |
|
col1, col2 = st.columns(2, ) |
|
widget_height = get_textarea_height(st.session_state.get("final_ideas", "")) |
|
user_input = col1.text_area(label="final_ideas", value=st.session_state.get("final_ideas", ""), |
|
label_visibility="collapsed", height=widget_height) |
|
if user_input: |
|
col2.markdown(f"{user_input}") |
|
else: |
|
col2.markdown(f"Please input the final ideas on the left.") |
|
submitted = col1.button('Submit', key="final_ideas_button") |
|
|
|
def step_by_step_generation(backend): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "global_state_step" not in st.session_state: |
|
st.session_state["global_state_step"] = 1.0 |
|
|
|
genrate_mainpage(backend) |
|
generate_sidebar() |
|
|