DDDSSS commited on
Commit
3ce1634
1 Parent(s): 8a48a42

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +90 -1
README.md CHANGED
@@ -10,4 +10,93 @@ datasets:
10
  metrics:
11
  - bleu
12
  - sacrebleu
13
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  metrics:
11
  - bleu
12
  - sacrebleu
13
+ ---
14
+ 该模型主要的训练数据是opus100和CodeAlpaca_20K中的英文作为翻译内容,采用chatglm作为翻译器翻译成中文,并将脏数据筛选后得到DDDSSS/en-zh-dataset数据集
15
+
16
+
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
18
+ parser.add_argument('--device', default="cpu", type=str, help='"cuda:1"、"cuda:2"……')
19
+ mode_name = opt.model
20
+ device = opt.device
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(mode_name)
22
+ tokenizer = AutoTokenizer.from_pretrained(mode_name)
23
+ translation = pipeline("translation_en_to_zh", model=model, tokenizer=tokenizer,
24
+ torch_dtype="float", device_map=True,device=device)
25
+ x=["If nothing is detected and there is a config.json file, it’s assumed the library is transformers.","By looking into the presence of files such as *.nemo or *saved_model.pb*, the Hub can determine if a model is from NeMo or Keras."]
26
+ re = translation(x, max_length=450)
27
+ print('翻译为:' ,re)
28
+
29
+
30
+ 微调:
31
+ import numpy as np
32
+ from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
33
+ import torch
34
+ # books = load_from_disk("")
35
+ books = load_dataset("json", data_files=".json")
36
+ books = books["train"].train_test_split(test_size=0.2)
37
+ checkpoint = "./opus-mt-en-zh"
38
+ # checkpoint = "./model/checkpoint-19304"
39
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
40
+ source_lang = "en"
41
+ target_lang = "zh"
42
+ def preprocess_function(examples):
43
+ inputs = [example[source_lang] for example in examples["translation"]]
44
+ targets = [example[target_lang] for example in examples["translation"]]
45
+ model_inputs = tokenizer(inputs, text_target=targets, max_length=512, truncation=True)
46
+ return model_inputs
47
+ tokenized_books = books.map(preprocess_function, batched=True)
48
+ data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
49
+ metric = evaluate.load("sacrebleu")
50
+
51
+ def postprocess_text(preds, labels):
52
+ preds = [pred.strip() for pred in preds]
53
+ labels = [[label.strip()] for label in labels]
54
+ return preds, labels
55
+
56
+ def compute_metrics(eval_preds):
57
+ preds, labels = eval_preds
58
+ if isinstance(preds, tuple):
59
+ preds = preds[0]
60
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
61
+
62
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
63
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
64
+
65
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
66
+
67
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels)
68
+ result = {"bleu": result["score"]}
69
+
70
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
71
+ result["gen_len"] = np.mean(prediction_lens)
72
+ result = {k: round(v, 4) for k, v in result.items()}
73
+ return result
74
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
75
+ batchsize=4
76
+ training_args = Seq2SeqTrainingArguments(
77
+
78
+ output_dir="./my_awesome_opus_books_model",
79
+ evaluation_strategy="epoch",
80
+ learning_rate=2e-4,
81
+ per_device_train_batch_size=batchsize,
82
+ per_device_eval_batch_size=batchsize,
83
+ weight_decay=0.01,
84
+ # save_total_limit=3,
85
+ num_train_epochs=4,
86
+ predict_with_generate=True,
87
+ fp16=True,
88
+ push_to_hub=False,
89
+ save_strategy="epoch",
90
+ jit_mode_eval=True
91
+ )
92
+
93
+ trainer = Seq2SeqTrainer(
94
+ model=model,
95
+ args=training_args,
96
+ train_dataset=tokenized_books["train"],
97
+ eval_dataset=tokenized_books["test"],
98
+ tokenizer=tokenizer,
99
+ data_collator=data_collator,
100
+ compute_metrics=compute_metrics,
101
+ )
102
+ trainer.train()