BAAI
/

shunxing1234 commited on
Commit
a0ade15
1 Parent(s): 3c3936a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -14
README.md CHANGED
@@ -46,27 +46,26 @@ We will continue to release improved versions of Aquila model as open source. Fo
46
  ```python
47
  from transformers import AutoTokenizer, AutoModelForCausalLM
48
  import torch
49
- from cyg_conversation import covert_prompt_to_input_ids_with_history
50
 
51
- tokenizer = AutoTokenizer.from_pretrained("BAAI/AquilaChat-7B")
52
- model = AutoModelForCausalLM.from_pretrained("BAAI/AquilaChat-7B")
 
 
 
53
  model.eval()
54
- model.to("cuda:0")
55
- vocab = tokenizer.vocab
56
- print(len(vocab))
57
 
58
  text = "请给出10个要到北京旅游的理由。"
59
 
60
- tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=512)
61
 
62
- tokens = torch.tensor(tokens)[None,].to("cuda:0")
63
 
64
 
65
  with torch.no_grad():
66
  out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007)[0]
67
 
68
  out = tokenizer.decode(out.cpu().numpy().tolist())
69
-
70
  if "###" in out:
71
  special_index = out.index("###")
72
  out = out[: special_index]
@@ -74,17 +73,13 @@ with torch.no_grad():
74
  if "[UNK]" in out:
75
  special_index = out.index("[UNK]")
76
  out = out[:special_index]
77
-
78
  if "</s>" in out:
79
  special_index = out.index("</s>")
80
  out = out[: special_index]
81
 
82
  if len(out) > 0 and out[0] == " ":
83
  out = out[1:]
84
-
85
- convert_tokens = convert_tokens[1:]
86
- probs = probs[1:]
87
-
88
  print(out)
89
  ```
90
 
 
46
  ```python
47
  from transformers import AutoTokenizer, AutoModelForCausalLM
48
  import torch
 
49
 
50
+ device = torch.device("cuda:1")
51
+
52
+ model_info = "BAAI/AquilaChat-7B"
53
+ tokenizer = AutoTokenizer.from_pretrained(model_info, trust_remote_code=True)
54
+ model = AutoModelForCausalLM.from_pretrained(model_info, trust_remote_code=True)
55
  model.eval()
56
+ model.to(device)
 
 
57
 
58
  text = "请给出10个要到北京旅游的理由。"
59
 
60
+ tokens = tokenizer.encode_plus(text)['input_ids'][:-1]
61
 
62
+ tokens = torch.tensor(tokens)[None,].to(device)
63
 
64
 
65
  with torch.no_grad():
66
  out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007)[0]
67
 
68
  out = tokenizer.decode(out.cpu().numpy().tolist())
 
69
  if "###" in out:
70
  special_index = out.index("###")
71
  out = out[: special_index]
 
73
  if "[UNK]" in out:
74
  special_index = out.index("[UNK]")
75
  out = out[:special_index]
76
+
77
  if "</s>" in out:
78
  special_index = out.index("</s>")
79
  out = out[: special_index]
80
 
81
  if len(out) > 0 and out[0] == " ":
82
  out = out[1:]
 
 
 
 
83
  print(out)
84
  ```
85