qgyd2021 commited on
Commit
a845f24
1 Parent(s): c092f7a

[update]add main

Browse files
Files changed (4) hide show
  1. .gitignore +6 -0
  2. main.py +138 -0
  3. project_settings.py +12 -0
  4. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ .git/
3
+ .idea/
4
+
5
+ **/flagged/
6
+ **/__pycache__/
main.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import List, Tuple
4
+ from threading import Thread
5
+
6
+ import gradio as gr
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ import torch
9
+
10
+ from project_settings import project_path
11
+
12
+
13
+ def greet(question: str, history: List[Tuple[str, str]]):
14
+ answer = "Hello " + question + "!"
15
+ result = history + [(question, answer)]
16
+ return result
17
+
18
+
19
+ def chat_with_llm_non_stream(question: str,
20
+ history: List[Tuple[str, str]],
21
+ pretrained_model_name_or_path: str,
22
+ max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
23
+ ):
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ pretrained_model_name_or_path,
28
+ trust_remote_code=True,
29
+ low_cpu_mem_usage=True,
30
+ torch_dtype=torch.bfloat16,
31
+ device_map="auto",
32
+ offload_folder="./offload",
33
+ offload_state_dict=True,
34
+ # load_in_4bit=True,
35
+ )
36
+ model = model.to(device)
37
+ model = model.bfloat16().eval()
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(
40
+ pretrained_model_name_or_path,
41
+ trust_remote_code=True,
42
+ # llama不支持fast
43
+ use_fast=False if model.config.model_type == "llama" else True,
44
+ padding_side="left"
45
+ )
46
+
47
+ # QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
48
+ if tokenizer.__class__.__name__ == "QWenTokenizer":
49
+ tokenizer.pad_token_id = tokenizer.eod_id
50
+ tokenizer.bos_token_id = tokenizer.eod_id
51
+ tokenizer.eos_token_id = tokenizer.eod_id
52
+
53
+ input_ids = tokenizer(
54
+ question,
55
+ return_tensors="pt",
56
+ add_special_tokens=False,
57
+ ).input_ids.to(device)
58
+ bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)
59
+ eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device)
60
+ input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
61
+
62
+ with torch.no_grad():
63
+ outputs = model.generate(
64
+ input_ids=input_ids,
65
+ max_new_tokens=max_new_tokens,
66
+ do_sample=True,
67
+ top_p=top_p,
68
+ temperature=temperature,
69
+ repetition_penalty=repetition_penalty,
70
+ eos_token_id=tokenizer.eos_token_id
71
+ )
72
+ outputs = outputs.tolist()[0][len(input_ids[0]):]
73
+ response = tokenizer.decode(outputs)
74
+ response = response.strip().replace(tokenizer.eos_token, "").strip()
75
+
76
+ return
77
+
78
+
79
+ def main():
80
+ description = """
81
+ chat llm
82
+ """
83
+
84
+ with gr.Blocks() as blocks:
85
+ gr.Markdown(value="gradio demo")
86
+
87
+ chatbot = gr.Chatbot([], elem_id="chatbot", height=400)
88
+ with gr.Row():
89
+ with gr.Column(scale=4):
90
+ text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter", container=False)
91
+ with gr.Column(scale=1):
92
+ submit_button = gr.Button("💬Submit")
93
+ with gr.Column(scale=1):
94
+ clear_button = gr.Button(
95
+ '🗑️Clear',
96
+ variant='secondary',
97
+ )
98
+
99
+ with gr.Row():
100
+ with gr.Column(scale=1):
101
+ max_new_tokens = gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"),
102
+ with gr.Column(scale=1):
103
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
104
+ with gr.Column(scale=1):
105
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
106
+ with gr.Column(scale=1):
107
+ repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
108
+
109
+ with gr.Row():
110
+ model_name = gr.Dropdown(choices=["Qwen/Qwen-7B-Chat"],
111
+ value="Qwen/Qwen-7B-Chat",
112
+ label="model_name",
113
+ )
114
+ gr.Examples(examples=["你好"], inputs=text_box)
115
+
116
+ inputs = [
117
+ text_box, chatbot, model_name,
118
+ max_new_tokens, top_p, temperature, repetition_penalty
119
+ ]
120
+ outputs = [
121
+ chatbot
122
+ ]
123
+ text_box.submit(chat_with_llm_non_stream, inputs, outputs)
124
+ submit_button.click(chat_with_llm_non_stream, inputs, outputs)
125
+ clear_button.click(
126
+ fn=lambda: ('', ''),
127
+ outputs=[text_box, chatbot],
128
+ queue=False,
129
+ api_name=False,
130
+ )
131
+
132
+ blocks.queue().launch()
133
+
134
+ return
135
+
136
+
137
+ if __name__ == '__main__':
138
+ main()
project_settings.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+
7
+ project_path = os.path.abspath(os.path.dirname(__file__))
8
+ project_path = Path(project_path)
9
+
10
+
11
+ if __name__ == '__main__':
12
+ pass
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==3.38.0
2
+ transformers==4.30.2
3
+ torch==1.13.0