5m4ck3r commited on
Commit
660eb2e
1 Parent(s): 7dbb473

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -31
app.py CHANGED
@@ -5,44 +5,32 @@ import gradio as gr
5
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
6
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
7
 
8
- # Initialize chat history
9
- chat_history_ids = None
10
-
11
- # Define a function to process user input and generate bot response
12
- def chat_cpu(user_input, chat_history=None):
13
- global chat_history_ids
14
-
15
- # If chat history is provided, use it
16
- if chat_history is not None:
17
- chat_history_ids = chat_history
18
 
19
  # Encode the new user input, add the eos_token, and return a tensor in PyTorch
20
- new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
21
 
22
  # Append the new user input tokens to the chat history
23
- bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
24
 
25
  # Generate a response while limiting the total chat history to 1000 tokens
26
  chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
27
 
28
- # Pretty print last output tokens from bot
29
  response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
30
 
31
- # Return the response and updated chat history
32
- return response, chat_history_ids
33
-
34
- # Define a Gradio interface
35
- iface = gr.Interface(
36
- fn=chat_cpu,
37
- inputs=["text", gr.Textbox(placeholder="Chat history will appear here", lines=10)],
38
- outputs=[gr.TextElement(text="BOT", id="bot_history", default="Chat AI"),
39
- gr.TextElement(text="USER", id="user_history", default="")],
40
- live=True,
41
- layout="vertical",
42
- theme="compact",
43
- title="Chat AI",
44
- css=".output {flex-direction: column-reverse;}",
45
- )
46
-
47
- # Launch the Gradio interface
48
- iface.launch()
 
5
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
6
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
7
 
8
+ def chat_with_history(message, chat_history=None):
9
+ # Initialize chat history if not provided
10
+ if chat_history is None:
11
+ chat_history = []
 
 
 
 
 
 
12
 
13
  # Encode the new user input, add the eos_token, and return a tensor in PyTorch
14
+ new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
15
 
16
  # Append the new user input tokens to the chat history
17
+ bot_input_ids = torch.cat([tokenizer.encode(pair[0] + tokenizer.eos_token, return_tensors='pt') for pair in chat_history] + [new_user_input_ids], dim=-1)
18
 
19
  # Generate a response while limiting the total chat history to 1000 tokens
20
  chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
21
 
22
+ # Decode the last output tokens from bot
23
  response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
24
 
25
+ # Update the chat history with the new user message and bot response
26
+ chat_history.append([message, response])
27
+
28
+ return response, chat_history
29
+
30
+ demo = gr.ChatInterface(
31
+ fn=chat_with_history,
32
+ examples=["hey how are you ?", "hola", "Yo!"],
33
+ title="Multi Chat Bot"
34
+ )
35
+
36
+ demo.launch()