wuxiaojun commited on
Commit
38ee4b3
1 Parent(s): ee9cb3f

init commit

Browse files
Files changed (1) hide show
  1. README.md +8 -7
README.md CHANGED
@@ -46,18 +46,18 @@ import torch
46
  from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration
47
 
48
  # load tokenizer and model
49
- pretrained_model = "IDEA-CCNL/Randeng-T5-784M-MultiTask-Chinese"
50
 
51
  special_tokens = ["<extra_id_{}>".format(i) for i in range(100)]
52
  tokenizer = T5Tokenizer.from_pretrained(
53
- args.pretrained_model,
54
  do_lower_case=True,
55
  max_length=512,
56
  truncation=True,
57
  additional_special_tokens=special_tokens,
58
  )
59
- config = T5Config.from_pretrained(args.pretrained_model)
60
- model = T5ForConditionalGeneration.from_pretrained(args.pretrained_model, config=config)
61
  model.resize_token_embeddings(len(tokenizer))
62
  model.eval()
63
 
@@ -66,8 +66,8 @@ text = "新闻分类任务:【微软披露拓扑量子计算机计划!】这
66
  encode_dict = tokenizer(text, max_length=512, padding='max_length',truncation=True)
67
 
68
  inputs = {
69
- "input_ids": torch.tensor(encode_dict['input_ids']).long(),
70
- "attention_mask": torch.tensor(encode_dict['attention_mask']).long(),
71
  }
72
 
73
  # generate answer
@@ -80,8 +80,9 @@ logits = model.generate(
80
 
81
  logits=logits[:,1:]
82
  predict_label = [tokenizer.decode(i,skip_special_tokens=True) for i in logits]
 
83
 
84
- # model Output: 科技
85
  ```
86
 
87
  ## 引用 Citation
 
46
  from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration
47
 
48
  # load tokenizer and model
49
+ pretrained_model = "/cognitive_comp/wuxiaojun/pretrained/pytorch/huggingface/Randeng-T5-784M-MultiTask-Chinese"
50
 
51
  special_tokens = ["<extra_id_{}>".format(i) for i in range(100)]
52
  tokenizer = T5Tokenizer.from_pretrained(
53
+ pretrained_model,
54
  do_lower_case=True,
55
  max_length=512,
56
  truncation=True,
57
  additional_special_tokens=special_tokens,
58
  )
59
+ config = T5Config.from_pretrained(pretrained_model)
60
+ model = T5ForConditionalGeneration.from_pretrained(pretrained_model, config=config)
61
  model.resize_token_embeddings(len(tokenizer))
62
  model.eval()
63
 
 
66
  encode_dict = tokenizer(text, max_length=512, padding='max_length',truncation=True)
67
 
68
  inputs = {
69
+ "input_ids": torch.tensor([encode_dict['input_ids']]).long(),
70
+ "attention_mask": torch.tensor([encode_dict['attention_mask']]).long(),
71
  }
72
 
73
  # generate answer
 
80
 
81
  logits=logits[:,1:]
82
  predict_label = [tokenizer.decode(i,skip_special_tokens=True) for i in logits]
83
+ print(predict_label)
84
 
85
+ # model output: 科技
86
  ```
87
 
88
  ## 引用 Citation