MatheusHRV commited on
Commit
ac7d0c9
·
verified ·
1 Parent(s): 2919744

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -10
app.py CHANGED
@@ -1,26 +1,52 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
  from langchain.schema import AIMessage, HumanMessage, SystemMessage
4
 
 
 
 
5
  st.set_page_config(page_title="LangChain Demo", page_icon=":robot:")
6
  st.header("MHRV Chatbot")
7
 
 
 
 
8
  if "sessionMessages" not in st.session_state:
9
  st.session_state.sessionMessages = [
10
- SystemMessage(content="You are a helpful customer support chatbot for a website.")
 
 
 
 
 
 
 
11
  ]
12
 
13
- # Load Flan-T5-Small (CPU-friendly)
14
- model_name = "google/flan-t5-small"
 
 
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
17
 
18
- generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=-1, max_new_tokens=256)
 
 
 
 
 
 
 
 
19
 
 
 
 
20
  def load_answer(question):
21
  st.session_state.sessionMessages.append(HumanMessage(content=question))
22
 
23
- # Concatenate messages into a single prompt
24
  prompt = ""
25
  for msg in st.session_state.sessionMessages:
26
  if isinstance(msg, SystemMessage):
@@ -30,9 +56,9 @@ def load_answer(question):
30
  elif isinstance(msg, AIMessage):
31
  prompt += f"AI: {msg.content}\n"
32
 
33
- # Generate response
34
- output = generator(prompt)
35
- answer_text = output[0]["generated_text"].strip()
36
 
37
  st.session_state.sessionMessages.append(AIMessage(content=answer_text))
38
  return answer_text
@@ -40,6 +66,9 @@ def load_answer(question):
40
  def get_text():
41
  return st.text_input("You: ", key="input")
42
 
 
 
 
43
  user_input = get_text()
44
  submit = st.button("Generate")
45
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  from langchain.schema import AIMessage, HumanMessage, SystemMessage
4
 
5
+ # ------------------------
6
+ # Streamlit UI
7
+ # ------------------------
8
  st.set_page_config(page_title="LangChain Demo", page_icon=":robot:")
9
  st.header("MHRV Chatbot")
10
 
11
+ # ------------------------
12
+ # Session memory
13
+ # ------------------------
14
  if "sessionMessages" not in st.session_state:
15
  st.session_state.sessionMessages = [
16
+ SystemMessage(
17
+ content=(
18
+ "You are a highly intelligent and helpful customer support assistant. "
19
+ "Answer user questions clearly, politely, and professionally. "
20
+ "If you don’t know the answer, say so instead of making things up. "
21
+ "Provide step-by-step instructions if relevant and helpful."
22
+ )
23
+ )
24
  ]
25
 
26
+ # ------------------------
27
+ # Load model and tokenizer
28
+ # ------------------------
29
+ model_name = "bigscience/bloom-560m" # CPU-compatible
30
  tokenizer = AutoTokenizer.from_pretrained(model_name)
31
+ model = AutoModelForCausalLM.from_pretrained(model_name)
32
 
33
+ # Create text-generation pipeline
34
+ generator = pipeline(
35
+ "text-generation",
36
+ model=model,
37
+ tokenizer=tokenizer,
38
+ device=-1, # CPU
39
+ max_new_tokens=256,
40
+ temperature=0.3
41
+ )
42
 
43
+ # ------------------------
44
+ # Helper functions
45
+ # ------------------------
46
  def load_answer(question):
47
  st.session_state.sessionMessages.append(HumanMessage(content=question))
48
 
49
+ # Build prompt from session messages
50
  prompt = ""
51
  for msg in st.session_state.sessionMessages:
52
  if isinstance(msg, SystemMessage):
 
56
  elif isinstance(msg, AIMessage):
57
  prompt += f"AI: {msg.content}\n"
58
 
59
+ # Generate answer
60
+ output = generator(prompt, max_new_tokens=256, do_sample=True, temperature=0.3)
61
+ answer_text = output[0]["generated_text"][len(prompt):].strip()
62
 
63
  st.session_state.sessionMessages.append(AIMessage(content=answer_text))
64
  return answer_text
 
66
  def get_text():
67
  return st.text_input("You: ", key="input")
68
 
69
+ # ------------------------
70
+ # Main app
71
+ # ------------------------
72
  user_input = get_text()
73
  submit = st.button("Generate")
74