NorGLM commited on
Commit
d27ee6f
1 Parent(s): e3713ac

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +85 -0
README.md CHANGED
@@ -1,3 +1,88 @@
1
  ---
2
  license: cc-by-nc-sa-4.0
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-nc-sa-4.0
3
+ language:
4
+ - 'no'
5
+ datasets:
6
+ - NorGLM/NO-CNN-DailyMail
7
+ pipeline_tag: summarization
8
  ---
9
+
10
+ # Model Card
11
+
12
+ NorGPT-3B-continue-summarization-peft is trained on top of [NorGPT-3B-continue](https://huggingface.co/NorGLM/NorGPT-3B-continue) model on [NO-CNN-DailyMail](https://huggingface.co/datasets/NorGLM/NO-CNN-DailyMail) dataset.
13
+
14
+ Prompt format:
15
+ ```
16
+ Summarise the article:\\n{article} |||\\n{positive_sample}
17
+ ```
18
+
19
+ Inference prompt:
20
+ ```
21
+ Summarise the article:\\n{article} |||\\n
22
+ ```
23
+
24
+ ## Run the Model
25
+ ```python
26
+ from peft import PeftModel, PeftConfig
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer
28
+ import torch
29
+
30
+ source_model_id = "NorGLM/NorGPT-3B-continue"
31
+ peft_model_id = "NorGLM/NorGPT-3B-continue-summarization-peft"
32
+
33
+ config = PeftConfig.from_pretrained(peft_model_id)
34
+ model = AutoModelForCausalLM.from_pretrained(source_model_id, device_map='balanced')
35
+
36
+ tokenizer_max_len = 2048
37
+ tokenizer_config = {'pretrained_model_name_or_path': source_model_id,
38
+ 'max_len': tokenizer_max_len}
39
+ tokenizer = tokenizer = AutoTokenizer.from_pretrained(**tokenizer_config)
40
+ tokenizer.pad_token = tokenizer.eos_token
41
+
42
+ model = PeftModel.from_pretrained(model, peft_model_id)
43
+ ```
44
+
45
+ ## Inference on test set
46
+ Load the model to evaluate on the test set of NO-CNN-DailyMail dataset:
47
+ ```python
48
+ def generate_texts(model, tokenizer, prompts, max_seq_length=200, do_sample=True, top_p=0.95, top_k=10):
49
+ # prompts are a list of news articles
50
+ results = []
51
+ cnt = 0
52
+ for prompt in prompts:
53
+ cnt += 1
54
+ pro_len = len(prompt.split())
55
+ if pro_len>1024:
56
+ results.append('')
57
+ continue
58
+
59
+ prompt = 'Summarise the article:\\n' + prompt + ' |||\\n'
60
+
61
+ model_inputs = tokenizer(prompt, return_tensors='pt').to(torch_device)
62
+ output = model.generate(**model_inputs, do_sample=False, max_new_tokens=max_seq_length)
63
+ result = tokenizer.decode(output[0], skip_special_tokens=True)
64
+ result = result.split("|||\\n")[-1]
65
+ results.append(result)
66
+ return results
67
+
68
+ print("--LOADING EVAL DATAS---")
69
+ eval_data = load_dataset("NorGLM/NO-CNN-DailyMail", data_files="test.csv")
70
+ prompts = eval_data['train']['article']
71
+ positive_samples = eval_data['train']['positive_sample']
72
+
73
+ print("--MAKING PREDICTIONS---")
74
+ model.eval()
75
+
76
+ output_file = <output file name>
77
+ with torch.no_grad():
78
+ results = generate_texts(model, tokenizer, prompts)
79
+
80
+ df = pd.DataFrame({'article':prompts, 'generated_text':results, 'positive_sample':positive_samples})
81
+
82
+ print("Save results to csv file...")
83
+ df.to_csv(output_file)
84
+
85
+ ```
86
+
87
+ ## Note
88
+ More training details will be released soon!