thorunna commited on
Commit
5cdaab3
1 Parent(s): 62753ff

Script to run model updated

Browse files
Files changed (1) hide show
  1. run_model.py +29 -183
run_model.py CHANGED
@@ -1,188 +1,34 @@
1
- """This script runs the trained model on data and saves the predictions to a file."""
2
-
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- import torch
5
- import logging
6
- import random
7
- import tqdm
8
- import json
9
- import argparse
10
-
11
- # Set the logging level to info
12
- logging.basicConfig(level=logging.INFO)
13
-
14
- # Set the device to GPU if available
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- logging.info(f"Device: {device}")
17
-
18
- # Prompts for the different tasks
19
- START_PROMPT_TASK1 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
20
- END_PROMPT_TASK1 = "Sérðu eitthvað sem mætti betur fara í textanum? Búðu til lista af öllum slíkum tilvikum þar sem hver lína tilgreinir hver villan er, hvar hún er, og hvað væri gert í staðinn fyrir villuna.\n\n"
21
-
22
- START_PROMPT_TASK2 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.Ég er með tvær útgáfur af textanum, A og B, og önnur þeirra gæti verið betri en hin á einhvern hátt, t.d. hvað varðar stafsetningu, málfræði o.s.frv.\nHér er texti A:\n\n"
23
- MIDDLE_PROMPT_TASK2 = "Hér er texti B:\n\n"
24
- END_PROMPT_TASK2 = "Hvorn textann líst þér betur á?\n\n"
25
-
26
- START_PROMPT_TASK3 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
27
- END_PROMPT_TASK3 = "Reyndu nú að laga textann þannig að hann líti betur út, eins og þér finnst best við hæfi.\n\n"
28
-
29
- START_PROMPT_TASK = {
30
- 1: START_PROMPT_TASK1,
31
- 2: START_PROMPT_TASK2,
32
- 3: START_PROMPT_TASK3,
33
  }
34
- END_PROMPT_TASK = {1: END_PROMPT_TASK1, 2: END_PROMPT_TASK2, 3: END_PROMPT_TASK3}
35
-
36
- SEP = "\n\n"
37
-
38
-
39
- def set_seed(seed):
40
- """Set the random seed for reproducibility."""
41
- torch.manual_seed(seed)
42
- if torch.cuda.is_available():
43
- torch.cuda.manual_seed_all(seed)
44
- torch.backends.cudnn.deterministic = True
45
- torch.backends.cudnn.benchmark = False
46
- random.seed(seed)
47
-
48
-
49
- def tokenize_data(tokenizer, data, task, max_length):
50
- """Tokenize the data and return the input_ids and attention_mask."""
51
- tokenized_start = tokenizer(START_PROMPT_TASK[task])["input_ids"]
52
- tokenized_end = tokenizer(END_PROMPT_TASK[task])["input_ids"]
53
- if task == 2:
54
- tokenized_middle = tokenizer(MIDDLE_PROMPT_TASK2)["input_ids"]
55
-
56
- # Tokenize the data
57
- tokenized_data = []
58
- if task == 1 or task == 3:
59
- for sentence in data:
60
- tokenized_sentence = tokenizer(sentence + SEP)["input_ids"]
61
-
62
- # Concatenate the tokenized data
63
- concatted_data = (
64
- [tokenizer.bos_token_id]
65
- + tokenized_start
66
- + tokenized_sentence
67
- + tokenized_end
68
- )
69
-
70
- # Truncate the data
71
- concatted_data = concatted_data[:max_length]
72
-
73
- tokenized_data.append(concatted_data)
74
- elif task == 2:
75
- for line in data:
76
- data_a = line["a"]
77
- data_b = line["b"]
78
- tokenized_sentence_a = tokenizer(data_a + SEP)["input_ids"]
79
- tokenized_sentence_b = tokenizer(data_b + SEP)["input_ids"]
80
-
81
- # Concatenate the tokenized data
82
- concatted_data = (
83
- [tokenizer.bos_token_id]
84
- + tokenized_start
85
- + tokenized_sentence_a
86
- + tokenized_middle
87
- + tokenized_sentence_b
88
- + tokenized_end
89
- )
90
-
91
- # Truncate the data
92
- concatted_data = concatted_data[:max_length]
93
-
94
- tokenized_data.append(concatted_data)
95
-
96
- return tokenized_data
97
-
98
-
99
- def run_model_on_data(model_path, tokenizer_name, arguments):
100
- """Run the model on the data and save the predictions to a file."""
101
- # Load the model and tokenizer
102
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
103
- model.to(device)
104
- model.eval()
105
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
106
-
107
- # Load the data
108
- if arguments.task == 1 or arguments.task == 3:
109
- with open(arguments.input_file, "r") as file:
110
- data = file.read().splitlines()
111
- elif arguments.task == 2:
112
- with open(arguments.input_file, "r") as file:
113
- data = file.read().splitlines()
114
- data = [json.loads(line) for line in data]
115
-
116
- # Tokenize the data
117
- data_tokenized = tokenize_data(
118
- tokenizer, data, arguments.task, tokenizer.model_max_length
119
- )
120
- logging.info(f"Number of examples: {len(data_tokenized)}")
121
-
122
- # Run the model on the data
123
- predictions = []
124
- progress_bar = tqdm.tqdm(total=len(data_tokenized), desc="Running model on data")
125
-
126
- for input_ids in data_tokenized:
127
- progress_bar.update(1)
128
-
129
- # Generate the predictions
130
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
131
- input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
132
- output = model.generate(
133
- input_ids=input_ids_tensor, max_new_tokens=500, num_return_sequences=1
134
- )
135
-
136
- # Only get the part of the prediction that was generated
137
- prediction = tokenizer.decode(
138
- output[0][len(input_ids) :], skip_special_tokens=True
139
- )
140
- predictions.append(prediction)
141
-
142
- progress_bar.close()
143
-
144
- # Save the predictions to a file
145
- with open(arguments.output_file, "w") as file:
146
- if arguments.task == 1:
147
- # We want to include the original text in the output file
148
- counter = 0
149
- for prediction in predictions:
150
- file.write(data[counter] + "\n")
151
- file.write(prediction.split("\n\n")[0] + "\n\n")
152
- counter += 1
153
- else:
154
- for prediction in predictions:
155
- file.write(prediction.split("\n\n")[0] + "\n")
156
-
157
- logging.info(f"Predictions written to file: {arguments.output_file}")
158
 
