File size: 3,174 Bytes
109014c
a20dfac
ecd63b4
 
 
954e857
 
 
 
 
 
 
 
 
 
 
 
 
 
ecd63b4
6e203a2
e87746b
 
558d9e8
e87746b
eaab710
e87746b
ecd63b4
 
 
954e857
eaab710
 
 
 
954e857
eaab710
8741596
ecd63b4
 
 
 
 
 
954e857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecd63b4
 
 
954e857
 
 
 
 
 
ecd63b4
954e857
ecd63b4
 
 
954e857
ecd63b4
 
 
 
e87746b
 
ecd63b4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import huggingface_hub
import streamlit as st
from vllm import LLM, SamplingParams


@st.cache(show_spinner=False)
def get_system_message():
    return """#Context:
You are an AI-based automated expert financial advisor named IRAI. You have a comprehensive understanding of finance and investing because you have trained on a  extensive dataset based on of financial news, analyst reports, books, company filings, earnings call transcripts, and finance websites.
#Objective:
Answer questions accurately and truthfully given the data you have trained on.  You do not have access to up-to-date current market data; this will be available in the future. 
Style and tone:
Please answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
#Audience:
The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
#Response:
Answer, concise yet insightful."""


@st.cache_resource(show_spinner=False)
def init_llm():
    huggingface_hub.login(token=os.getenv("HF_TOKEN"))
    llm = LLM(model="InvestmentResearchAI/LLM-ADE-dev")
    tok = llm.get_tokenizer()
    tok.eos_token = '<|im_end|>' # Override to use turns
    return llm

def get_response(prompt):
    try:
        sys_msg = get_system_message()
        convo = [
            {"role": "system", "content": sys_msg},
            {"role": "user", "content": prompt},
        ]
        llm = init_llm()
        prompts = [llm.get_tokenizer().apply_chat_template(convo, tokenize=False)]
        sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=500, stop_token_ids=[128009])
        outputs = llm.generate(prompts, sampling_params)
        for output in outputs:
            return output.outputs[0].text
    except Exception as e:
        return f"An error occurred: {str(e)}"


def get_response(prompt, custom_sys_msg):
    try:
        convo = [
            {"role": "system", "content": custom_sys_msg},
            {"role": "user", "content": prompt},
        ]
        prompts = [llm.get_tokenizer().apply_chat_template(convo, tokenize=False)]
        sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=2000, stop_token_ids=[128009])
        outputs = llm.generate(prompts, sampling_params)
        for output in outputs:
            return output.outputs[0].text
    except Exception as e:
        return f"An error occurred: {str(e)}"

def main():
    st.title("LLM-ADE 9B Demo")
    
    # Retrieve the default system message
    sys_msg = get_system_message()
    
    # UI for editable preprompt
    user_modified_sys_msg = st.text_area("Preprompt: ", value=sys_msg, height=200)

    input_text = st.text_area("Enter your text here:", value="", height=200)
    
    if st.button("Generate"):
        if input_text:
            with st.spinner('Generating response...'):
                response_text = get_response(input_text, user_modified_sys_msg)
                st.write(response_text)
        else:
            st.warning("Please enter some text to generate a response.")

llm = init_llm()

if __name__ == "__main__":
    main()