|
import os |
|
|
|
|
|
os.environ["HF_HOME"] = "/data/.cache/huggingface" |
|
|
|
from huggingface_hub import snapshot_download |
|
import streamlit as st |
|
from utils.help import get_disclaimer |
|
from utils.format import sec_to_time, fix_latex, get_youtube_embed |
|
from utils.rag_utils import load_youtube_data, load_book_data, load_summary, embed_question_sentence_transformer, fixed_knn_retrieval, get_random_question |
|
from utils.system_prompts import get_expert_system_prompt, get_synthesis_system_prompt |
|
from utils.openai_utils import embed_question_openai, openai_domain_specific_answer_generation, openai_context_integration |
|
from utils.llama_utils import get_bnb_config, load_base_model, load_fine_tuned_model, generate_response |
|
|
|
st.set_page_config(page_title="AI University") |
|
|
|
st.markdown(""" |
|
<style> |
|
.video-wrapper { |
|
position: relative; |
|
padding-bottom: 56.25%; |
|
height: 0; |
|
} |
|
.video-wrapper iframe { |
|
position: absolute; |
|
top: 0; |
|
left: 0; |
|
width: 100%; |
|
height: 100%; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
HOME = "/home/user/app" |
|
|
|
data_dir = HOME +"/data" |
|
private_data_dir = HOME + "/private_data" |
|
|
|
|
|
os.makedirs(private_data_dir, exist_ok=True) |
|
token = os.getenv("data") |
|
local_repo_path = snapshot_download( |
|
repo_id="my-ai-university/data", |
|
use_auth_token=token, |
|
repo_type="dataset", |
|
local_dir=private_data_dir, |
|
) |
|
|
|
adapter_path = HOME + "/LLaMA-TOMMI-1.0/" |
|
|
|
base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct" |
|
base_model_path_3B = "meta-llama/Llama-3.2-3B-Instruct" |
|
|
|
|
|
|
|
st.title(":red[AI University] :gray[/] FEM") |
|
st.markdown(""" |
|
Welcome to <span style='color:red'><a href='https://my-ai-university.com/' target='_blank' style='text-decoration: none; color: red;'>AI University</a></span> β an AI-powered platform designed to address scientific course queries, dynamically adapting to instructors' teaching styles and students' learning needs. |
|
This prototype showcases the capabilities of the <span style='color:red'><a href='https://github.com/my-ai-university' target='_blank' style='text-decoration: none; color: red;'>AI University platform</a></span> by providing expert answers to queries related to a graduate-level <span style='color:red'><a href='https://www.youtube.com/playlist?list=PLJhG_d-Sp_JHKVRhfTgDqbic_4MHpltXZ' target='_blank' style='text-decoration: none; color: red;'>Finite Element Method (FEM)</a></span> course. |
|
""", unsafe_allow_html=True) |
|
|
|
st.markdown(" ") |
|
with st.container(border=False): |
|
|
|
st.info(""" |
|
Heavy traffic or GPU limits may increase response time or cause errors. Disable expert model for faster replies or try again later. |
|
""", icon="π") |
|
|
|
if 'activate_expert' in st.session_state: |
|
st.session_state.activate_expert = st.toggle("Use expert model", value=st.session_state.activate_expert, key="use_expert_model1") |
|
else: |
|
st.session_state.activate_expert = st.toggle("Use expert model", value=True, key="use_expert_model1", help='More accurate but slower') |
|
|
|
st.markdown(" ") |
|
st.markdown(" ") |
|
|
|
|
|
|
|
with st.sidebar: |
|
st.header("Settings") |
|
|
|
with st.expander('Embedding model',expanded=True): |
|
|
|
|
|
|
|
embedding_model = st.selectbox("Choose content embedding model", [ |
|
"text-embedding-3-small", |
|
|
|
"all-MiniLM-L6-v2", |
|
|
|
], |
|
|
|
|
|
|
|
|
|
|
|
) |
|
st.divider() |
|
|
|
st.write('**Video lectures**') |
|
if embedding_model == "all-MiniLM-L6-v2": |
|
yt_token_choice = st.select_slider("Token per content", [128, 256], value=256, help="Larger values lead to an increase in the length of each retrieved piece of content", key="yt_token_len") |
|
elif embedding_model == "text-embedding-3-small": |
|
yt_token_choice = st.select_slider("Token per content", [256, 512, 1024], value=256, help="Larger values lead to an increase in the length of each retrieved piece of content", key="yt_token_len") |
|
yt_chunk_tokens = yt_token_choice |
|
yt_max_content = {128: 32, 256: 16, 512: 8, 1024: 4}[yt_chunk_tokens] |
|
top_k_YT = st.slider("Number of content pieces to retrieve", 0, yt_max_content, 4, key="yt_token_num") |
|
yt_overlap_tokens = yt_chunk_tokens // 4 |
|
|
|
st.divider() |
|
|
|
st.write('**Textbook**') |
|
show_textbook = False |
|
|
|
|
|
if embedding_model == "all-MiniLM-L6-v2": |
|
latex_token_choice = st.select_slider("Token per content", [128, 256], value=256, help="Larger values lead to an increase in the length of each retrieved piece of content", key="latex_token_len") |
|
elif embedding_model == "text-embedding-3-small": |
|
latex_token_choice = st.select_slider("Token per content", [128, 256, 512, 1024], value=256, help="Larger values lead to an increase in the length of each retrieved piece of content", key="latex_token_len") |
|
latex_chunk_tokens = latex_token_choice |
|
latex_max_content = {128: 32, 256: 16, 512: 8, 1024: 4}[latex_chunk_tokens] |
|
top_k_Latex = st.slider("Number of content pieces to retrieve", 0, latex_max_content, 4, key="latex_token_num") |
|
|
|
latex_overlap_tokens = 0 |
|
|
|
st.write(' ') |
|
with st.expander('Expert model', expanded=False): |
|
if st.session_state.activate_expert: |
|
st.session_state.activate_expert = st.toggle("Use expert model", value=True) |
|
else: |
|
st.session_state.activate_expert = st.toggle("Use expert model", value=False) |
|
|
|
show_expert_responce = st.toggle("Show initial expert answer", value=False) |
|
|
|
st.session_state.expert_model = st.selectbox( |
|
"Choose the LLM model", |
|
["LLaMA-TOMMI-1.0-11B", "LLaMA-3.2-11B", "gpt-4.1-mini"], |
|
index=0, |
|
key='a1model' |
|
) |
|
|
|
if st.session_state.expert_model in ["LLaMA-TOMMI-1.0-11B", "LLaMA-3.2-11B"]: |
|
expert_do_sample = st.toggle("Enable Sampling", value=False, key='expert_sample') |
|
|
|
if expert_do_sample: |
|
expert_temperature = st.slider("Temperature", 0.0, 1.5, 0.7, key='expert_temp') |
|
expert_top_k = st.slider("Top K", 0, 100, 50, key='expert_top_k') |
|
expert_top_p = st.slider("Top P", 0.0, 1.0, 0.95, key='expert_top_p') |
|
else: |
|
expert_num_beams = st.slider("Num Beams", 1, 4, 1, key='expert_num_beams') |
|
|
|
expert_max_new_tokens = st.slider("Max New Tokens", 100, 2000, 500, step=50, key='expert_max_new_tokens') |
|
else: |
|
expert_api_temperature = st.slider("Temperature", 0.0, 1.5, 0.7, key='a1t') |
|
expert_api_top_p = st.slider("Top P", 0.0, 1.0, 0.9, key='a1p') |
|
|
|
with st.expander('Synthesis model',expanded=False): |
|
|
|
|
|
show_yt_context = st.toggle("Show retrieved video content", value=False) |
|
st.session_state.synthesis_model = st.selectbox( |
|
"Choose the LLM model", |
|
["LLaMA-3.2-3B", "gpt-4o-mini", "gpt-4.1-mini"], |
|
index=2, |
|
key='a2model' |
|
) |
|
|
|
if st.session_state.synthesis_model in ["LLaMA-3.2-3B", "LLaMA-3.2-11B"]: |
|
synthesis_do_sample = st.toggle("Enable Sampling", value=False, key='synthesis_sample') |
|
|
|
if synthesis_do_sample: |
|
synthesis_temperature = st.slider("Temperature", 0.0, 1.5, 0.7, key='synthesis_temp') |
|
synthesis_top_k = st.slider("Top K", 0, 100, 50, key='synthesis_top_k') |
|
synthesis_top_p = st.slider("Top P", 0.0, 1.0, 0.95, key='synthesis_top_p') |
|
else: |
|
synthesis_num_beams = st.slider("Num Beams", 1, 4, 1, key='synthesis_num_beams') |
|
|
|
synthesis_max_new_tokens = st.slider("Max New Tokens", 100, 2000, 1500, step=50, key='synthesis_max_new_tokens') |
|
else: |
|
|
|
synthesis_api_temperature = st.slider("Temperature", 0.0, .3, .5, help="Defines the randomness in the next token prediction. Lower: More predictable and focused. Higher: More adventurous and diverse.", key='a2t') |
|
|
|
synthesis_api_top_p = st.slider("Top P", 0.1, 0.5, .3, help="Defines the range of token choices the model can consider in the next prediction. Lower: More focused and restricted to high-probability options. Higher: More creative, allowing consideration of less likely options.", key='a2p') |
|
|
|
|
|
if "question" not in st.session_state: |
|
st.session_state.question = "" |
|
|
|
|
|
text_area_placeholder = st.empty() |
|
question_help = "Including details or instructions improves the answer." |
|
st.session_state.question = text_area_placeholder.text_area( |
|
"**Enter your query about Finite Element Method**", |
|
height=120, |
|
value=st.session_state.question, |
|
help=question_help |
|
) |
|
|
|
_, col1, col2, _ = st.columns([4, 2, 4, 3]) |
|
with col1: |
|
submit_button_placeholder = st.empty() |
|
|
|
with col2: |
|
if st.button("Random Question"): |
|
while True: |
|
random_question = get_random_question(data_dir + "/questions.txt") |
|
if random_question != st.session_state.question: |
|
break |
|
st.session_state.question = random_question |
|
text_area_placeholder.text_area( |
|
"**Enter your query about Finite Element Method:**", |
|
height=120, |
|
value=st.session_state.question, |
|
help=question_help |
|
) |
|
|
|
with st.spinner("Loading LLaMA-TOMMI-1.0-11B..."): |
|
if st.session_state.expert_model == "LLaMA-TOMMI-1.0-11B": |
|
if 'tommi_model' not in st.session_state: |
|
tommi_model, tommi_tokenizer = load_fine_tuned_model(adapter_path, base_model_path) |
|
st.session_state.tommi_model = tommi_model |
|
st.session_state.tommi_tokenizer = tommi_tokenizer |
|
|
|
|
|
with st.spinner("Loading LLaMA-3.2-11B..."): |
|
if "LLaMA-3.2-11B" in [st.session_state.expert_model, st.session_state.synthesis_model]: |
|
if 'llama_model' not in st.session_state: |
|
llama_model, llama_tokenizer = load_base_model(base_model_path) |
|
st.session_state.llama_model = llama_model |
|
st.session_state.llama_tokenizer = llama_tokenizer |
|
|
|
with st.spinner("Loading LLaMA-3.2-3B..."): |
|
if "LLaMA-3.2-3B" in [st.session_state.expert_model, st.session_state.synthesis_model]: |
|
if 'llama_model_3B' not in st.session_state: |
|
llama_model_3B, llama_tokenizer_3B = load_base_model(base_model_path_3B) |
|
st.session_state.llama_model_3B = llama_model_3B |
|
st.session_state.llama_tokenizer_3B = llama_tokenizer_3B |
|
|
|
|
|
text_data_YT, context_embeddings_YT = load_youtube_data(data_dir, embedding_model, yt_chunk_tokens, yt_overlap_tokens) |
|
text_data_Latex, context_embeddings_Latex = load_book_data(private_data_dir, embedding_model, latex_chunk_tokens, latex_overlap_tokens) |
|
summary = load_summary(data_dir + '/KG_FEM_summary.json') |
|
|
|
if 'question_answered' not in st.session_state: |
|
st.session_state.question_answered = False |
|
if 'context_by_video' not in st.session_state: |
|
st.session_state.context_by_video = {} |
|
if 'context_by_section' not in st.session_state: |
|
st.session_state.context_by_section = {} |
|
if 'answer' not in st.session_state: |
|
st.session_state.answer = "" |
|
if 'playing_video_id' not in st.session_state: |
|
st.session_state.playing_video_id = None |
|
|
|
if submit_button_placeholder.button("AI Answer", type="primary"): |
|
if st.session_state.question == "": |
|
st.markdown("") |
|
st.write("Please enter a query. :smirk:") |
|
st.session_state.question_answered = False |
|
|
|
else: |
|
with st.spinner("Finding relevant contexts..."): |
|
|
|
if embedding_model == "all-MiniLM-L6-v2": |
|
question_embedding = embed_question_sentence_transformer(st.session_state.question, model_name="all-MiniLM-L6-v2") |
|
elif embedding_model == "text-embedding-3-small": |
|
question_embedding = embed_question_openai(st.session_state.question, embedding_model) |
|
|
|
initial_max_k = int(0.1 * context_embeddings_YT.shape[0]) |
|
idx_YT = fixed_knn_retrieval(question_embedding, context_embeddings_YT, top_k=top_k_YT, min_k=0) |
|
idx_Latex = fixed_knn_retrieval(question_embedding, context_embeddings_Latex, top_k=top_k_Latex, min_k=0) |
|
|
|
relevant_contexts_YT = sorted([text_data_YT[i] for i in idx_YT], key=lambda x: x['order']) |
|
relevant_contexts_Latex = sorted([text_data_Latex[i] for i in idx_Latex], key=lambda x: x['order']) |
|
|
|
st.session_state.context_by_video = {} |
|
for context_item in relevant_contexts_YT: |
|
video_id = context_item['video_id'] |
|
if video_id not in st.session_state.context_by_video: |
|
st.session_state.context_by_video[video_id] = [] |
|
st.session_state.context_by_video[video_id].append(context_item) |
|
|
|
st.session_state.context_by_section = {} |
|
for context_item in relevant_contexts_Latex: |
|
section_id = context_item['section'] |
|
if section_id not in st.session_state.context_by_section: |
|
st.session_state.context_by_section[section_id] = [] |
|
st.session_state.context_by_section[section_id].append(context_item) |
|
|
|
context = '' |
|
for i, (video_id, contexts) in enumerate(st.session_state.context_by_video.items(), start=1): |
|
for context_item in contexts: |
|
start_time = int(context_item['start']) |
|
context += f'Video {i}, time: {sec_to_time(start_time)}:' + context_item['text'] + '\n\n' |
|
st.session_state.yt_context = fix_latex(context) |
|
|
|
for i, (section_id, contexts) in enumerate(st.session_state.context_by_section.items(), start=1): |
|
context += f'Section {i} ({section_id}):\n' |
|
for context_item in contexts: |
|
context += context_item['text'] + '\n\n' |
|
|
|
with st.spinner("Answering the question..."): |
|
|
|
|
|
|
|
if st.session_state.activate_expert: |
|
if st.session_state.expert_model in ["LLaMA-TOMMI-1.0-11B", "LLaMA-3.2-11B"]: |
|
|
|
if st.session_state.expert_model == "LLaMA-TOMMI-1.0-11B": |
|
model_ = st.session_state.tommi_model |
|
tokenizer_ = st.session_state.tommi_tokenizer |
|
|
|
elif st.session_state.expert_model == "LLaMA-3.2-11B": |
|
model_ = st.session_state.llama_model |
|
tokenizer_ = st.session_state.llama_tokenizer |
|
|
|
messages = [ |
|
{"role": "system", "content": get_expert_system_prompt()}, |
|
{"role": "user", "content": st.session_state.question} |
|
] |
|
|
|
expert_answer = generate_response( |
|
model=model_, |
|
tokenizer=tokenizer_, |
|
messages=messages, |
|
tokenizer_max_length=500, |
|
do_sample=expert_do_sample, |
|
temperature=expert_temperature if expert_do_sample else None, |
|
top_k=expert_top_k if expert_do_sample else None, |
|
top_p=expert_top_p if expert_do_sample else None, |
|
num_beams=expert_num_beams if not expert_do_sample else 1, |
|
max_new_tokens=expert_max_new_tokens |
|
) |
|
|
|
else: |
|
expert_answer = openai_domain_specific_answer_generation( |
|
get_expert_system_prompt(), |
|
st.session_state.question, |
|
model=st.session_state.expert_model, |
|
temperature=expert_api_temperature, |
|
top_p=expert_api_top_p |
|
) |
|
|
|
st.session_state.expert_answer = fix_latex(expert_answer) |
|
|
|
else: |
|
st.session_state.expert_answer = 'No Expert Answer. Only use the context.' |
|
|
|
|
|
|
|
|
|
if st.session_state.synthesis_model in ["LLaMA-3.2-3B", "LLaMA-3.2-11B"]: |
|
|
|
if st.session_state.synthesis_model == "LLaMA-3.2-11B": |
|
model_s = st.session_state.llama_model |
|
tokenizer_s = st.session_state.llama_tokenizer |
|
|
|
elif st.session_state.synthesis_model == "LLaMA-3.2-3B": |
|
model_s = st.session_state.llama_model_3B |
|
tokenizer_s = st.session_state.llama_tokenizer_3B |
|
|
|
synthesis_prompt = f""" |
|
Question: |
|
{st.session_state.question} |
|
|
|
Direct Answer: |
|
{st.session_state.expert_answer} |
|
|
|
Retrieved Context: |
|
{context} |
|
|
|
Final Answer: |
|
""" |
|
messages = [ |
|
{"role": "system", "content": get_synthesis_system_prompt("Finite Element Method")}, |
|
{"role": "user", "content": synthesis_prompt} |
|
] |
|
|
|
synthesis_answer = generate_response( |
|
model=model_s, |
|
tokenizer=tokenizer_s, |
|
messages=messages, |
|
tokenizer_max_length=30000, |
|
do_sample=synthesis_do_sample, |
|
temperature=synthesis_temperature if synthesis_do_sample else None, |
|
top_k=synthesis_top_k if synthesis_do_sample else None, |
|
top_p=synthesis_top_p if synthesis_do_sample else None, |
|
num_beams=synthesis_num_beams if not synthesis_do_sample else 1, |
|
max_new_tokens=synthesis_max_new_tokens |
|
) |
|
|
|
else: |
|
synthesis_answer = openai_context_integration( |
|
get_synthesis_system_prompt("Finite Element Method"), |
|
st.session_state.question, |
|
st.session_state.expert_answer, |
|
context, |
|
model=st.session_state.synthesis_model, |
|
temperature=synthesis_api_temperature, |
|
top_p=synthesis_api_top_p |
|
) |
|
|
|
|
|
if synthesis_answer.split()[0] == "NOT_ENOUGH_INFO": |
|
st.markdown("") |
|
st.markdown("#### Query:") |
|
st.markdown(fix_latex(st.session_state.question)) |
|
if show_expert_responce: |
|
st.markdown("#### Initial Expert Answer:") |
|
st.markdown(st.session_state.expert_answer) |
|
st.markdown("#### Answer:") |
|
st.write(":smiling_face_with_tear:") |
|
st.markdown(synthesis_answer.split('NOT_ENOUGH_INFO')[1]) |
|
st.divider() |
|
st.caption(get_disclaimer()) |
|
|
|
st.session_state.question_answered = False |
|
st.stop() |
|
else: |
|
st.session_state.answer = fix_latex(synthesis_answer) |
|
st.session_state.question_answered = True |
|
|
|
if st.session_state.question_answered: |
|
st.markdown("") |
|
st.markdown("#### Query:") |
|
st.markdown(fix_latex(st.session_state.question)) |
|
if show_expert_responce: |
|
st.markdown("#### Initial Expert Answer:") |
|
st.markdown(st.session_state.expert_answer) |
|
st.markdown("#### Answer:") |
|
st.markdown(st.session_state.answer) |
|
if show_yt_context: |
|
st.markdown("#### Retrieved lecture video transcripts:") |
|
st.markdown(st.session_state.yt_context) |
|
|
|
if top_k_YT > 0: |
|
st.markdown("#### Retrieved content in lecture videos") |
|
for i, (video_id, contexts) in enumerate(st.session_state.context_by_video.items(), start=1): |
|
|
|
with st.container(border=True): |
|
st.markdown(f"**Video {i} | {contexts[0]['title']}**") |
|
video_placeholder = st.empty() |
|
video_placeholder.markdown(get_youtube_embed(video_id, 0, 0), unsafe_allow_html=True) |
|
st.markdown('') |
|
with st.container(border=False): |
|
st.markdown("Retrieved Times") |
|
cols = st.columns([1 for i in range(len(contexts))] + [9 - len(contexts)]) |
|
for j, context_item in enumerate(contexts): |
|
start_time = int(context_item['start']) |
|
label = sec_to_time(start_time) |
|
if cols[j].button(label, key=f"{video_id}_{start_time}"): |
|
if st.session_state.playing_video_id is not None: |
|
st.session_state.playing_video_id = None |
|
video_placeholder.empty() |
|
video_placeholder.markdown(get_youtube_embed(video_id, start_time, 1), unsafe_allow_html=True) |
|
st.session_state.playing_video_id = video_id |
|
|
|
with st.expander("Video Summary", expanded=False): |
|
|
|
st.markdown(summary[video_id]) |
|
|
|
if show_textbook and top_k_Latex > 0: |
|
st.markdown("#### Retrieved content in textbook",help="The Finite Element Method: Linear Static and Dynamic Finite Element Analysis") |
|
for i, (section_id, contexts) in enumerate(st.session_state.context_by_section.items(), start=1): |
|
|
|
st.markdown(f"**Section {i} | {section_id}**") |
|
for context_item in contexts: |
|
st.markdown(context_item['text']) |
|
st.divider() |
|
|
|
st.markdown(" ") |
|
st.divider() |
|
st.caption(get_disclaimer()) |