159
 
160
- if __name__ == "__main__":
161
- # Parse the arguments
162
- parser = argparse.ArgumentParser()
163
- parser.add_argument(
164
- "--task",
165
- type=int,
166
- choices=(1,2,3),
167
- required=True,
168
- help="The task type (1, 2, or 3)",
169
- )
170
- parser.add_argument(
171
- "--input-file",
172
- type=str,
173
- required=True,
174
- help="The path to the input file with data to be corrected",
175
- )
176
- parser.add_argument(
177
- "--output-file",
178
- type=str,
179
- required=True,
180
- help="The path to the output file where the corrected data will be saved",
181
- )
182
- args = parser.parse_args()
183
 
184
- model_path = "."
185
- tokenizer_name = "AI-Sweden-Models/gpt-sw3-6.7b"
186
 
187
- set_seed(42)
188
- run_model_on_data(model_path, tokenizer_name, args)
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for running the model using the Hugging Face endpoint. An authorized Hugging Face API key is required.
3
+ """
4
+
5
+ import requests
6
+ import os
7
+
8
+ API_URL = "https://otaf5w2ge8huxngl.eu-west-1.aws.endpoints.huggingface.cloud"
9
+ # Set your Hugging Face API key as an environment variable
10
+ api_key = os.environ.get("HF_API_KEY")
11
+ headers = {
12
+ "Accept": "application/json",
13
+ "Authorization": f"Bearer {api_key}",
14
+ "Content-Type": "application/json",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
+ def query(payload):
19
+ response = requests.post(API_URL, headers=headers, json=payload)
20
+ return response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
22
 
23
+ output = query(
24
+ {
25
+ "inputs": "", # Can be left empty.
26
+ "input_a": "<text A>", # Required for all tasks.
27
+ "input_b": "<text B>", # Required for task 2 but not for task 1 or 3.
28
+ "task": 1 | 2 | 3, # Choose the task number.
29
+ "parameters": {
30
+ # Can be left empty
31
+ },
32
+ }
33
+ )
34
+ print(output)