Update app.py
Browse files
app.py
CHANGED
|
@@ -127,41 +127,41 @@ else:
|
|
| 127 |
|
| 128 |
|
| 129 |
|
| 130 |
-
#
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
|
| 135 |
-
#
|
| 136 |
-
|
| 137 |
|
| 138 |
-
#
|
| 139 |
-
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
|
| 144 |
-
|
| 145 |
-
#
|
| 146 |
-
|
| 147 |
|
| 148 |
-
#
|
| 149 |
-
|
| 150 |
-
|
| 151 |
|
| 152 |
-
#
|
| 153 |
-
|
| 154 |
-
|
| 155 |
|
| 156 |
-
#
|
| 157 |
-
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
|
| 163 |
-
#
|
| 164 |
-
|
| 165 |
|
| 166 |
|
| 167 |
#############
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
|
| 130 |
+
# Load pre-trained GPT-2 model and tokenizer
|
| 131 |
+
model_name = "gpt2-medium" # "gpt2" # Use "gpt-3.5-turbo" or another model from Hugging Face if needed
|
| 132 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
| 133 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
| 134 |
|
| 135 |
+
# Initialize the text generation pipeline
|
| 136 |
+
gpt_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
| 137 |
|
| 138 |
+
# Streamlit UI
|
| 139 |
+
st.markdown("<h3 style='text-align: center; font-size: 20px;'>Chat with GPT</h3>", unsafe_allow_html=True)
|
| 140 |
|
| 141 |
+
if 'conversation' not in st.session_state:
|
| 142 |
+
st.session_state.conversation = ""
|
| 143 |
|
| 144 |
+
def chat_with_gpt(user_input):
|
| 145 |
+
# Append user input to the conversation
|
| 146 |
+
st.session_state.conversation += f"User: {user_input}\n"
|
| 147 |
|
| 148 |
+
# Generate response
|
| 149 |
+
response = gpt_pipeline(user_input, max_length=100, num_return_sequences=1)[0]['generated_text']
|
| 150 |
+
response_text = response.replace(user_input, '') # Strip the user input part from response
|
| 151 |
|
| 152 |
+
# Append GPT's response to the conversation
|
| 153 |
+
st.session_state.conversation += f"GPT: {response_text}\n"
|
| 154 |
+
return response_text
|
| 155 |
|
| 156 |
+
# Text input for user query
|
| 157 |
+
user_input = st.text_input("You:", "")
|
| 158 |
|
| 159 |
+
if st.button("Send"):
|
| 160 |
+
if user_input:
|
| 161 |
+
chat_with_gpt(user_input)
|
| 162 |
|
| 163 |
+
# Display conversation history
|
| 164 |
+
st.text_area("Conversation", value=st.session_state.conversation, height=400)
|
| 165 |
|
| 166 |
|
| 167 |
#############
|