johann22 commited on
Commit
49e0a2f
·
1 Parent(s): cdb1abe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -16,7 +16,9 @@ client = InferenceClient(
16
  history = []
17
  hist_out= []
18
  summary =[]
 
19
  summary.append("")
 
20
  def format_prompt(message, history):
21
  prompt = "<s>"
22
  for user_prompt, bot_response in history:
@@ -38,7 +40,8 @@ repetition_penalty=1.0,
38
  def compress_history(formatted_prompt):
39
 
40
  seed = random.randint(1,1111111111111111)
41
- agent=prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0])
 
42
  system_prompt=agent
43
  temperature = 0.9
44
  if temperature < 1e-2:
@@ -62,13 +65,14 @@ def compress_history(formatted_prompt):
62
  output += response.token.text
63
  #history.append((output,history))
64
  print(output)
 
65
  return output
66
 
67
 
68
  def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
69
  #def question_generate(prompt, history):
70
  seed = random.randint(1,1111111111111111)
71
- agent=prompts.QUESTION_GENERATOR
72
  system_prompt=agent
73
  temperature = float(temperature)
74
  if temperature < 1e-2:
@@ -105,6 +109,7 @@ def create_valid_filename(invalid_filename: str) -> str:
105
  return ''.join(char for char in valid_chars if char in allowed_chars)
106
 
107
  def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0,):
 
108
  #print(datetime.datetime.now())
109
  uid=uuid.uuid4()
110
  current_time = str(datetime.datetime.now())
 
16
  history = []
17
  hist_out= []
18
  summary =[]
19
+ main_point=[]
20
  summary.append("")
21
+ main_point.append("")
22
  def format_prompt(message, history):
23
  prompt = "<s>"
24
  for user_prompt, bot_response in history:
 
40
  def compress_history(formatted_prompt):
41
 
42
  seed = random.randint(1,1111111111111111)
43
+ agent=prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])
44
+
45
  system_prompt=agent
46
  temperature = 0.9
47
  if temperature < 1e-2:
 
65
  output += response.token.text
66
  #history.append((output,history))
67
  print(output)
68
+ print(main_point[0])
69
  return output
70
 
71
 
72
  def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
73
  #def question_generate(prompt, history):
74
  seed = random.randint(1,1111111111111111)
75
+ agent=prompts.QUESTION_GENERATOR.format(focus=main_point[0])
76
  system_prompt=agent
77
  temperature = float(temperature)
78
  if temperature < 1e-2:
 
109
  return ''.join(char for char in valid_chars if char in allowed_chars)
110
 
111
  def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0,):
112
+ main_point[0]=prompt
113
  #print(datetime.datetime.now())
114
  uid=uuid.uuid4()
115
  current_time = str(datetime.datetime.now())