qgyd2021 commited on
Commit
febd133
1 Parent(s): d2b43fb
Files changed (2) hide show
  1. .gitignore +2 -0
  2. main.py +12 -5
.gitignore CHANGED
@@ -4,3 +4,5 @@
4
 
5
  **/flagged/
6
  **/__pycache__/
 
 
 
4
 
5
  **/flagged/
6
  **/__pycache__/
7
+
8
+ trained_models/
main.py CHANGED
@@ -3,6 +3,7 @@
3
  import argparse
4
  from collections import defaultdict
5
  import os
 
6
 
7
  import gradio as gr
8
  from threading import Thread
@@ -71,7 +72,6 @@ def main():
71
  input_ids = torch.tensor([input_ids], dtype=torch.long)
72
  input_ids = input_ids.to(device)
73
 
74
- output: str = ""
75
  streamer = TextIteratorStreamer(tokenizer=tokenizer)
76
 
77
  generation_kwargs = dict(
@@ -88,17 +88,24 @@ def main():
88
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
89
  thread.start()
90
 
 
 
91
  for output_ in streamer:
92
- output_ = output_.replace(" ", "")
93
- output_ = output_.replace("[CLS]", "")
 
 
 
94
  output_ = output_.replace("[SEP]", "\n")
95
  output_ = output_.replace("[UNK]", "")
96
- output_ = output_.replace(text, "")
97
 
98
  output += output_.strip()
99
  output_text_box.value += output
100
  yield output
101
 
 
 
102
  demo = gr.Interface(
103
  fn=fn_stream,
104
  inputs=[
@@ -107,7 +114,7 @@ def main():
107
  gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
108
  gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
109
  gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
110
- gr.Dropdown(choices=["qgyd2021/lib_service_4chan"], value="qgyd2021/lib_service_4chan", label="model_name"),
111
  gr.Checkbox(value=True, label="is_chat")
112
  ],
113
  outputs=[output_text_box],
 
3
  import argparse
4
  from collections import defaultdict
5
  import os
6
+ import platform
7
 
8
  import gradio as gr
9
  from threading import Thread
 
72
  input_ids = torch.tensor([input_ids], dtype=torch.long)
73
  input_ids = input_ids.to(device)
74
 
 
75
  streamer = TextIteratorStreamer(tokenizer=tokenizer)
76
 
77
  generation_kwargs = dict(
 
88
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
89
  thread.start()
90
 
91
+ output: str = ""
92
+ first_answer = True
93
  for output_ in streamer:
94
+ if first_answer:
95
+ first_answer = False
96
+ continue
97
+ # output_ = output_.replace(text, "")
98
+ # output_ = output_.replace("[CLS]", "")
99
  output_ = output_.replace("[SEP]", "\n")
100
  output_ = output_.replace("[UNK]", "")
101
+ output_ = output_.replace(" ", "")
102
 
103
  output += output_.strip()
104
  output_text_box.value += output
105
  yield output
106
 
107
+ model_name_choices = ["trained_models/lib_service_4chan"] \
108
+ if platform.system() == "Windows" else ["qgyd2021/lib_service_4chan"]
109
  demo = gr.Interface(
110
  fn=fn_stream,
111
  inputs=[
 
114
  gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
115
  gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
116
  gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
117
+ gr.Dropdown(choices=model_name_choices, value=model_name_choices[0], label="model_name"),
118
  gr.Checkbox(value=True, label="is_chat")
119
  ],
120
  outputs=[output_text_box],