StevenChen16 commited on
Commit
8f48aeb
1 Parent(s): dd91d0c

Update app.py to use multiple threads

Browse files
Files changed (1) hide show
  1. app.py +34 -23
app.py CHANGED
@@ -3,41 +3,53 @@ from llamafactory.chat import ChatModel
3
  from llamafactory.extras.misc import torch_gc
4
  import re
5
  import spaces
 
6
 
7
  def split_into_sentences(text):
8
  sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
9
  sentences = sentence_endings.split(text)
10
  return [sentence.strip() for sentence in sentences if sentence]
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  @spaces.GPU(duration=120)
13
  def process_paragraph(paragraph, progress=gr.Progress()):
14
  sentences = split_into_sentences(paragraph)
15
- results = []
16
  total_sentences = len(sentences)
 
 
17
  for i, sentence in enumerate(sentences):
18
- progress((i + 1) / total_sentences)
19
- messages.append({"role": "user", "content": sentence})
20
- sentence_response = ""
21
- for new_text in chat_model.stream_chat(messages, temperature=0.7, top_p=0.9, top_k=50, max_new_tokens=300):
22
- sentence_response += new_text.strip()
23
- category = sentence_response.strip().lower().replace(' ', '_')
24
- if category != "fair":
25
- results.append((sentence, category))
26
- else:
27
- results.append((sentence, "fair"))
28
- messages.append({"role": "assistant", "content": sentence_response})
29
- torch_gc()
30
- return results
31
 
 
32
 
33
  args = dict(
34
- model_name_or_path="princeton-nlp/Llama-3-Instruct-8B-SimPO", # 使用量化的 Llama-3-8B-Instruct 模型
35
- # model_name_or_path="StevenChen16/llama3-8b-compliance-review",
36
- # adapter_name_or_path="StevenChen16/llama3-8b-compliance-review-adapter", # 加载保存的 LoRA 适配器
37
- template="llama3", # 与训练时使用的模板相同
38
- finetuning_type="lora", # 与训练时使用的微调类型相同
39
- quantization_bit=8, # 加载 4-bit 量化模型
40
- use_unsloth=True, # 使用 UnslothAI 的 LoRA 优化以加速生成
41
  )
42
  chat_model = ChatModel(args)
43
  messages = []
@@ -56,7 +68,6 @@ label_to_color = {
56
  }
57
 
58
  with gr.Blocks() as demo:
59
-
60
  with gr.Row(equal_height=True):
61
  with gr.Column():
62
  input_text = gr.Textbox(label="Input Paragraph", lines=10, placeholder="Enter the paragraph here...")
@@ -71,4 +82,4 @@ with gr.Blocks() as demo:
71
 
72
  btn.click(on_click, inputs=input_text, outputs=[output])
73
 
74
- demo.launch(share=True)
 
3
  from llamafactory.extras.misc import torch_gc
4
  import re
5
  import spaces
6
+ from threading import Thread
7
 
8
  def split_into_sentences(text):
9
  sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
10
  sentences = sentence_endings.split(text)
11
  return [sentence.strip() for sentence in sentences if sentence]
12
 
13
+ @spaces.GPU(duration=120)
14
+ def process_sentence(sentence, index, results, messages, progress, total_sentences):
15
+ messages.append({"role": "user", "content": sentence})
16
+ sentence_response = ""
17
+ for new_text in chat_model.stream_chat(messages, temperature=0.7, top_p=0.9, top_k=50, max_new_tokens=300):
18
+ sentence_response += new_text.strip()
19
+ category = sentence_response.strip().lower().replace(' ', '_')
20
+ if category != "fair":
21
+ results[index] = (sentence, category)
22
+ else:
23
+ results[index] = (sentence, "fair")
24
+ messages.append({"role": "assistant", "content": sentence_response})
25
+ torch_gc()
26
+ progress((index + 1) / total_sentences)
27
+
28
  @spaces.GPU(duration=120)
29
  def process_paragraph(paragraph, progress=gr.Progress()):
30
  sentences = split_into_sentences(paragraph)
31
+ results = [None] * len(sentences)
32
  total_sentences = len(sentences)
33
+ threads = []
34
+
35
  for i, sentence in enumerate(sentences):
36
+ thread = Thread(target=process_sentence, args=(sentence, i, results, messages.copy(), progress, total_sentences))
37
+ threads.append(thread)
38
+ thread.start()
39
+
40
+ for thread in threads:
41
+ thread.join()
 
 
 
 
 
 
 
42
 
43
+ return results
44
 
45
  args = dict(
46
+ model_name_or_path="princeton-nlp/Llama-3-Instruct-8B-SimPO", # 使用量化的 Llama-3-8B-Instruct 模型
47
+ # model_name_or_path="StevenChen16/llama3-8b-compliance-review",
48
+ # adapter_name_or_path="StevenChen16/llama3-8b-compliance-review-adapter", # 加载保存的 LoRA 适配器
49
+ template="llama3", # 与训练时使用的模板相同
50
+ finetuning_type="lora", # 与训练时使用的微调类型相同
51
+ quantization_bit=8, # 加载 8-bit 量化模型
52
+ use_unsloth=True, # 使用 UnslothAI 的 LoRA 优化以加速生成
53
  )
54
  chat_model = ChatModel(args)
55
  messages = []
 
68
  }
69
 
70
  with gr.Blocks() as demo:
 
71
  with gr.Row(equal_height=True):
72
  with gr.Column():
73
  input_text = gr.Textbox(label="Input Paragraph", lines=10, placeholder="Enter the paragraph here...")
 
82
 
83
  btn.click(on_click, inputs=input_text, outputs=[output])
84
 
85
+ demo.launch(share=True)