File size: 4,180 Bytes
a615d13
de1c7b8
fd872b2
8b384d6
 
3a830ca
 
5005937
3a830ca
a615d13
5005937
 
a494749
a615d13
5005937
 
 
 
a615d13
5005937
3a830ca
 
 
 
de1c7b8
a615d13
fd872b2
a615d13
fd872b2
 
 
 
 
a615d13
 
 
 
 
 
 
 
f7a2748
3a830ca
 
8b384d6
fd872b2
a615d13
bd328c0
 
fd872b2
8b384d6
fd872b2
a615d13
3a830ca
a615d13
8b384d6
de1c7b8
379c3fe
de1c7b8
3a830ca
 
 
a615d13
3a830ca
 
 
 
fd872b2
de1c7b8
8b384d6
d24411e
e9b6611
7e11909
78336c6
8b384d6
bd328c0
de1c7b8
77428fd
 
 
a615d13
 
77428fd
a615d13
 
de1c7b8
379c3fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import streamlit as st
from transformers import pipeline
from concurrent.futures import ThreadPoolExecutor


# Function to load models only once using Streamlit's cache mechanism
@st.cache_resource(show_spinner="Loading Models...")
def load_models():
    device = 0 if torch.cuda.is_available() else -1
    base_pipe = pipeline(
        "text-generation",
        model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        device=device,
    )
    irai_pipe = pipeline(
        "text-generation",
        model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
        device=device,
    )
    return base_pipe, irai_pipe


base_pipe, irai_pipe = load_models()

alpaca_template = (
    "<|system|>\n"
    "{sys}</s>\n"
    "<|user|>\n"
    "{input_text}</s>\n"
    "<|assistant|>\n"
)

chatml_template = (
    "<|im_start|>system\n"
    "{sys}<|im_end|>\n"
    "<|im_start|>user\n"
    "{input_text}<|im_end|>\n"
    "<|im_start|>assistant\n"
)

system_prompt = "You are an AI model with extensive knowledge in investing, finance, and economics, trained on a diverse dataset including technology reports, investment reports, financial texts, economic analyses, and other relevant sources up to June 5, 2024. Your purpose is to provide informative and helpful responses to questions related to these topics. Your answer is only for demonstration purposes only and will not be used as professional investment advice. When answering questions, focus on the data and information from your training data. If a question is outside your knowledge domain or you don't have enough pre-existing information to provide a complete answer, respond with \"I don't have sufficient knowledge from my training data to answer this question.\" Always strive to provide direct and concise answer to the question in a formal tone without opinion or sentiment. Please note that your knowledge is based on your training data up to June 5, 2024 and you have no knowledge of events or developments after your training data's cut-off date."
executor = ThreadPoolExecutor(max_workers=2)


def generate_base_response(input_text):
    formatted_input = alpaca_template.format(sys=system_prompt, input_text=input_text)
    result = base_pipe(formatted_input)[0]["generated_text"]
    return result.split("<|assistant|>")[1].strip()


def generate_irai_response(input_text):
    formatted_input = chatml_template.format(sys=system_prompt, input_text=input_text)
    result = irai_pipe(formatted_input)[0]["generated_text"]
    return result.split("<|im_start|>assistant")[1].split("<|im_end|>")[0].strip()


@st.cache_data(show_spinner="Generating responses...")
def generate_response(input_text):
    try:
        future_base = executor.submit(generate_base_response, input_text)
        future_irai = executor.submit(generate_irai_response, input_text)
        base_resp = future_base.result()
        irai_resp = future_irai.result()
    except Exception as e:
        st.error(f"An error occurred: {e}")
        return None, None
    return base_resp, irai_resp


st.title("Base Model vs IRAI LLM-ADE")
st.markdown("This is demo of the [LLM-ADE paper](https://arxiv.org/abs/2404.13028) using two instruct versions of [TinyLlama](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).")
st.markdown("LLM-ADE models work best for data curation and agentic pipelines, not as pure answer generator. This is a toy demo with no RAG involved, using same prompts - knowledge cutoff is June 5, 2024")
user_input = st.text_area("Please ask about investing/finance related questions and mega-cap (top 20) stocks!", "")

if st.button("Generate") or user_input:
    if user_input:
        base_response, irai_response = generate_response(user_input)
        col1, col2 = st.columns(2)
        with col1:
            st.write("### Base Model (Tiny-Llama)")
            st.text_area(label="none", value=base_response, height=300, key="base_response", label_visibility="hidden")
        with col2:
            st.write("### LLM-ADE Enhanced")
            st.text_area(label="none", value=irai_response, height=300, key="irai_response", label_visibility="hidden")
    else:
        st.warning("Please enter some text")