WilliamGazeley commited on
Commit
a615d13
1 Parent(s): bd328c0

Update to use latest model

Browse files
Files changed (1) hide show
  1. app.py +25 -13
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from concurrent.futures import ThreadPoolExecutor
@@ -6,42 +7,52 @@ from concurrent.futures import ThreadPoolExecutor
6
  # Function to load models only once using Streamlit's cache mechanism
7
  @st.cache_resource(show_spinner="Loading Models...")
8
  def load_models():
 
9
  base_pipe = pipeline(
10
  "text-generation",
11
  model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
12
- max_length=512,
13
  )
14
  irai_pipe = pipeline(
15
  "text-generation",
16
  model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
17
- max_length=512,
18
  )
19
  return base_pipe, irai_pipe
20
 
21
 
22
  base_pipe, irai_pipe = load_models()
23
 
24
- prompt_template = (
25
  "<|system|>\n"
26
- "You are a friendly chatbot who always gives helpful, detailed, and polite answers.</s>\n"
27
  "<|user|>\n"
28
  "{input_text}</s>\n"
29
  "<|assistant|>\n"
30
  )
31
 
 
 
 
 
 
 
 
 
 
32
  executor = ThreadPoolExecutor(max_workers=2)
33
 
34
 
35
  def generate_base_response(input_text):
36
- formatted_input = prompt_template.format(input_text=input_text)
37
  result = base_pipe(formatted_input)[0]["generated_text"]
38
  return result.split("<|assistant|>")[1].strip()
39
 
40
 
41
  def generate_irai_response(input_text):
42
- formatted_input = prompt_template.format(input_text=input_text)
43
  result = irai_pipe(formatted_input)[0]["generated_text"]
44
- return result.split("<|assistant|>")[1].strip()
45
 
46
 
47
  @st.cache_data(show_spinner="Generating responses...")
@@ -49,7 +60,7 @@ def generate_response(input_text):
49
  try:
50
  future_base = executor.submit(generate_base_response, input_text)
51
  future_irai = executor.submit(generate_irai_response, input_text)
52
- base_resp = future_base.result().replace(input_text, "", 1).strip()
53
  irai_resp = future_irai.result()
54
  except Exception as e:
55
  st.error(f"An error occurred: {e}")
@@ -58,17 +69,18 @@ def generate_response(input_text):
58
 
59
 
60
  st.title("Base Model vs IRAI LLM-ADE")
61
- user_input = st.text_area("Enter a financial question:", "")
 
62
 
63
  if st.button("Generate") or user_input:
64
  if user_input:
65
  base_response, irai_response = generate_response(user_input)
66
  col1, col2 = st.columns(2)
67
  with col1:
68
- st.header("Base Model")
69
- st.text_area(label="", value=base_response, height=300)
70
  with col2:
71
- st.header("LLM-ADE Enhanced")
72
- st.text_area(label="", value=irai_response, height=300)
73
  else:
74
  st.warning("Please enter some text")
 
1
+ import torch
2
  import streamlit as st
3
  from transformers import pipeline
4
  from concurrent.futures import ThreadPoolExecutor
 
7
  # Function to load models only once using Streamlit's cache mechanism
8
  @st.cache_resource(show_spinner="Loading Models...")
9
  def load_models():
10
+ device = 0 if torch.cuda.is_available() else -1
11
  base_pipe = pipeline(
12
  "text-generation",
13
  model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
14
+ device=device,
15
  )
16
  irai_pipe = pipeline(
17
  "text-generation",
18
  model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
19
+ device=device,
20
  )
21
  return base_pipe, irai_pipe
22
 
23
 
24
  base_pipe, irai_pipe = load_models()
25
 
26
+ alpaca_template = (
27
  "<|system|>\n"
28
+ "{sys}</s>\n"
29
  "<|user|>\n"
30
  "{input_text}</s>\n"
31
  "<|assistant|>\n"
32
  )
33
 
34
+ chatml_template = (
35
+ "<|im_start|>system\n"
36
+ "{sys}<|im_end|>\n"
37
+ "<|im_start|>user\n"
38
+ "{input_text}<|im_end|>\n"
39
+ "<|im_start|>assistant\n"
40
+ )
41
+
42
+ system_prompt = "You are an AI assistant trained on an extensive dataset, including technology reports, investment reports, financial texts, and other relevant sources. Please answer the following question based on the knowledge you have acquired during your training. Do not make any assumptions or use information from external sources. If you don't have enough pre-existing knowledge to provide a complete answer, simply respond with \"I don't have enough pre-existing knowledge to comprehensively answer this question.\" If you can partially answer the question based on your training, please provide that partial answer and clarify that it may not be a complete response. Assume today is June 5, 2024, and respond as if you have no knowledge of events after your training data's cut-off date."
43
  executor = ThreadPoolExecutor(max_workers=2)
44
 
45
 
46
  def generate_base_response(input_text):
47
+ formatted_input = alpaca_template.format(sys=system_prompt, input_text=input_text)
48
  result = base_pipe(formatted_input)[0]["generated_text"]
49
  return result.split("<|assistant|>")[1].strip()
50
 
51
 
52
  def generate_irai_response(input_text):
53
+ formatted_input = chatml_template.format(sys=system_prompt, input_text=input_text)
54
  result = irai_pipe(formatted_input)[0]["generated_text"]
55
+ return result.split("<|im_start|>assistant")[1].split("<|im_end|>")[0].strip()
56
 
57
 
58
  @st.cache_data(show_spinner="Generating responses...")
 
60
  try:
61
  future_base = executor.submit(generate_base_response, input_text)
62
  future_irai = executor.submit(generate_irai_response, input_text)
63
+ base_resp = future_base.result()
64
  irai_resp = future_irai.result()
65
  except Exception as e:
66
  st.error(f"An error occurred: {e}")
 
69
 
70
 
71
  st.title("Base Model vs IRAI LLM-ADE")
72
+ st.markdown("This is a demonstration of the [LLM-ADE paper](https://arxiv.org/abs/2404.13028) (knowledge cutoff is June 5, 2024)")
73
+ user_input = st.text_area("Ask about finance related questions and mega-cap (top 15) stocks!", "")
74
 
75
  if st.button("Generate") or user_input:
76
  if user_input:
77
  base_response, irai_response = generate_response(user_input)
78
  col1, col2 = st.columns(2)
79
  with col1:
80
+ st.write("### Base Model (Tiny-Llama)")
81
+ st.text_area(label="none", value=base_response, height=300, key="base_response", label_visibility="hidden")
82
  with col2:
83
+ st.write("### LLM-ADE Enhanced")
84
+ st.text_area(label="none", value=irai_response, height=300, key="irai_response", label_visibility="hidden")
85
  else:
86
  st.warning("Please enter some text")