johann22 commited on
Commit
a1e37d7
·
1 Parent(s): e3c90f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -8
app.py CHANGED
@@ -23,6 +23,37 @@ agents =[
23
  "MEME_GENERATOR",
24
  "QUESTION_GENERATOR",
25
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def generate(
27
  prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
28
  ):
@@ -59,14 +90,23 @@ def generate(
59
  do_sample=True,
60
  seed=seed,
61
  )
62
-
63
- formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
64
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
65
- output = ""
66
-
67
- for response in stream:
68
- output += response.token.text
69
- yield output
 
 
 
 
 
 
 
 
 
70
  return output
71
 
72
 
 
23
  "MEME_GENERATOR",
24
  "QUESTION_GENERATOR",
25
  ]
26
+
27
+
28
+ def question_generate(
29
+ prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
30
+ ):
31
+ seed = random.randint(1,1111111111111111)
32
+ agent=prompts.QUESTION_GENERATOR
33
+ system_prompt=agent
34
+ temperature = float(temperature)
35
+ if temperature < 1e-2:
36
+ temperature = 1e-2
37
+ top_p = float(top_p)
38
+
39
+ generate_kwargs = dict(
40
+ temperature=temperature,
41
+ max_new_tokens=max_new_tokens,
42
+ top_p=top_p,
43
+ repetition_penalty=repetition_penalty,
44
+ do_sample=True,
45
+ seed=seed,
46
+ )
47
+
48
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
49
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
50
+ output = ""
51
+
52
+ for response in stream:
53
+ output += response.token.text
54
+ yield output
55
+ return output
56
+
57
  def generate(
58
  prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
59
  ):
 
90
  do_sample=True,
91
  seed=seed,
92
  )
93
+ while True:
94
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
95
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
96
+ output = ""
97
+
98
+ for response in stream:
99
+ output += response.token.text
100
+
101
+ yield output
102
+ prompt = question_generate(
103
+ output, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
104
+ ):
105
+
106
+
107
+
108
+
109
+
110
  return output
111
 
112