Azure99 commited on
Commit
9f054c7
·
verified ·
1 Parent(s): 075fbd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -31,29 +31,27 @@ def get_input_ids(inst, history):
31
  return input_ids
32
 
33
 
34
- @spaces.GPU
35
- def chat(inst, history, temperature, top_p, repetition_penalty):
36
  with torch.no_grad():
37
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
38
- input_ids = get_input_ids(inst, history)
39
- if len(input_ids) > MAX_INPUT_LIMIT:
40
- yield "The input is too long, please clear the history."
41
- return
42
- generate_config = dict(
43
- max_new_tokens=MAX_NEW_TOKENS,
44
- temperature=temperature,
45
- top_p=top_p,
46
- repetition_penalty=repetition_penalty
47
- )
48
- print(generate_config)
49
- generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device), do_sample=True,
50
- streamer=streamer, **generate_config)
51
  Thread(target=model.generate, kwargs=generation_kwargs).start()
52
 
53
- outputs = ""
54
- for new_text in streamer:
55
- outputs += new_text
56
- yield outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  additional_inputs = [
@@ -93,7 +91,8 @@ gr.ChatInterface(chat,
93
  description='Hello, I am Blossom, an open source conversational large language model.🌠'
94
  '<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
95
  theme="soft",
96
- examples=[["Hello"], ["What is MBTI"], ["用Python实现二分查找"], ["为switch写一篇小红书种草文案,带上emoji"]],
 
97
  additional_inputs=additional_inputs,
98
  additional_inputs_accordion=gr.Accordion(label="Config", open=True),
99
  clear_btn="🗑️Clear",
 
31
  return input_ids
32
 
33
 
34
+ def generate(generation_kwargs):
 
35
  with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  Thread(target=model.generate, kwargs=generation_kwargs).start()
37
 
38
+
39
+ @spaces.GPU
40
+ def chat(inst, history, temperature, top_p, repetition_penalty):
41
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
42
+ input_ids = get_input_ids(inst, history)
43
+ if len(input_ids) > MAX_INPUT_LIMIT:
44
+ yield "The input is too long, please clear the history."
45
+ return
46
+ generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device),
47
+ streamer=streamer, do_sample=True, max_new_tokens=MAX_NEW_TOKENS,
48
+ temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
49
+ generate(generation_kwargs)
50
+
51
+ outputs = ""
52
+ for new_text in streamer:
53
+ outputs += new_text
54
+ yield outputs
55
 
56
 
57
  additional_inputs = [
 
91
  description='Hello, I am Blossom, an open source conversational large language model.🌠'
92
  '<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
93
  theme="soft",
94
+ examples=[["Hello"], ["What is MBTI"], ["用Python实现二分查找"],
95
+ ["为switch写一篇小红书种草文案,带上emoji"]],
96
  additional_inputs=additional_inputs,
97
  additional_inputs_accordion=gr.Accordion(label="Config", open=True),
98
  clear_btn="🗑️Clear",