WilliamGazeley commited on
Commit
3a830ca
1 Parent(s): 8b384d6

Fix model loading

Browse files
Files changed (1) hide show
  1. app.py +35 -29
app.py CHANGED
@@ -2,18 +2,25 @@ import streamlit as st
2
  from transformers import pipeline
3
  from concurrent.futures import ThreadPoolExecutor
4
 
5
- # Load models at startup
6
- with st.spinner(text="Loading Models..."):
7
- base_pipe = pipeline(
8
- "text-generation",
9
- model="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
10
- max_length=512,
11
- )
12
- irai_pipe = pipeline(
13
- "text-generation",
14
- model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
15
- max_length=512,
16
- )
 
 
 
 
 
 
 
17
 
18
  prompt_template = (
19
  "<|system|>\n"
@@ -23,29 +30,28 @@ prompt_template = (
23
  "<|assistant|>\n"
24
  )
25
 
 
 
26
 
27
  def generate_base_response(input_text):
28
  return base_pipe(input_text)[0]["generated_text"]
29
 
30
 
31
  def generate_irai_response(input_text):
32
- return (
33
- irai_pipe(prompt_template.format(input_text=input_text))[0]["generated_text"]
34
- .split("<|assistant|>")[1]
35
- .strip()
36
- )
37
 
38
 
39
  def generate_response(input_text):
40
- with ThreadPoolExecutor() as executor:
41
- try:
42
- future_base = executor.submit(generate_base_response, input_text)
43
- future_irai = executor.submit(generate_irai_response, input_text)
44
- base_resp = future_base.result()
45
- irai_resp = future_irai.result()
46
- except Exception as e:
47
- st.error(f"An error occurred: {e}")
48
- return None, None
49
  return base_resp, irai_resp
50
 
51
 
@@ -54,14 +60,14 @@ user_input = st.text_area("Enter a financial question:", "")
54
 
55
  if st.button("Generate"):
56
  if user_input:
57
- with st.spinner(text="Generating text..."):
58
  base_response, irai_response = generate_response(user_input)
59
  col1, col2 = st.columns(2)
60
  with col1:
61
  st.header("Base Model Response")
62
- st.text_area("", base_response, height=300)
63
  with col2:
64
  st.header("IRAI LLM-ADE Model Response")
65
- st.text_area("", irai_response, height=300)
66
  else:
67
  st.warning("Please enter some text to generate a response.")
 
2
  from transformers import pipeline
3
  from concurrent.futures import ThreadPoolExecutor
4
 
5
+
6
+ # Function to load models only once using Streamlit's cache mechanism
7
+ @st.cache(allow_output_mutation=True)
8
+ def load_models():
9
+ with st.spinner(text="Loading Models..."):
10
+ base_pipe = pipeline(
11
+ "text-generation",
12
+ model="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
13
+ max_length=512,
14
+ )
15
+ irai_pipe = pipeline(
16
+ "text-generation",
17
+ model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
18
+ max_length=512,
19
+ )
20
+ return base_pipe, irai_pipe
21
+
22
+
23
+ base_pipe, irai_pipe = load_models()
24
 
25
  prompt_template = (
26
  "<|system|>\n"
 
30
  "<|assistant|>\n"
31
  )
32
 
33
+ executor = ThreadPoolExecutor(max_workers=2)
34
+
35
 
36
  def generate_base_response(input_text):
37
  return base_pipe(input_text)[0]["generated_text"]
38
 
39
 
40
  def generate_irai_response(input_text):
41
+ formatted_input = prompt_template.format(input_text=input_text)
42
+ result = irai_pipe(formatted_input)[0]["generated_text"]
43
+ return result.split("<|assistant|>")[1].strip()
 
 
44
 
45
 
46
  def generate_response(input_text):
47
+ try:
48
+ future_base = executor.submit(generate_base_response, input_text)
49
+ future_irai = executor.submit(generate_irai_response, input_text)
50
+ base_resp = future_base.result()
51
+ irai_resp = future_irai.result()
52
+ except Exception as e:
53
+ st.error(f"An error occurred: {e}")
54
+ return None, None
 
55
  return base_resp, irai_resp
56
 
57
 
 
60
 
61
  if st.button("Generate"):
62
  if user_input:
63
+ with st.spinner("Generating text..."):
64
  base_response, irai_response = generate_response(user_input)
65
  col1, col2 = st.columns(2)
66
  with col1:
67
  st.header("Base Model Response")
68
+ st.text_area(label="", value=base_response, height=300)
69
  with col2:
70
  st.header("IRAI LLM-ADE Model Response")
71
+ st.text_area(label="", value=irai_response, height=300)
72
  else:
73
  st.warning("Please enter some text to generate a response.")