vilarin commited on
Commit
9443a16
1 Parent(s): b01d655

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer
5
+ from vllm import LLM, SamplingParams
6
+
7
+ model = os.environ.get["MODEL_ID"]
8
+ MODEL_NAME = model.split("/")[-1]
9
+
10
+ DESCRIPTION = f"""
11
+ <h3>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></h3>
12
+ <center>
13
+ <p>Qwen is the large language model built by Alibaba Cloud.
14
+ <br>
15
+ Feel free to test without log.
16
+ </p>
17
+ </center>
18
+ """
19
+
20
+ css="""
21
+ h1 {
22
+ text-align: center;
23
+ }
24
+ footer {
25
+ visibility: hidden;
26
+ }
27
+ """
28
+
29
+
30
+ # Initialize the tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained(model)
32
+
33
+ # Pass the default decoding hyperparameters of Qwen2-7B-Instruct
34
+ # max_tokens is for the maximum length for generation.
35
+
36
+ # Input the model name or path. Can be GPTQ or AWQ models.
37
+ llm = LLM(model=model)
38
+
39
+ @spaces.GPU
40
+ def generate(message, history, system, max_tokens, temperature, top_p, top_k, penalty):
41
+ # Prepare your prompts
42
+ conversation = [
43
+ {"role": "system", "content":sytem}
44
+ ]
45
+ for prompt, answer in history:
46
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
47
+ conversation.append({"role": "user", "content": message})
48
+
49
+ print(f"Conversation is -\n{conversation}")
50
+
51
+ text = tokenizer.apply_chat_template(
52
+ conversation,
53
+ tokenize=False,
54
+ add_generation_prompt=True
55
+ )
56
+ sampling_params = SamplingParams(
57
+ temperature=temperature,
58
+ top_p=top_p,
59
+ top_k=top_k,
60
+ repetition_penalty=penalty,
61
+ max_tokens=max_tokens,
62
+ eos_token_id=[151645,151643],
63
+ )
64
+ # generate outputs
65
+ outputs = llm.generate([text], sampling_params)
66
+
67
+ # Print the outputs.
68
+ for output in outputs:
69
+ prompt = output.prompt
70
+ generated_text = output.outputs[0].text
71
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
72
+ return generated_text
73
+
74
+
75
+
76
+
77
+ with gr.Blocks(css=css, fill-height=True) as demo:
78
+ gr.HTML(TITLE)
79
+ gr.HTML(DESCRIPTION)
80
+ gr.ChatInterface(
81
+ fn=generate,
82
+ chatbot=gr.Chatbot(scale=1),
83
+ additional_inputs=[
84
+ gr.Textbox(value="You are a helpful assistant.", label="System message"),
85
+ gr.Slider(minimum=1, maximum=30720, value=2048, step=1, label="Max tokens"),
86
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
87
+ gr.Slider(
88
+ minimum=0.1,
89
+ maximum=1.0,
90
+ value=0.95,
91
+ step=0.05,
92
+ label="Top-p",
93
+ ),
94
+ gr.Slider(
95
+ minimum=0,
96
+ maximum=20,
97
+ value=20,
98
+ step=1,
99
+ label="Top-k",
100
+ ),
101
+ gr.Slider(
102
+ minimum=0.0,
103
+ maximum=2.0,
104
+ value=1,
105
+ step=0.1,
106
+ label="Repetition penalty",
107
+ ),
108
+ ],
109
+ retry_btn="Retry",
110
+ undo_btn="Undo",
111
+ clear_btn="Clear",
112
+ submit_btn="Send",
113
+ )
114
+
115
+ if __name__ == "__main__":
116
+ demo.launch()