CMLL commited on
Commit
8560f66
1 Parent(s): ad01e97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -2
app.py CHANGED
@@ -44,11 +44,40 @@ def generate(
44
  system_prompt: str = "You are a helpful TCM assistant named 仲景中医大语言模型, created by 医哲未来. You can switch between Chinese and English based on user preference.",
45
  max_new_tokens: int = 1024,
46
  temperature: float = 0.6,
47
- top_p: float = 0.95,
48
  top_k: int = 50,
49
  repetition_penalty: float = 1.2,
50
  ) -> Iterator[str]:
51
- # ... (保持 generate 函数不变)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  chat_interface = gr.ChatInterface(
54
  fn=generate,
 
44
  system_prompt: str = "You are a helpful TCM assistant named 仲景中医大语言模型, created by 医哲未来. You can switch between Chinese and English based on user preference.",
45
  max_new_tokens: int = 1024,
46
  temperature: float = 0.6,
47
+ top_p: float = 0.9,
48
  top_k: int = 50,
49
  repetition_penalty: float = 1.2,
50
  ) -> Iterator[str]:
51
+ conversation = [{"role": "system", "content": system_prompt}]
52
+ for user, assistant in chat_history:
53
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
54
+ conversation.append({"role": "user", "content": message})
55
+
56
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
57
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
58
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
59
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
60
+ input_ids = input_ids.to(model.device)
61
+
62
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
63
+ generate_kwargs = dict(
64
+ {"input_ids": input_ids},
65
+ streamer=streamer,
66
+ max_new_tokens=max_new_tokens,
67
+ do_sample=True,
68
+ top_p=top_p,
69
+ top_k=top_k,
70
+ temperature=temperature,
71
+ num_beams=1,
72
+ repetition_penalty=repetition_penalty,
73
+ )
74
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
75
+ t.start()
76
+
77
+ outputs = []
78
+ for text in streamer:
79
+ outputs.append(text)
80
+ yield "".join(outputs)
81
 
82
  chat_interface = gr.ChatInterface(
83
  fn=generate,