replace-wrong-weights
#3
by
gardari
- opened
- README.md +20 -7
- config.json +1 -1
- handler.py +0 -120
- run_model.py +175 -29
- spiece.model +0 -3
- tokenizer_config.json +0 -5
README.md
CHANGED
@@ -10,21 +10,34 @@ ICELANDIC GPT-SW3 FOR SPELL AND GRAMMAR CHECKING
|
|
10 |
|
11 |
This is a model for correcting spelling and grammar errors in Icelandic text. It is a GPT-SW3 model (https://huggingface.co/AI-Sweden-Models/gpt-sw3-6.7b) finetuned on Icelandic and particularly on the spell and grammar checking task.
|
12 |
|
13 |
-
Provided here is the model along with a script for running it
|
14 |
|
15 |
-
To run the
|
16 |
|
17 |
> pip install -r requirements.txt
|
18 |
|
19 |
-
The current version of transformers includes a bug
|
20 |
|
21 |
-
|
|
|
|
|
22 |
- Task 1: The model evaluates one text with regards to e.g. grammar and spelling, and returns all errors in the input text as a list, with their position in the text and their corrections.
|
23 |
- Task 2: The model evaluates two texts and chooses which one is better with regards to e.g. grammar and spelling.
|
24 |
- Task 3: The model evaluates one text with regards to e.g. grammar and spelling, and returns a corrected version of the text.
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
|
30 |
-
Input text(s) and the task type need to be specified in the script.
|
|
|
10 |
|
11 |
This is a model for correcting spelling and grammar errors in Icelandic text. It is a GPT-SW3 model (https://huggingface.co/AI-Sweden-Models/gpt-sw3-6.7b) finetuned on Icelandic and particularly on the spell and grammar checking task.
|
12 |
|
13 |
+
Provided here is the model along with a script for running it locally.
|
14 |
|
15 |
+
To run the script you will need a python3 environment. Install the required dependencies by running
|
16 |
|
17 |
> pip install -r requirements.txt
|
18 |
|
19 |
+
The current version of transformers includes a bug which has to be fixed in the user's environment before the model can be run. To fix it, change "gpt-sw3-7b" in line no. 138 in transformers/models/gpt_sw3/tokenization_gpt_sw3.py to "gpt-sw3-6.7b".
|
20 |
|
21 |
+
After that you can run the script with an input file consisting of text to correct.
|
22 |
+
|
23 |
+
The model is fine-tuned on the following three tasks:
|
24 |
- Task 1: The model evaluates one text with regards to e.g. grammar and spelling, and returns all errors in the input text as a list, with their position in the text and their corrections.
|
25 |
- Task 2: The model evaluates two texts and chooses which one is better with regards to e.g. grammar and spelling.
|
26 |
- Task 3: The model evaluates one text with regards to e.g. grammar and spelling, and returns a corrected version of the text.
|
27 |
|
28 |
+
The script which runs the model takes the following three arguments:
|
29 |
+
- --task: A number (1-3) representing the intended task. The script includes prompts for each task.
|
30 |
+
- --input-file: A file containing text to be evaluated. The format of the input file differs between tasks, and is described further below.
|
31 |
+
- --output-file: A path to a desired output file to be created by the script. The format of the file differs between tasks, and is described further below.
|
32 |
+
|
33 |
+
An input file for tasks 1 and 3 should be a .txt file consisting of texts per line. An example of both files can be found under ./example_inputs.
|
34 |
+
An input file for task 2 should be a .jsonl file, where each line is a dictionary object showing two texts. Keys in the dictionary are "a" and "b" and texts to be evaluated are their values. An example of this file can be found under ./example_inputs.
|
35 |
+
|
36 |
+
All output files are .txt files and output examples for each task are shown in ./example_outputs. An output file for task 1 shows each text which was evaluated, followed by a list of corrections. Text outputs are separated by an empty line. An output file for task 2 shows 'A' or 'B' for which text is preferred, one choice per line. An output file for task 3 shows the corrected text, one text per line.
|
37 |
+
|
38 |
+
Run the script with
|
39 |
+
|
40 |
+
> python run_model.py --task 3 --input-file example_inputs/task3_example.txt --output-file example_outputs/task3_example.txt
|
41 |
|
42 |
+
The script we provide runs in CPU-only mode and should work on most systems that have enough RAM to load the model. Users that wish to accelerate their corrections with specialized hardware (eg GPUs) will need to install appropriate support packages for their hardware. We refer to the PyTorch documentation: https://pytorch.org/get-started/locally/ . After the extra packages are installed, add the `device` parameter to the pipeline constructor. See the HuggingFace documentation (https://huggingface.co/docs/transformers/main_classes/pipelines) for more details.
|
43 |
|
|
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "mideind/icelandic-gpt-sw3
|
3 |
"activation_function": "gelu",
|
4 |
"apply_query_key_layer_scaling": true,
|
5 |
"architectures": [
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "mideind/icelandic-gpt-sw3",
|
3 |
"activation_function": "gelu",
|
4 |
"apply_query_key_layer_scaling": true,
|
5 |
"architectures": [
|
handler.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
from typing import Dict, List, Any
|
2 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
-
import torch
|
4 |
-
import logging
|
5 |
-
|
6 |
-
logging.basicConfig(level=logging.INFO)
|
7 |
-
LOGGER = logging.getLogger(__name__)
|
8 |
-
|
9 |
-
|
10 |
-
# Prompts for the different tasks
|
11 |
-
START_PROMPT_TASK1 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
|
12 |
-
END_PROMPT_TASK1 = "Sérðu eitthvað sem mætti betur fara í textanum? Búðu til lista af öllum slíkum tilvikum þar sem hver lína tilgreinir hver villan er, hvar hún er, og hvað væri gert í staðinn fyrir villuna.\n\n"
|
13 |
-
|
14 |
-
START_PROMPT_TASK2 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.Ég er með tvær útgáfur af textanum, A og B, og önnur þeirra gæti verið betri en hin á einhvern hátt, t.d. hvað varðar stafsetningu, málfræði o.s.frv.\nHér er texti A:\n\n"
|
15 |
-
MIDDLE_PROMPT_TASK2 = "Hér er texti B:\n\n"
|
16 |
-
END_PROMPT_TASK2 = "Hvorn textann líst þér betur á?\n\n"
|
17 |
-
|
18 |
-
START_PROMPT_TASK3 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
|
19 |
-
END_PROMPT_TASK3 = "Reyndu nú að laga textann þannig að hann líti betur út, eins og þér finnst best við hæfi.\n\n"
|
20 |
-
|
21 |
-
START_PROMPT_TASK = {
|
22 |
-
1: START_PROMPT_TASK1,
|
23 |
-
2: START_PROMPT_TASK2,
|
24 |
-
3: START_PROMPT_TASK3,
|
25 |
-
}
|
26 |
-
END_PROMPT_TASK = {1: END_PROMPT_TASK1, 2: END_PROMPT_TASK2, 3: END_PROMPT_TASK3}
|
27 |
-
|
28 |
-
SEP = "\n\n"
|
29 |
-
|
30 |
-
|
31 |
-
class EndpointHandler:
|
32 |
-
def __init__(self, path=""):
|
33 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
34 |
-
path, device_map="auto", torch_dtype=torch.bfloat16
|
35 |
-
)
|
36 |
-
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
37 |
-
LOGGER.info(f"Inference model loaded from {path}")
|
38 |
-
LOGGER.info(f"Model device: {self.model.device}")
|
39 |
-
|
40 |
-
def check_valid_inputs(self, input_a: str, input_b: str, task: int) -> bool:
|
41 |
-
"""
|
42 |
-
Check if the inputs are valid
|
43 |
-
"""
|
44 |
-
if task not in [1, 2, 3]:
|
45 |
-
return False
|
46 |
-
if task == 1 or task == 3:
|
47 |
-
if input_a is None:
|
48 |
-
return False
|
49 |
-
elif task == 2:
|
50 |
-
if input_a is None or input_b is None:
|
51 |
-
return False
|
52 |
-
return True
|
53 |
-
|
54 |
-
def tokenize_input(self, input_a: str, input_b: str, task: int) -> List[int]:
|
55 |
-
"""
|
56 |
-
Tokenize the input
|
57 |
-
"""
|
58 |
-
if task == 1 or task == 3:
|
59 |
-
tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"]
|
60 |
-
tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"]
|
61 |
-
tokenized_sentence = self.tokenizer(input_a + SEP)["input_ids"]
|
62 |
-
concatted_data = (
|
63 |
-
[self.tokenizer.bos_token_id]
|
64 |
-
+ tokenized_start
|
65 |
-
+ tokenized_sentence
|
66 |
-
+ tokenized_end
|
67 |
-
)
|
68 |
-
elif task == 2:
|
69 |
-
tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"]
|
70 |
-
tokenized_middle = self.tokenizer(MIDDLE_PROMPT_TASK2)["input_ids"]
|
71 |
-
tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"]
|
72 |
-
tokenized_sentence_a = self.tokenizer(input_a + SEP)["input_ids"]
|
73 |
-
tokenized_sentence_b = self.tokenizer(input_b + SEP)["input_ids"]
|
74 |
-
concatted_data = (
|
75 |
-
[self.tokenizer.bos_token_id]
|
76 |
-
+ tokenized_start
|
77 |
-
+ tokenized_sentence_a
|
78 |
-
+ tokenized_middle
|
79 |
-
+ tokenized_sentence_b
|
80 |
-
+ tokenized_end
|
81 |
-
)
|
82 |
-
return concatted_data
|
83 |
-
|
84 |
-
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
85 |
-
"""
|
86 |
-
data args:
|
87 |
-
inputs (:obj: `str` | `PIL.Image` | `np.array`)
|
88 |
-
kwargs
|
89 |
-
Return:
|
90 |
-
A :obj:`list` | `dict`: will be serialized and returned
|
91 |
-
"""
|
92 |
-
LOGGER.info(f"Received data: {data}")
|
93 |
-
|
94 |
-
# Get inputs
|
95 |
-
input_a = data.pop("input_a", None)
|
96 |
-
input_b = data.pop("input_b", None)
|
97 |
-
task = data.pop("task", None)
|
98 |
-
parameters = data.pop("parameters", {})
|
99 |
-
|
100 |
-
# Check valid inputs
|
101 |
-
if not self.check_valid_inputs(input_a, input_b, task):
|
102 |
-
return [{"error": "Invalid inputs"}]
|
103 |
-
if "max_new_tokens" not in parameters and "max_length" not in parameters:
|
104 |
-
parameters["max_new_tokens"] = 512
|
105 |
-
|
106 |
-
# Tokenize the input
|
107 |
-
tokenized_input = self.tokenize_input(input_a, input_b, task)
|
108 |
-
|
109 |
-
# Move the input to the device
|
110 |
-
input_ids = torch.tensor(tokenized_input).to(self.model.device)
|
111 |
-
input_ids = input_ids.unsqueeze(0)
|
112 |
-
|
113 |
-
# Generate the output
|
114 |
-
output = self.model.generate(input_ids, **parameters)
|
115 |
-
|
116 |
-
# Decode only the new part of the output
|
117 |
-
decoded_output = self.tokenizer.decode(
|
118 |
-
output[0][len(tokenized_input) :], skip_special_tokens=True
|
119 |
-
).strip()
|
120 |
-
return [{"output": decoded_output}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_model.py
CHANGED
@@ -1,34 +1,180 @@
|
|
1 |
-
"""
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
import
|
6 |
-
import
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
"inputs": "", # Can be left empty.
|
26 |
-
"input_a": "<text A>", # Required for all tasks.
|
27 |
-
"input_b": "<text B>", # Required for task 2 but not for task 1 or 3.
|
28 |
-
"task": 1 | 2 | 3, # Choose the task number.
|
29 |
-
"parameters": {
|
30 |
-
# Can be left empty
|
31 |
-
},
|
32 |
-
}
|
33 |
-
)
|
34 |
-
print(output)
|
|
|
1 |
+
"""This script runs the trained model on data and saves the predictions to a file."""
|
2 |
+
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
import torch
|
5 |
+
import logging
|
6 |
+
import random
|
7 |
+
import tqdm
|
8 |
+
import json
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
# Set the logging level to info
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
|
14 |
+
# Set the device to GPU if available
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
logging.info(f"Device: {device}")
|
17 |
+
|
18 |
+
# Prompts for the different tasks
|
19 |
+
START_PROMPT_TASK1 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
|
20 |
+
END_PROMPT_TASK1 = "Sérðu eitthvað sem mætti betur fara í textanum? Búðu til lista af öllum slíkum tilvikum þar sem hver lína tilgreinir hver villan er, hvar hún er, og hvað væri gert í staðinn fyrir villuna.\n\n"
|
21 |
+
|
22 |
+
START_PROMPT_TASK2 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.Ég er með tvær útgáfur af textanum, A og B, og önnur þeirra gæti verið betri en hin á einhvern hátt, t.d. hvað varðar stafsetningu, málfræði o.s.frv.\nHér er texti A:\n\n"
|
23 |
+
MIDDLE_PROMPT_TASK2 = "Hér er texti B:\n\n"
|
24 |
+
END_PROMPT_TASK2 = "Hvorn textann líst þér betur á?\n\n"
|
25 |
+
|
26 |
+
START_PROMPT_TASK3 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
|
27 |
+
END_PROMPT_TASK3 = "Reyndu nú að laga textann þannig að hann líti betur út, eins og þér finnst best við hæfi.\n\n"
|
28 |
+
|
29 |
+
START_PROMPT_TASK = {
|
30 |
+
1: START_PROMPT_TASK1,
|
31 |
+
2: START_PROMPT_TASK2,
|
32 |
+
3: START_PROMPT_TASK3,
|
33 |
}
|
34 |
+
END_PROMPT_TASK = {1: END_PROMPT_TASK1, 2: END_PROMPT_TASK2, 3: END_PROMPT_TASK3}
|
35 |
+
|
36 |
+
SEP = "\n\n"
|
37 |
+
|
38 |
+
|
39 |
+
def set_seed(seed):
|
40 |
+
"""Set the random seed for reproducibility."""
|
41 |
+
torch.manual_seed(seed)
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
torch.cuda.manual_seed_all(seed)
|
44 |
+
torch.backends.cudnn.deterministic = True
|
45 |
+
torch.backends.cudnn.benchmark = False
|
46 |
+
random.seed(seed)
|
47 |
+
|
48 |
+
|
49 |
+
def tokenize_data(tokenizer, data, task, max_length):
|
50 |
+
"""Tokenize the data and return the input_ids and attention_mask."""
|
51 |
+
tokenized_start = tokenizer(START_PROMPT_TASK[task])["input_ids"]
|
52 |
+
tokenized_end = tokenizer(END_PROMPT_TASK[task])["input_ids"]
|
53 |
+
if task == 2:
|
54 |
+
tokenized_middle = tokenizer(MIDDLE_PROMPT_TASK2)["input_ids"]
|
55 |
+
|
56 |
+
# Tokenize the data
|
57 |
+
tokenized_data = []
|
58 |
+
if task == 1 or task == 3:
|
59 |
+
for sentence in data:
|
60 |
+
tokenized_sentence = tokenizer(sentence + SEP)["input_ids"]
|
61 |
+
|
62 |
+
# Concatenate the tokenized data
|
63 |
+
concatted_data = (
|
64 |
+
[tokenizer.bos_token_id]
|
65 |
+
+ tokenized_start
|
66 |
+
+ tokenized_sentence
|
67 |
+
+ tokenized_end
|
68 |
+
)
|
69 |
+
|
70 |
+
# Truncate the data
|
71 |
+
concatted_data = concatted_data[:max_length]
|
72 |
+
|
73 |
+
tokenized_data.append(concatted_data)
|
74 |
+
elif task == 2:
|
75 |
+
for line in data:
|
76 |
+
data_a = line["a"]
|
77 |
+
data_b = line["b"]
|
78 |
+
tokenized_sentence_a = tokenizer(data_a + SEP)["input_ids"]
|
79 |
+
tokenized_sentence_b = tokenizer(data_b + SEP)["input_ids"]
|
80 |
+
|
81 |
+
# Concatenate the tokenized data
|
82 |
+
concatted_data = (
|
83 |
+
[tokenizer.bos_token_id]
|
84 |
+
+ tokenized_start
|
85 |
+
+ tokenized_sentence_a
|
86 |
+
+ tokenized_middle
|
87 |
+
+ tokenized_sentence_b
|
88 |
+
+ tokenized_end
|
89 |
+
)
|
90 |
+
|
91 |
+
# Truncate the data
|
92 |
+
concatted_data = concatted_data[:max_length]
|
93 |
+
|
94 |
+
tokenized_data.append(concatted_data)
|
95 |
+
|
96 |
+
return tokenized_data
|
97 |
+
|
98 |
+
|
99 |
+
def run_model_on_data(model_path, tokenizer_name, arguments):
|
100 |
+
"""Run the model on the data and save the predictions to a file."""
|
101 |
+
# Load the model and tokenizer
|
102 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
|
103 |
+
model.to(device)
|
104 |
+
model.eval()
|
105 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
106 |
+
|
107 |
+
# Load the data
|
108 |
+
if arguments.task == 1 or arguments.task == 3:
|
109 |
+
with open(arguments.input_file, "r") as file:
|
110 |
+
data = file.read().splitlines()
|
111 |
+
elif arguments.task == 2:
|
112 |
+
with open(arguments.input_file, "r") as file:
|
113 |
+
data = file.read().splitlines()
|
114 |
+
data = [json.loads(line) for line in data]
|
115 |
+
|
116 |
+
# Tokenize the data
|
117 |
+
data_tokenized = tokenize_data(
|
118 |
+
tokenizer, data, arguments.task, tokenizer.model_max_length
|
119 |
+
)
|
120 |
+
logging.info(f"Number of examples: {len(data_tokenized)}")
|
121 |
+
|
122 |
+
# Run the model on the data
|
123 |
+
predictions = []
|
124 |
+
progress_bar = tqdm.tqdm(total=len(data_tokenized), desc="Running model on data")
|
125 |
+
|
126 |
+
for input_ids in data_tokenized:
|
127 |
+
progress_bar.update(1)
|
128 |
+
|
129 |
+
# Generate the predictions
|
130 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
131 |
+
input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
|
132 |
+
output = model.generate(
|
133 |
+
input_ids=input_ids_tensor, max_new_tokens=500, num_return_sequences=1
|
134 |
+
)
|
135 |
+
|
136 |
+
# Only get the part of the prediction that was generated
|
137 |
+
prediction = tokenizer.decode(
|
138 |
+
output[0][len(input_ids) :], skip_special_tokens=True
|
139 |
+
)
|
140 |
+
predictions.append(prediction)
|
141 |
+
|
142 |
+
progress_bar.close()
|
143 |
+
|
144 |
+
# Save the predictions to a file
|
145 |
+
with open(arguments.output_file, "w") as file:
|
146 |
+
if arguments.task == 1:
|
147 |
+
# We want to include the original text in the output file
|
148 |
+
counter = 0
|
149 |
+
for prediction in predictions:
|
150 |
+
file.write(data[counter] + "\n")
|
151 |
+
file.write(prediction.split("\n\n")[0] + "\n\n")
|
152 |
+
counter += 1
|
153 |
+
else:
|
154 |
+
for prediction in predictions:
|
155 |
+
file.write(prediction.split("\n\n")[0] + "\n")
|
156 |
+
|
157 |
+
logging.info(f"Predictions written to file: {arguments.output_file}")
|
158 |
|
159 |
|
160 |
+
if __name__ == "__main__":
|
161 |
+
# Parse the arguments
|
162 |
+
parser = argparse.ArgumentParser()
|
163 |
+
parser.add_argument("--task", type=int, help="The task type (1, 2, or 3)")
|
164 |
+
parser.add_argument(
|
165 |
+
"--input-file",
|
166 |
+
type=str,
|
167 |
+
help="The path to the input file with data to be corrected",
|
168 |
+
)
|
169 |
+
parser.add_argument(
|
170 |
+
"--output-file",
|
171 |
+
type=str,
|
172 |
+
help="The path to the output file where the corrected data will be saved",
|
173 |
+
)
|
174 |
+
args = parser.parse_args()
|
175 |
|
176 |
+
model_path = "./gpt-sw3-model"
|
177 |
+
tokenizer_name = "AI-Sweden-Models/gpt-sw3-6.7b"
|
178 |
|
179 |
+
set_seed(42)
|
180 |
+
run_model_on_data(model_path, tokenizer_name, args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spiece.model
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:8a76244a65ab35adda1b1cdb7b49be970d143bcc489d7b05d87551a12de78878
|
3 |
-
size 1071963
|
|
|
|
|
|
|
|
tokenizer_config.json
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"name_or_path": "AI-Sweden-Models/gpt-sw3-6.7b",
|
3 |
-
"bos_token": "<|endoftext|>",
|
4 |
-
"pad_token": "<unk>"
|
5 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|