pvduy commited on
Commit
a7706d8
·
verified ·
1 Parent(s): 70d4f40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -60
app.py CHANGED
@@ -1,83 +1,56 @@
1
- import argparse
2
- import os
3
  import spaces
4
 
5
- import gradio as gr
6
-
7
  import json
8
- from threading import Thread
9
- import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
-
12
- MAX_LENGTH = 4096
13
- DEFAULT_MAX_NEW_TOKENS = 1024
14
-
15
 
16
- def parse_args():
17
- parser = argparse.ArgumentParser()
18
- parser.add_argument("--base_model", type=str) # model path
19
- parser.add_argument("--n_gpus", type=int, default=1) # n_gpu
20
- return parser.parse_args()
21
 
22
  @spaces.GPU()
23
  def predict(message, history, system_prompt, temperature, max_tokens):
24
- global model, tokenizer, device
25
- messages = [{'role': 'system', 'content': system_prompt}]
26
  for human, assistant in history:
27
- messages.append({'role': 'user', 'content': human})
28
- messages.append({'role': 'assistant', 'content': assistant})
29
- messages.append({'role': 'user', 'content': message})
30
- problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
31
- stop_tokens = ["<|endoftext|>", "<|im_end|>"]
32
- streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
33
- enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
34
- input_ids = enc.input_ids
35
- attention_mask = enc.attention_mask
36
-
37
- if input_ids.shape[1] > MAX_LENGTH:
38
- input_ids = input_ids[:, -MAX_LENGTH:]
39
-
40
- input_ids = input_ids.to(device)
41
- attention_mask = attention_mask.to(device)
42
- generate_kwargs = dict(
43
- {"input_ids": input_ids, "attention_mask": attention_mask},
44
- streamer=streamer,
45
- do_sample=True,
46
- top_p=0.95,
47
- temperature=0.5,
48
- max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
49
- use_cache=True,
50
- eos_token_id=100278 # <|im_end|>
51
- )
52
- t = Thread(target=model.generate, kwargs=generate_kwargs)
53
- t.start()
54
- outputs = []
55
- for text in streamer:
56
- outputs.append(text)
57
- yield "".join(outputs)
58
-
59
 
60
 
61
  if __name__ == "__main__":
62
- args = parse_args()
63
- tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-chat", trust_remote_code=True)
64
- model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-2-chat", trust_remote_code=True, torch_dtype=torch.bfloat16)
65
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
66
- model = model.to(device)
67
  gr.ChatInterface(
68
  predict,
69
- title="StableLM 2 Chat - Demo",
70
- description="StableLM 2 Chat - StabilityAI",
71
  theme="soft",
72
- chatbot=gr.Chatbot(label="Chat History",),
73
  textbox=gr.Textbox(placeholder="input", container=False, scale=7),
74
  retry_btn=None,
75
  undo_btn="Delete Previous",
76
  clear_btn="Clear",
77
  additional_inputs=[
78
- gr.Textbox("You are a helpful assistant.", label="System Prompt"),
79
- gr.Slider(0, 1, 0.5, label="Temperature"),
80
  gr.Slider(100, 2048, 1024, label="Max Tokens"),
81
  ],
82
  additional_inputs_accordion_name="Parameters",
 
 
 
 
 
83
  ).queue().launch()
 
 
 
1
  import spaces
2
 
3
+ import os
 
4
  import json
5
+ from vllm import LLM, SamplingParams
6
+ from transformers import AutoTokenizer
 
 
 
 
 
7
 
 
 
 
 
 
8
 
9
  @spaces.GPU()
10
  def predict(message, history, system_prompt, temperature, max_tokens):
11
+ messages = [{"role": "system", "content": system_prompt}]
 
12
  for human, assistant in history:
13
+ messages.append({"role": "user", "content": human})
14
+ messages.append({"role": "assistant", "content": assistant})
15
+ messages.append({"role": "user", "content": message})
16
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
17
+ stop_tokens = ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]
18
+ sampling_params = SamplingParams(temperature=temperature, top_p=1, max_tokens=max_tokens, stop=stop_tokens)
19
+ completions = llm.generate(prompt, sampling_params)
20
+ for output in completions:
21
+ prompt = output.prompt
22
+ print('==========================question=============================')
23
+ print(prompt)
24
+ generated_text = output.outputs[0].text
25
+ print('===========================answer=============================')
26
+ print(generated_text)
27
+ for idx in range(len(generated_text)):
28
+ yield generated_text[:idx+1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  if __name__ == "__main__":
32
+ path = "stabilityai/stablelm-2-12b-chat"
33
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
34
+ llm = LLM(model=path, tensor_parallel_size=1, trust_remote_code=True)
 
 
35
  gr.ChatInterface(
36
  predict,
37
+ title="LLM playground",
38
+ description="This is a LLM playground for StableLM",
39
  theme="soft",
40
+ chatbot=gr.Chatbot(height=1400, label="Chat History",),
41
  textbox=gr.Textbox(placeholder="input", container=False, scale=7),
42
  retry_btn=None,
43
  undo_btn="Delete Previous",
44
  clear_btn="Clear",
45
  additional_inputs=[
46
+ gr.Textbox("You are a hepful assistant.", label="System Prompt"),
47
+ gr.Slider(0, 1, 0.7, label="Temperature"),
48
  gr.Slider(100, 2048, 1024, label="Max Tokens"),
49
  ],
50
  additional_inputs_accordion_name="Parameters",
51
+ examples=[
52
+ ["implement snake game using pygame"],
53
+ ["Can you explain briefly to me what is the Python programming language?"],
54
+ ["write a program to find the factorial of a number"],
55
+ ],
56
  ).queue().launch()