```python from transformers import AutoModel, AutoTokenizer, StoppingCriteria import torch import argparse class EosListStoppingCriteria(StoppingCriteria): def __init__(self, eos_sequence = [137625, 137632, 2]): self.eos_sequence = eos_sequence def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: last_ids = input_ids[:,-1].tolist() return any(eos_id in last_ids for eos_id in self.eos_sequence) def test_model(ckpt): model = AutoModel.from_pretrained(ckpt, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(ckpt, trust_remote_code=True) init_prompt = "<|im_start|>user\n{input_message}<|end_of_user|>\n<|im_start|>" while True: history = "" print(f">>>让我们开始对话吧<<<") input_message = input() input_prompt = init_prompt.format(input_message = input_message) history += input_prompt input_ids = tokenizer.encode(history, return_tensors="pt") output = model.generate(input_ids, top_p=1.0, max_new_tokens=300, stopping_criteria = [EosListStoppingCriteria()]).squeeze() output_str = tokenizer.decode(output[input_ids.shape[1]: -1]) print(output_str) print(">>>>>>>><<<<<<<<<<") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--ckpt", type=str, help="path to the checkpoint", required=True) args = parser.parse_args() test_model(args.ckpt) ```