wop commited on
Commit
e8cf90a
1 Parent(s): 55ff663

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import os
3
  os.system('pip install transformers torch')
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
 
@@ -8,16 +8,27 @@ model_name = "microsoft/DialoGPT-small"
8
  model = GPT2LMHeadModel.from_pretrained(model_name)
9
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
10
 
 
 
 
 
11
  def generate_response(prompt, max_length=50, temperature=0.8):
12
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
 
 
13
  output_ids = model.generate(input_ids, max_length=max_length, temperature=temperature, num_return_sequences=1)
14
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
15
  return response
16
 
17
  iface = gr.Interface(
18
  fn=generate_response,
19
  inputs=gr.Textbox(),
20
- outputs="text"
 
21
  )
22
 
23
  iface.launch()
 
1
  import gradio as gr
2
+ import os
3
  os.system('pip install transformers torch')
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
 
 
8
  model = GPT2LMHeadModel.from_pretrained(model_name)
9
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
10
 
11
+ # Initial system prompt and chat history
12
+ system_prompt = "You are a helpful assistant."
13
+ chat_history = system_prompt
14
+
15
  def generate_response(prompt, max_length=50, temperature=0.8):
16
+ global chat_history
17
+ input_text = chat_history + " User: " + prompt
18
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
19
  output_ids = model.generate(input_ids, max_length=max_length, temperature=temperature, num_return_sequences=1)
20
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
21
+
22
+ # Update chat history
23
+ chat_history += f" User: {prompt} Assistant: {response}"
24
+
25
  return response
26
 
27
  iface = gr.Interface(
28
  fn=generate_response,
29
  inputs=gr.Textbox(),
30
+ outputs="text",
31
+ live=True
32
  )
33
 
34
  iface.launch()