ruslanmv commited on
Commit
f150754
·
verified ·
1 Parent(s): cec1f02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -3,7 +3,7 @@ import streamlit as st
3
  from models import load_model
4
 
5
  # Load the model once
6
- demo = load_model()
7
 
8
  # Page configuration
9
  st.set_page_config(
@@ -73,20 +73,25 @@ if prompt := st.chat_input("Type your message..."):
73
  try:
74
  # Generate response using the model
75
  with st.spinner("Generating response..."):
76
- # Pass inputs as positional arguments to the Gradio model
77
- response = demo(
78
- f"{system_message}\n\nUser: {prompt}\nAssistant:",
79
- max_length=max_tokens, # Gradio parameter
80
  temperature=temperature,
81
- top_p=top_p
82
- )
 
 
 
 
 
83
 
84
  # Display assistant response
85
  with st.chat_message("assistant"):
86
- st.markdown(response)
87
 
88
  # Add assistant response to chat history
89
- st.session_state.messages.append({"role": "assistant", "content": response})
90
 
91
  except Exception as e:
92
- st.error(f"An error occurred: {str(e)}")
 
3
  from models import load_model
4
 
5
  # Load the model once
6
+ generator = load_model()
7
 
8
  # Page configuration
9
  st.set_page_config(
 
73
  try:
74
  # Generate response using the model
75
  with st.spinner("Generating response..."):
76
+ full_prompt = f"{system_message}\n\nUser: {prompt}\nAssistant:"
77
+ response = generator(
78
+ full_prompt,
79
+ max_length=max_tokens,
80
  temperature=temperature,
81
+ top_p=top_p,
82
+ do_sample=True,
83
+ num_return_sequences=1
84
+ )[0]['generated_text']
85
+
86
+ # Extract only the assistant's response
87
+ assistant_response = response.split("Assistant:")[-1].strip()
88
 
89
  # Display assistant response
90
  with st.chat_message("assistant"):
91
+ st.markdown(assistant_response)
92
 
93
  # Add assistant response to chat history
94
+ st.session_state.messages.append({"role": "assistant", "content": assistant_response})
95
 
96
  except Exception as e:
97
+ st.error(f"An error occurred: {str(e)}")