KhantKyaw commited on
Commit
a85678b
1 Parent(s): 25572fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -30
app.py CHANGED
@@ -2,39 +2,85 @@ import streamlit as st
2
 
3
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
 
5
- model_path='KhantKyaw/GPT2_chatbot2'
6
-
7
- tokenizer = GPT2Tokenizer.from_pretrained(model_path)
8
- tokenizer.pad_token = tokenizer.eos_token
9
- model = GPT2LMHeadModel.from_pretrained(model_path)
10
-
11
  def generate_response(input_text):
12
- input_ids = tokenizer.encode(input_text, return_tensors='pt')
13
-
14
- output_sequences = model.generate(
15
- input_ids=input_ids,
16
- max_length=200,
17
- temperature=0.7,
18
- num_return_sequences=1,
19
- no_repeat_ngram_size=2,
20
- top_k=50,
21
- top_p=0.9,
22
- do_sample=True,
23
- pad_token_id=tokenizer.eos_token_id,
 
 
 
 
 
 
24
  )
25
-
26
- response_with_prefix = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
27
- response_start_idx = response_with_prefix.find("answer: ")
28
- if response_start_idx != -1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- response = response_with_prefix[response_start_idx + len("answer: "):]
31
- else:
 
 
32
 
33
- response = response_with_prefix
 
 
 
 
 
 
 
 
34
 
35
- return response
36
 
37
- prompt = st.chat_input(placeholder="Say Something!",key=None, max_chars=None, disabled=False, on_submit=None, args=None, kwargs=None)
38
- if prompt:
39
- with st.chat_message(name="AI",avatar=None):
40
- st.write(generate_response(prompt))
 
2
 
3
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
 
5
+ # Function to generate a response
 
 
 
 
 
6
  def generate_response(input_text):
7
+ # Adjusted input to include the [Bot] marker
8
+ #adjusted_input = f"{input_text} [Bot]"
9
+
10
+ # Encode the adjusted input
11
+ inputs = tokenizer(input_text, return_tensors="pt")
12
+
13
+ # Generate a sequence of text with a slightly increased max_length to account for the prompt length
14
+ output_sequences = model.generate(
15
+ input_ids=inputs['input_ids'],
16
+ attention_mask=inputs['attention_mask'],
17
+ max_length=100, # Adjusted max_length
18
+ temperature=0.7,
19
+ top_k=50,
20
+ top_p=0.95,
21
+ no_repeat_ngram_size=2,
22
+ pad_token_id=tokenizer.eos_token_id,
23
+ #early_stopping=True,
24
+ do_sample=True
25
  )
26
+
27
+ # Decode the generated sequence
28
+ full_generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
29
+
30
+ # Extract the generated response after the [Bot] marker
31
+ bot_response_start = full_generated_text.find('[Bot]') + len('[Bot]')
32
+ bot_response = full_generated_text[bot_response_start:]
33
+
34
+ # Trim the response to end at the last period within the specified max_length
35
+ last_period_index = bot_response.rfind('.')
36
+ if last_period_index != -1:
37
+ bot_response = bot_response[:last_period_index + 1]
38
+
39
+ return bot_response.strip()
40
+
41
+ # Load pre-trained model tokenizer (vocabulary) and model
42
+ model_name = 'KhantKyaw/Chat_GPT-2'
43
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
44
+ model = GPT2LMHeadModel.from_pretrained(model_name)
45
+
46
+ # Chat loop
47
+ print("Chatbot is ready. Type 'quit' to exit.")
48
+ while True:
49
+ user_input = input("You: ")
50
+ if user_input.lower() == "quit":
51
+ break
52
+ response = generate_response(user_input)
53
+ print("Chatbot:", response)
54
+
55
+
56
+
57
+ st.title("Simple Streamlit Chatbot")
58
+
59
+ # User input text box
60
+ user_input = st.text_input("You: ", key="user_input")
61
+
62
+ # Button to send the message
63
+ if st.button("Send"):
64
+ # Generating a response
65
+ response = get_response(user_input)
66
 
67
+ # Displaying the conversation
68
+ # Here, we use st.session_state to keep track of the conversation
69
+ if 'conversation' not in st.session_state:
70
+ st.session_state.conversation = []
71
 
72
+ # Append the user input and bot response to the conversation
73
+ st.session_state.conversation.append("You: " + user_input)
74
+ st.session_state.conversation.append("Bot: " + response)
75
+
76
+ # Display each line in the conversation
77
+ for line in st.session_state.conversation:
78
+ st.text(line)
79
+
80
+
81
 
 
82
 
83
+ #prompt = st.chat_input(placeholder="Say Something!",key=None, max_chars=None, disabled=False, on_submit=None, args=None, kwargs=None)
84
+ #if prompt:
85
+ # with st.chat_message(name="AI",avatar=None):
86
+ # st.write(generate_response(prompt))