Anonymous commited on
Commit
208053f
1 Parent(s): f18411b
Files changed (6) hide show
  1. app.py +211 -0
  2. generate_prompt.py +642 -0
  3. tasks/ner.py +132 -0
  4. tasks/nli.py +496 -0
  5. tasks/qa.py +770 -0
  6. tasks/summarization.py +149 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from openai import OpenAI
4
+ from generate_prompt import construct_generic_prompt, recommend_config
5
+
6
+ # Define available tasks and their corresponding datasets
7
+
8
+ QA = "QA"
9
+ SUMMARIZATION = "Summarization"
10
+ NLI = "NLI"
11
+ NER = "NER"
12
+
13
+ tasks_datasets = {
14
+ QA: ["XQuad", "Indicqa"],
15
+ SUMMARIZATION: ["XLSum", "HeSum"],
16
+ NLI: ["XNLI"],
17
+ NER: ["MasakaNER", "WikiANN"]
18
+ }
19
+
20
+ # List of all languages
21
+ languages = [
22
+ "English", "Spanish", "French", "German", "Chinese", "Japanese", "Korean", "Italian",
23
+ "Portuguese", "Russian", "Arabic", "Hindi", "Bengali", "Turkish", "Vietnamese", "Polish",
24
+ "Dutch", "Indonesian", "Malay", "Thai", "Greek", "Swedish", "Hungarian", "Finnish",
25
+ "Danish", "Norwegian", "Hebrew", "Czech", "Slovak", "Bulgarian", "Romanian", "Serbian",
26
+ "Croatian", "Ukrainian", "Lithuanian", "Latvian", "Estonian", "Filipino", "Icelandic",
27
+ "Irish", "Welsh", "Maltese", "Swahili", "Zulu", "Afrikaans"
28
+ ]
29
+
30
+
31
+
32
+ def get_datasets(task):
33
+ return tasks_datasets.get(task, [])
34
+
35
+
36
+ with gr.Blocks() as demo:
37
+ with gr.Row():
38
+ gr.Markdown("## Multilingual Prompt Generator")
39
+
40
+ with gr.Row():
41
+ with gr.Column(scale=2):
42
+ instruction = gr.Textbox(label="Instruction")
43
+ openai_key = gr.Textbox(label="OpenAI API key", type="password")
44
+ model = gr.Textbox(label="Model", placeholder="Enter model name (e.g., gpt-4-vision-preview)")
45
+ model_type = gr.Dropdown(label="Model Type", choices=["Multilingual", "English"], value='English')
46
+ config_recommendation = gr.Button("Recommend Configuration")
47
+ with gr.Column():
48
+ task = gr.Dropdown(label="Task", choices=list(tasks_datasets.keys()), value=QA)
49
+ language = gr.Dropdown(label="Source Language", choices=languages, value="English")
50
+ zero_shot = gr.Checkbox(label="Zero-shot", value=False)
51
+ with gr.Accordion("Prompt Configuration Selection", open=False):
52
+ prefix_selection = gr.Dropdown(["English", "Source"], label="prefix", value='English')
53
+ context_selection = gr.Dropdown(["English", "Source"], label="context", value='English')
54
+ examples_selection = gr.Dropdown(["English", "Source"], label="examples" , value='English')
55
+ output_selection = gr.Dropdown(["English", "Source"], label="output", value='English')
56
+ with gr.Accordion("Few Shot - Select Type of Examples ", open=False, visible=True) as few_shot:
57
+ dataset = gr.Dropdown(label="Dataset", choices=tasks_datasets[QA], value="XlSum")
58
+ num_examples = gr.Slider(label="Number of examples in context", minimum=1, maximum=10, step=1, value=3)
59
+ with gr.Row():
60
+ question = gr.Textbox(label="Question", visible=True)
61
+ context = gr.Textbox(label="Context", visible=True)
62
+ text = gr.Textbox(label="Text", visible=False)
63
+ sentence = gr.Textbox(label="Sentence", visible=False)
64
+ hypothesis = gr.Textbox(label="Hypothesis", visible=False)
65
+ premise = gr.Textbox(label="Premise", visible=False)
66
+ with gr.Row():
67
+ config_prompt = gr.Textbox(label="Recommended Configuration", interactive=False,
68
+ placeholder="Recommended Configuration for this scenerio")
69
+
70
+ generate_button = gr.Button("Generate Prompt")
71
+
72
+ with gr.Row():
73
+ prompt = gr.Textbox(label="Generated Prompt", interactive=False, placeholder="Generated prompt will appear here.")
74
+
75
+
76
+ def update_datasets(selected_task):
77
+ return gr.Dropdown(choices=get_datasets(selected_task))
78
+
79
+
80
+ def toggle_task_inputs(selected_task):
81
+ if selected_task == QA:
82
+ return (
83
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=False),
84
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
85
+ )
86
+ elif selected_task == SUMMARIZATION:
87
+ return (
88
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=True),
89
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
90
+ )
91
+ elif selected_task == NER:
92
+ return (
93
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
94
+ gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
95
+ )
96
+ else:
97
+ return (
98
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
99
+ gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
100
+ )
101
+
102
+
103
+ def toggle_num_examples(zero_shot_value):
104
+ # If zero_shot is True, hide the num_examples slider
105
+ return gr.update(visible=not zero_shot_value)
106
+
107
+ def update_language_selection(language):
108
+ return gr.update(choices=list({'English', language})), gr.update(choices=list({'English', language})), gr.update(choices=list({'English', language})), gr.update(choices=list({'English', language}))
109
+
110
+ def generatePrompt(instruction, num_examples, zero_shot,
111
+ task, selected_language, dataset, prefix_selection, context_selection, examples_selection, output_selection,
112
+ text, question, context, sentence, hypothesis, premise):
113
+
114
+ config = {'prefix': str.lower(prefix_selection), 'input': str.lower(context_selection), 'context': str.lower(examples_selection), 'output': str.lower(output_selection)}
115
+
116
+ if task == QA:
117
+ text_example = {
118
+ 'context': context,
119
+ 'question': question,
120
+ }
121
+ elif task == SUMMARIZATION:
122
+ text_example = {
123
+ 'text': text,
124
+ }
125
+ elif task == NER:
126
+ text_example = {
127
+ 'tokens': sentence,
128
+ }
129
+ else:
130
+ text_example = {
131
+ 'hypothesis': hypothesis,
132
+ 'premise': premise
133
+ }
134
+
135
+ print(text_example)
136
+ prompt = construct_generic_prompt(task, instruction, text_example, zero_shot, num_examples, selected_language, dataset, config)
137
+
138
+ return prompt
139
+
140
+
141
+ def respond(message, openai_key, url, chat_history, model, config_input, config_prefix, config_context,
142
+ config_output, task, dataset, language, num_examples, zero_shot):
143
+ os.environ["OPENAI_API_KEY"] = openai_key
144
+ client = OpenAI()
145
+
146
+ config = {
147
+ "input": config_input,
148
+ "prefix": config_prefix,
149
+ "context": config_context.split(', '),
150
+ "output": config_output,
151
+ "language": language,
152
+ "num_examples": num_examples,
153
+ "zero_shot": zero_shot
154
+ }
155
+
156
+ response = client.chat.completions.create(
157
+ model=model,
158
+ messages=[
159
+ {
160
+ "role": "user",
161
+ "content": [
162
+ {"type": "text", "text": message},
163
+ {"type": "image_url", "image_url": url},
164
+ {"type": "config", "config": config},
165
+ {"type": "task", "text": task},
166
+ {"type": "dataset", "text": dataset}
167
+ ],
168
+ },
169
+ ],
170
+ max_tokens=1000,
171
+ )
172
+
173
+ out = response.choices[0].message.content
174
+
175
+ chat_history.append((message, out))
176
+ return "", chat_history
177
+
178
+
179
+ # Bind functions to dropdown changes and button click
180
+ # task.change(fn=update_datasets, outputs=dataset)
181
+ language.change(fn=update_language_selection, inputs=language, outputs=[prefix_selection, context_selection, examples_selection, output_selection])
182
+
183
+ zero_shot.change(fn=toggle_num_examples, inputs=zero_shot, outputs=few_shot)
184
+ zero_shot.change(fn=toggle_num_examples, inputs=zero_shot, outputs=num_examples)
185
+ task.change(fn=update_datasets, inputs=task, outputs=dataset)
186
+ task.change(fn=toggle_task_inputs, inputs=task, outputs=[
187
+ question, context, text, sentence, hypothesis, premise,
188
+ ])
189
+ generate_button.click(
190
+ generatePrompt,
191
+ inputs=[
192
+ instruction, num_examples, zero_shot,
193
+ task, language, dataset, prefix_selection, context_selection, examples_selection, output_selection,
194
+ text, question, context, sentence, hypothesis, premise
195
+
196
+ ],
197
+ outputs=[prompt]
198
+ )
199
+
200
+ config_recommendation.click(
201
+ recommend_config,
202
+ inputs=[
203
+ task,
204
+ language,
205
+ model_type
206
+ ],
207
+ outputs=[config_prompt]
208
+ )
209
+
210
+ if __name__ == '__main__':
211
+ demo.launch()
generate_prompt.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import enum
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ import string
8
+ import sys
9
+ import unicodedata
10
+ from typing import Any, Dict, List, NewType, Union
11
+
12
+ import numpy as np
13
+ import openai
14
+ import pandas as pd
15
+ import requests
16
+ import yaml
17
+ from datasets import Dataset, load_dataset
18
+ from easygoogletranslate import EasyGoogleTranslate
19
+ from langchain.prompts import FewShotPromptTemplate, PromptTemplate
20
+ from tqdm import tqdm
21
+ from yaml.loader import SafeLoader
22
+
23
+ from selective_pre_translation.tasks import qa, summarization, ner, nli
24
+
25
+
26
+ # from models.model_completion import gpt3x_completion, gemini_completion
27
+
28
+ class LanguageType(enum.Enum):
29
+ Low = "Low"
30
+ High = "High"
31
+
32
+
33
+ class ModelType(enum.Enum):
34
+ English = "English"
35
+ Multilingual = "Multilingual"
36
+
37
+
38
+ def get_entities_gpt3_long(prompt):
39
+ response = openai.ChatCompletion.create(
40
+ engine="chatgpt", temperature=0, messages=[{"role": "user", "content": prompt}]
41
+ )
42
+ return response["choices"][0]["message"]["content"]
43
+
44
+
45
+ def gpt3x_completion(
46
+ prompt: Union[str, List[Dict[str, str]]],
47
+ ) -> str:
48
+ import os
49
+ import openai
50
+ os.environ["OPENAI_API_KEY"] = ''
51
+
52
+
53
+ def get_entities_chatGPT(final_prompt):
54
+ response = openai.ChatCompletion.create(
55
+ engine="gpt35-16k",
56
+ temperature=0,
57
+ messages=[
58
+ {"role": "user", "content": final_prompt}
59
+ ]
60
+ )
61
+ return response['choices'][0]['message']['content']
62
+
63
+ return get_entities_chatGPT(final_prompt=prompt)
64
+
65
+
66
+ def mixtral_completion(prompt):
67
+ url = "https://api.together.xyz/v1/chat/completions"
68
+
69
+ # Define your Together API key
70
+ together_api_key = "" # Replace with your actual API key
71
+
72
+ # Define the request payload
73
+ payload = {
74
+ "temperature": 0,
75
+ "max_tokens": 30,
76
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
77
+ "messages": [{"role": "user", "content": f"{prompt}"}],
78
+ }
79
+
80
+ # Define request headers
81
+ headers = {
82
+ "Authorization": f"Bearer {together_api_key}",
83
+ "Content-Type": "application/json",
84
+ }
85
+
86
+ # Send POST request
87
+ response = requests.post(url, json=payload, headers=headers)
88
+
89
+ # Check response status
90
+ if response.status_code == 200:
91
+ # Print the response content (API output)
92
+ return response.json()["choices"][0]["message"]["content"]
93
+ else:
94
+ # Print error message if request fails
95
+ print(f"Error: {response.status_code} - {response.text}")
96
+
97
+
98
+ XQUAD_LANG2CODES = {
99
+ "bengali": "bn",
100
+ "korean": "ko",
101
+ "swahili": "sw",
102
+ "english": "en",
103
+ "indonesian": "id",
104
+ "arabic": "ar",
105
+ "finnish": "fi",
106
+ "telugu": "te",
107
+ "russian": "ru",
108
+ "german": "de",
109
+ "greek": "el",
110
+ "hindi": "hi",
111
+ "vietnamese": "vi",
112
+ "romanian": "ro",
113
+ }
114
+
115
+ INDICQA_LANG2CODES = {
116
+ "indicqa": "as",
117
+ "bengali": "bn",
118
+ "gujarati": "gu",
119
+ "hindi": "hi",
120
+ "kannada": "kn",
121
+ "malayalam": "ml",
122
+ "marathi": "mr",
123
+ "odia": "or",
124
+ "punjabi": "pa",
125
+ "tamil": "ta",
126
+ "telugu": "te",
127
+ "assamese": "as",
128
+ }
129
+
130
+ PUNCT = {
131
+ chr(i)
132
+ for i in range(sys.maxunicode)
133
+ if unicodedata.category(chr(i)).startswith("P")
134
+ }.union(string.punctuation)
135
+ WHITESPACE_LANGS = ["en", "es", "hi", "vi", "de", "ar"]
136
+ MIXED_SEGMENTATION_LANGS = ["zh"]
137
+
138
+ TYDIQA_LANG2CODES = {
139
+ "bengali": "bn",
140
+ "korean": "ko",
141
+ "swahili": "sw",
142
+ "english": "en",
143
+ "indonesian": "id",
144
+ "arabic": "ar",
145
+ "finnish": "fi",
146
+ "telugu": "te",
147
+ "russian": "ru",
148
+ "assamese": "as",
149
+ "persian": "fa",
150
+ }
151
+
152
+ logger = logging.Logger("Xlsum_task")
153
+ LANGUAGE_TO_SUFFIX = {
154
+ "chinese_simplified": "zh-CN",
155
+ "french": "fr",
156
+ "portuguese": "pt",
157
+ "english": "en",
158
+ "arabic": "ar",
159
+ "hindi": "hi",
160
+ "indonesian": "id",
161
+ "amharic": "am",
162
+ "bengali": "bn",
163
+ "telugu": "te",
164
+ "burmese": "my",
165
+ "german": "de",
166
+ "greek": "el",
167
+ "tamil": "ta",
168
+ "assamese": "as",
169
+ "hindi": "hi",
170
+ "vietnamese": "vi",
171
+ "russian": "ru",
172
+ "telugu": "te",
173
+ "romanian": "ro",
174
+ "malayalam": "ml",
175
+ "persian": "fa",
176
+ }
177
+
178
+ PARAMS = NewType("PARAMS", Dict[str, Any])
179
+
180
+
181
+ def read_parameters(args_path) -> PARAMS:
182
+ with open(args_path) as f:
183
+ args = yaml.load(f, Loader=SafeLoader)
184
+ return args
185
+
186
+
187
+ def load_qa_dataset(dataset_name, lang, split, translate_test=False, limit=5):
188
+ if dataset_name == "indicqa":
189
+ if split != "train":
190
+ dataset = load_dataset(
191
+ "ai4bharat/IndicQA", f"indicqa.{INDICQA_LANG2CODES[lang]}"
192
+ )[split]
193
+ else:
194
+ dataset = load_dataset("squad_v2")[split]
195
+ elif dataset_name == "xquad":
196
+ if split != "train":
197
+ dataset = load_dataset("xquad", f"xquad.{XQUAD_LANG2CODES[lang]}")[
198
+ "validation"
199
+ ]
200
+ else:
201
+ dataset = load_dataset("squad")[split]
202
+ elif dataset_name == "tydiqa":
203
+ dataset = load_dataset("tydiqa", "secondary_task")[split]
204
+ dataset = dataset.map(
205
+ lambda example: {"lang": TYDIQA_LANG2CODES[example["id"].split("-")[0]]}
206
+ )
207
+ dataset = dataset.filter(lambda example: example["lang"] == lang)
208
+ elif dataset_name == "mlqa":
209
+ if split == "train":
210
+ print("No Training Data for MLQA, switching to validation!")
211
+ split = "validation"
212
+ if translate_test:
213
+ dataset_name = f"mlqa-translate-test.{lang}"
214
+ else:
215
+ dataset_name = f"mlqa.{lang}.{lang}"
216
+
217
+ dataset = load_dataset("mlqa", dataset_name)[split]
218
+
219
+ else:
220
+ raise NotImplementedError()
221
+ return dataset.select(np.arange(limit))
222
+
223
+
224
+ def construct_prompt(
225
+ instruction: str,
226
+ test_example: dict,
227
+ ic_examples: List[dict],
228
+ zero_shot: bool,
229
+ lang: str,
230
+ config: Dict[Any, Any],
231
+ ):
232
+ example_prompt = PromptTemplate(
233
+ input_variables=["context", "question", "answers"],
234
+ template="Context: {context}\nQuestion: {question}\n" "Answers: {answers}",
235
+ )
236
+
237
+ zero_shot_template = (
238
+ f"""{instruction}""" + "\n<Context>: {context} \n<Question>: {question} " ""
239
+ )
240
+
241
+ prompt = (
242
+ FewShotPromptTemplate(
243
+ examples=ic_examples,
244
+ prefix=instruction,
245
+ example_prompt=example_prompt,
246
+ suffix="<Context>: {context} \n<Question>: {question} \nAnswers: ?",
247
+ input_variables=["question", "context"],
248
+ )
249
+ if not zero_shot
250
+ else PromptTemplate(
251
+ input_variables=["question", "context"], template=zero_shot_template
252
+ )
253
+ )
254
+
255
+ label = test_example["answers"]
256
+ if config["input"] != lang:
257
+ test_example = _translate_example(
258
+ example=test_example, src_language=lang, target_language=config["input"]
259
+ )
260
+
261
+ return (
262
+ prompt.format(
263
+ question=test_example["question"], context=test_example["context"]
264
+ ),
265
+ label,
266
+ )
267
+
268
+
269
+ def dump_metrics(
270
+ lang: str, config: Dict[str, str], f1: float, em: float, metric_logger_path: str
271
+ ):
272
+ # Check if the metric logger file exists
273
+ file_exists = os.path.exists(metric_logger_path)
274
+
275
+ # Open the CSV file in append mode
276
+ with open(metric_logger_path, "a", newline="") as f:
277
+ csvwriter = csv.writer(f, delimiter=",")
278
+
279
+ # Write header row if the file is newly created
280
+ if not file_exists:
281
+ header = ["Language", "Prefix", "Input", "Context", "Output", "F1", "Em"]
282
+ csvwriter.writerow(header)
283
+
284
+ csvwriter.writerow(
285
+ [
286
+ lang,
287
+ config["prefix"],
288
+ config["input"],
289
+ config["context"][0],
290
+ config["output"],
291
+ f1,
292
+ em,
293
+ ]
294
+ )
295
+
296
+
297
+ def dump_predictions(idx, response, label, response_logger_file):
298
+ obj = {"q_idx": idx, "prediction": response, "label": label}
299
+ with open(response_logger_file, "a") as f:
300
+ f.write(json.dumps(obj, ensure_ascii=False) + "\n")
301
+
302
+
303
+ def _translate_instruction(basic_instruction: str, target_language: str) -> str:
304
+ translator = EasyGoogleTranslate(
305
+ source_language="en",
306
+ target_language=LANGUAGE_TO_SUFFIX[target_language],
307
+ timeout=50,
308
+ )
309
+ return translator.translate(basic_instruction)
310
+
311
+
312
+ def _translate_prediction_to_output_language(
313
+ prediction: str, prediction_language: str, output_language: str
314
+ ) -> str:
315
+ translator = EasyGoogleTranslate(
316
+ source_language=LANGUAGE_TO_SUFFIX[prediction_language],
317
+ target_language=LANGUAGE_TO_SUFFIX[output_language],
318
+ timeout=10,
319
+ )
320
+ return translator.translate(prediction)
321
+
322
+
323
+ def create_instruction(lang: str, expected_output: str):
324
+ basic_instruction = (
325
+ "Answer to the <Question> below, based only to the given <Context>, Follow these instructions:\n"
326
+ "1. The answer should include only words from the given context\n"
327
+ "2. The answer must include up to 5 words\n"
328
+ "3. The answer Should be the shortest as possible\n"
329
+ f"4. The answer must be in {expected_output} only!, not another language!!!"
330
+ )
331
+ return (
332
+ basic_instruction
333
+ if lang == "english"
334
+ else _translate_instruction(basic_instruction, target_language=lang)
335
+ )
336
+
337
+
338
+ def _translate_example(
339
+ example: Dict[str, str], src_language: str, target_language: str
340
+ ):
341
+ translator = EasyGoogleTranslate(
342
+ source_language=LANGUAGE_TO_SUFFIX[str(src_language).lower()],
343
+ target_language=LANGUAGE_TO_SUFFIX[str(target_language).lower()],
344
+ timeout=30,
345
+ )
346
+
347
+ return {
348
+ "question": translator.translate(example["question"]),
349
+ "context": translator.translate(example["context"][:2000])
350
+ + translator.translate(example["context"][2000:4000])
351
+ + translator.translate(example["context"][4000:6000]),
352
+ "answers": translator.translate(example["answers"][0]),
353
+ }
354
+ # except Exception as e:
355
+ # print(example["text"])
356
+ # print(example["summary"])
357
+ # print(e)
358
+
359
+
360
+ def choose_few_shot_examples(
361
+ train_dataset: Dataset,
362
+ few_shot_size: int,
363
+ context: List[str],
364
+ selection_criteria: str,
365
+ lang: str,
366
+ ) -> List[Dict[str, Union[str, int]]]:
367
+ """Selects few-shot examples from training datasets
368
+
369
+ Args:
370
+ train_dataset (Dataset): Training Dataset
371
+ few_shot_size (int): Number of few-shot examples
372
+ selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k]
373
+
374
+ Returns:
375
+ List[Dict[str, Union[str, int]]]: Selected examples
376
+ """
377
+ selected_examples = []
378
+
379
+ example_idxs = []
380
+ if selection_criteria == "first_k":
381
+ example_idxs = list(range(few_shot_size))
382
+ elif selection_criteria == "random":
383
+ example_idxs = (
384
+ np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
385
+ .astype(int)
386
+ .tolist()
387
+ )
388
+
389
+ ic_examples = [
390
+ {
391
+ "question": train_dataset[idx]["question"],
392
+ "context": train_dataset[idx]["context"],
393
+ "answers": train_dataset[idx]["answers"]["text"],
394
+ }
395
+ for idx in example_idxs
396
+ ]
397
+
398
+ for idx, ic_language in enumerate(context):
399
+ (
400
+ selected_examples.append(ic_examples[idx])
401
+ if ic_language == lang
402
+ else (
403
+ selected_examples.append(
404
+ _translate_example(
405
+ example=ic_examples[idx],
406
+ src_language=lang,
407
+ target_language=ic_language,
408
+ )
409
+ )
410
+ )
411
+ )
412
+
413
+ return selected_examples
414
+
415
+
416
+ def normalize_answer(s):
417
+ """Lower text and remove punctuation, articles and extra whitespace."""
418
+
419
+ def remove_articles(text):
420
+ return re.sub(r"\b(a|an|the)\b", " ", text)
421
+
422
+ def white_space_fix(text):
423
+ return " ".join(text.split())
424
+
425
+ def remove_punc(text):
426
+ exclude = set(PUNCT) # set(string.punctuation)
427
+ return "".join(ch for ch in text if ch not in exclude)
428
+
429
+ def lower(text):
430
+ return text.lower()
431
+
432
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
433
+
434
+
435
+ def process_test_example(
436
+ test_data, config_header, idx, test_example, config, zero_shot, lang, params
437
+ ):
438
+ try:
439
+ # Your existing code for processing each test example
440
+ instruction = create_instruction(
441
+ lang=config["prefix"], expected_output=config["output"]
442
+ )
443
+ text_example = {
444
+ "question": test_example["question"],
445
+ "context": test_example["context"],
446
+ "answers": test_example["answers"]["text"],
447
+ }
448
+
449
+ ic_examples = []
450
+ if not zero_shot:
451
+ ic_examples = choose_few_shot_examples(
452
+ train_dataset=test_data,
453
+ few_shot_size=len(config["context"]),
454
+ context=config["context"],
455
+ selection_criteria="random",
456
+ lang=params["selected_language"],
457
+ )
458
+
459
+ prompt, label = construct_prompt(
460
+ instruction=instruction,
461
+ test_example=text_example,
462
+ ic_examples=ic_examples,
463
+ zero_shot=zero_shot,
464
+ lang=lang,
465
+ config=config,
466
+ )
467
+
468
+ pred = gpt3x_completion(prompt=prompt)
469
+ print(pred)
470
+
471
+ logger.info("Saving prediction to persistent volume")
472
+ os.makedirs(
473
+ f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True
474
+ )
475
+ dump_predictions(
476
+ idx=idx,
477
+ response=pred,
478
+ label=label,
479
+ response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv",
480
+ )
481
+ except Exception as e:
482
+ # Handle exceptions here
483
+ print(f"Error processing example {idx}: {e}")
484
+
485
+
486
+ def run_one_configuration(selected_language, config, zero_shot, dataset_name, limit=10):
487
+ test_data = load_qa_dataset(
488
+ dataset_name=dataset_name,
489
+ lang=selected_language,
490
+ split="validation" if dataset_name == "xquad" else "test",
491
+ limit=limit,
492
+ )
493
+
494
+ for idx, test_example in (pbar := tqdm(enumerate(test_data))):
495
+ try:
496
+ instruction = create_instruction(
497
+ lang=config["prefix"], expected_output=config["output"]
498
+ )
499
+ text_example = {
500
+ "question": test_example["question"],
501
+ "context": test_example["context"],
502
+ "answers": test_example["answers"]["text"],
503
+ }
504
+
505
+ ic_examples = []
506
+ if not zero_shot:
507
+ ic_examples = choose_few_shot_examples(
508
+ train_dataset=test_data,
509
+ few_shot_size=len(config["context"]),
510
+ context=config["context"],
511
+ selection_criteria="random",
512
+ lang=selected_language,
513
+ )
514
+
515
+ prompt, label = construct_prompt(
516
+ instruction=instruction,
517
+ test_example=text_example,
518
+ ic_examples=ic_examples,
519
+ zero_shot=zero_shot,
520
+ lang=selected_language,
521
+ config=config,
522
+ )
523
+
524
+ pred = gpt3x_completion(prompt=prompt)
525
+
526
+ return pred
527
+
528
+ except Exception as e:
529
+ print(f"Found an exception {e}, continue to the next example")
530
+ continue
531
+
532
+
533
+ QA = "QA"
534
+ SUMMARIZATION = "Summarization"
535
+ NLI = "NLI"
536
+ NER = "NER"
537
+
538
+
539
+ def construct_generic_prompt(task, instruction, test_example, zero_shot, num_examples, selected_language, dataset,
540
+ config):
541
+ print(task)
542
+ if task == SUMMARIZATION:
543
+ prompt = summarization.construct_prompt(
544
+ instruction=instruction,
545
+ test_example=test_example,
546
+ zero_shot=zero_shot,
547
+ dataset=dataset,
548
+ num_examples=num_examples,
549
+ lang=str(selected_language).lower(),
550
+ config=config,
551
+ )
552
+ elif task == NER:
553
+ prompt = ner.construct_prompt(
554
+ instruction=instruction,
555
+ test_example=test_example,
556
+ zero_shot=zero_shot,
557
+ num_examples=num_examples,
558
+ lang=str(selected_language).lower(),
559
+ config=config,
560
+ )
561
+ elif task == QA:
562
+ prompt = qa.construct_prompt(
563
+ instruction=instruction,
564
+ test_example=test_example,
565
+ zero_shot=zero_shot,
566
+ num_examples=num_examples,
567
+ lang=str(selected_language).lower(),
568
+ config=config,
569
+ # dataset_name=dataset
570
+ )
571
+ else:
572
+ prompt = nli.construct_prompt(
573
+ instruction=instruction,
574
+ test_example=test_example,
575
+ zero_shot=zero_shot,
576
+ num_examples=num_examples,
577
+ lang=str(selected_language).lower(),
578
+ config=config,
579
+ )
580
+ return prompt
581
+
582
+
583
+ def _get_language_type(language: str):
584
+ df = pd.read_csv("utils/languages_by_word_count.csv")
585
+ number_of_words = df[df['Language'] == language]['number of words'].iloc[0]
586
+ print(number_of_words)
587
+ return LanguageType.Low if number_of_words < 150276400 else LanguageType.High
588
+
589
+
590
+ class Config:
591
+ def __init__(self, prefix="source", context="source", examples="source", output="source"):
592
+ self.prefix = prefix
593
+ self.context = context
594
+ self.examples = examples
595
+ self.output = output
596
+
597
+ def set(self, prefix=None, context=None, examples=None, output=None):
598
+ if prefix: self.prefix = prefix
599
+ if context: self.context = context
600
+ if examples: self.examples = examples
601
+ if output: self.output = output
602
+
603
+ def to_dict(self):
604
+ return {
605
+ 'prefix': self.prefix,
606
+ 'context': self.context,
607
+ 'examples': self.examples,
608
+ 'output': self.output
609
+ }
610
+
611
+
612
+ def recommend_config(task, lang, model_type):
613
+ print(task)
614
+ print(model_type)
615
+ language_type = _get_language_type(lang)
616
+ config = Config()
617
+ print(language_type)
618
+ if task == QA:
619
+ if model_type == ModelType.English.value:
620
+ config.set(prefix='source', context='source', examples='source', output='source')
621
+ else:
622
+ config.set(prefix='english', context='source', examples='source', output='source')
623
+ if task == NER:
624
+ if model_type == ModelType.English.value:
625
+ config.set(prefix='source', context='source', examples='source', output='source')
626
+ elif language_type == LanguageType.High:
627
+ config.set(prefix='english', context='source', examples='source', output='source')
628
+ else:
629
+ config.set(prefix='english', context='source', examples='source', output='english')
630
+ if task == NLI:
631
+ if model_type == ModelType.English.value:
632
+ config.set(prefix='source', context='source', examples='source', output='source')
633
+ elif language_type == LanguageType.High:
634
+ print("here")
635
+ config.set(prefix='english', context='source', examples='english')
636
+ else:
637
+ print("here1")
638
+ config.set(prefix='english', context='english', examples='english')
639
+ if task == SUMMARIZATION:
640
+ config.set(context='english')
641
+
642
+ return config.to_dict()
tasks/ner.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any
2
+
3
+ from easygoogletranslate import EasyGoogleTranslate
4
+ from langchain.prompts import PromptTemplate, FewShotPromptTemplate
5
+
6
+ LANGUAGE_TO_GOOGLE_TRANSLATE_MARK = {
7
+ "english": "en",
8
+ "bambara": "bm",
9
+ "ewe": "ee",
10
+ "hausa": "ha",
11
+ "igbo": "ig",
12
+ "kinyarwanda": "rw",
13
+ "chichewa": "ny",
14
+ "twi": "ak",
15
+ "yoruba": "yo",
16
+ "slovak": "sk",
17
+ "serbian": "sr",
18
+ "swedish": "sv",
19
+ "vietnamese": "vi",
20
+ "italian": "it",
21
+ "portuguese": "pt",
22
+ "chinese": "zh",
23
+ "english": "en",
24
+ "french": "fr"
25
+
26
+
27
+
28
+ }
29
+
30
+ LANGAUGE_TO_PREFIX = {
31
+ "bambara": "bam",
32
+ "ewe": "ewe",
33
+ "fon": "fon",
34
+ "hausa": "hau",
35
+ "igbo": "ibo",
36
+ "kinyarwanda": "kin",
37
+ "chichewa": "nya",
38
+ "twi": "twi",
39
+ "yoruba": "yor",
40
+ "slovak": "sk",
41
+ "serbian": "sr",
42
+ "swedish": "sv",
43
+ "vietnamese": "vi",
44
+ "italian": "it",
45
+ "portuguese": "pt",
46
+ "chinese": "zh",
47
+ "english": "en",
48
+ "french": "fr"
49
+ }
50
+
51
+
52
+ def _translate_instruction(basic_instruction: str, target_language: str) -> str:
53
+ translator = EasyGoogleTranslate(
54
+ source_language="en",
55
+ target_language=LANGAUGE_TO_PREFIX[target_language],
56
+ timeout=10,
57
+ )
58
+ return translator.translate(basic_instruction)
59
+
60
+
61
+ def create_instruction(lang: str, expected_output: str):
62
+ basic_instruction = f"""You are an NLP assistant whose
63
+ purpose is to perform Named Entity Recognition
64
+ (NER). You will need to give each entity a tag, from the following:
65
+ PER means a person, ORG means organization.
66
+ LOC means a location entity.
67
+ The output should be a list of tuples of the format:
68
+ ['Tag: Entity', 'Tag: Entity'] for each entity in the sentence.
69
+ The entities should be in {expected_output} language"""
70
+
71
+ return (
72
+ basic_instruction
73
+ if lang == "english"
74
+ else _translate_instruction(basic_instruction, target_language=lang)
75
+ )
76
+
77
+ def construct_prompt(
78
+ instruction: str,
79
+ test_example: dict,
80
+ zero_shot: bool,
81
+ dataset: str,
82
+ num_examples: int,
83
+ lang: str,
84
+ config: Dict[str, str],
85
+ ):
86
+ if not instruction:
87
+ print(lang)
88
+ instruction = create_instruction(lang, config['prefix'])
89
+
90
+ example_prompt = PromptTemplate(
91
+ input_variables=["summary", "text"], template="Text: {text}\nSummary: {summary}"
92
+ )
93
+
94
+ zero_shot_template = f"""{instruction}""" + "\n Input: {text} " ""
95
+
96
+ test_data = load_xlsum_data(lang=lang, split="test", limit=100)
97
+
98
+ print(test_data)
99
+ print(num_examples)
100
+ print(lang)
101
+ ic_examples = []
102
+ if not zero_shot:
103
+
104
+ ic_examples = choose_few_shot_examples(
105
+ train_dataset=test_data,
106
+ few_shot_size=num_examples,
107
+ context=[config["context"]] * num_examples,
108
+ selection_criteria="random",
109
+ lang=lang,
110
+ )
111
+
112
+ prompt = (
113
+ FewShotPromptTemplate(
114
+ examples=ic_examples,
115
+ prefix=instruction,
116
+ example_prompt=example_prompt,
117
+ suffix="<Text>: {text}",
118
+ input_variables=["text"],
119
+ )
120
+ if not zero_shot
121
+ else PromptTemplate(input_variables=["text"], template=zero_shot_template)
122
+ )
123
+
124
+ print("lang", lang)
125
+ print(config["input"] , lang)
126
+ if config["input"] != lang:
127
+ test_example = _translate_example(
128
+ example=test_example, src_language=lang, target_language=config["input"]
129
+ )
130
+
131
+ print("test_example", prompt)
132
+ return prompt.format(text=test_example["text"])
tasks/nli.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import time
3
+
4
+ import csv
5
+ import json
6
+ import multiprocessing as mp
7
+ import os
8
+ from typing import Any, Dict, List, NewType, Optional, Union
9
+ import openai
10
+ import numpy as np
11
+ import requests
12
+ import yaml
13
+ from datasets import Dataset, DatasetDict, load_dataset
14
+ from easygoogletranslate import EasyGoogleTranslate
15
+ from langchain.prompts import FewShotPromptTemplate, PromptTemplate
16
+ from tqdm import tqdm
17
+ from yaml.loader import SafeLoader
18
+
19
+ LANGUAGE_TO_SUFFIX = {
20
+ "chinese_simplified": "zh-CN",
21
+ "french": "fr",
22
+ "portuguese": "pt",
23
+ "english": "en",
24
+ "arabic": "ar",
25
+ "hindi": "hi",
26
+ "indonesian": "id",
27
+ "amharic": "am",
28
+ "bengali": "bn",
29
+ "burmese": "my",
30
+ "chinese": "zh-CN",
31
+ "swahili": "sw",
32
+ "bulgarian": "bg",
33
+ "thai": "th",
34
+ "urdu": "ur",
35
+ "turkish": "tr",
36
+ "spanish": "es",
37
+ "chinese": "zh",
38
+ "greek": "el",
39
+ "german": "de"
40
+
41
+
42
+ }
43
+
44
+ NUMBER_TO_TAG = {0: "entailment", 1: "neutral", 2: "contradiction"}
45
+
46
+ PARAMS = NewType("PARAMS", Dict[str, Any])
47
+
48
+
49
+ def gemini_completion(prompt):
50
+ # Define the endpoint URL
51
+ genai.configure(api_key="AIzaSyBnghQNoOS2qiacHjqutK1RpPV5y-gv7Pg")
52
+ model = genai.GenerativeModel("models/gemini-1.0-pro-latest")
53
+ return model.generate_content(prompt).text
54
+
55
+
56
+
57
+ def gpt3x_completion(
58
+ prompt: Union[str, List[Dict[str, str]]],
59
+ model: str = "chatgpt",
60
+ # run_details: Any = {},
61
+ # num_evals_per_sec: int = 2,
62
+ # **model_params,
63
+ ) -> str:
64
+ import os
65
+ import openai
66
+ os.environ["OPENAI_API_KEY"] = ''
67
+
68
+
69
+ def get_entities_chatGPT(final_prompt):
70
+ response = openai.ChatCompletion.create(
71
+ engine="gpt35-16k",
72
+ temperature=0,
73
+ messages=[
74
+ {"role": "user", "content": final_prompt}
75
+ ]
76
+ )
77
+ return response['choices'][0]['message']['content']
78
+
79
+ return get_entities_chatGPT(final_prompt=prompt)
80
+
81
+ def mixtral_completion(prompt):
82
+ url = "https://api.together.xyz/v1/chat/completions"
83
+
84
+ # Define your Together API key
85
+ together_api_key = "" # Replace with your actual API key
86
+
87
+ # Define the request payload
88
+ payload = {
89
+ "temperature": 0,
90
+ "max_tokens": 30,
91
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
92
+ "messages": [{"role": "user", "content": f"{prompt}"}],
93
+ }
94
+
95
+ # Define request headers
96
+ headers = {
97
+ "Authorization": f"Bearer {together_api_key}",
98
+ "Content-Type": "application/json",
99
+ }
100
+
101
+ # Send POST request
102
+ response = requests.post(url, json=payload, headers=headers)
103
+
104
+ # Check response status
105
+ if response.status_code == 200:
106
+ # Print the response content (API output)
107
+ return response.json()["choices"][0]["message"]["content"]
108
+ else:
109
+ # Print error message if request fails
110
+ print(f"Error: {response.status_code} - {response.text}")
111
+
112
+
113
+ def read_parameters(args_path) -> PARAMS:
114
+ with open(args_path) as f:
115
+ args = yaml.load(f, Loader=SafeLoader)
116
+ return args
117
+
118
+
119
+ def get_key(key_path):
120
+ with open(key_path) as f:
121
+ key = f.read().split("\n")[0]
122
+ return key
123
+
124
+
125
+ def _translate_example(
126
+ example: Dict[str, str], src_language: str, target_language: str
127
+ ):
128
+ translator = EasyGoogleTranslate(
129
+ source_language=LANGUAGE_TO_SUFFIX[src_language],
130
+ target_language=LANGUAGE_TO_SUFFIX[target_language],
131
+ timeout=30,
132
+ )
133
+ try:
134
+ return {
135
+ "premise": translator.translate(example["premise"]),
136
+ "hypothesis": translator.translate(example["hypothesis"]),
137
+ "label": "",
138
+ }
139
+ except Exception as e:
140
+ print(e)
141
+
142
+
143
+ def choose_few_shot_examples(
144
+ train_dataset: Dataset,
145
+ few_shot_size: int,
146
+ context: List[str],
147
+ selection_criteria: str,
148
+ lang: str,
149
+ ) -> List[Dict[str, Union[str, int]]]:
150
+ """Selects few-shot examples from training datasets
151
+
152
+ Args:
153
+ train_dataset (Dataset): Training Dataset
154
+ few_shot_size (int): Number of few-shot examples
155
+ selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k]
156
+
157
+ Returns:
158
+ List[Dict[str, Union[str, int]]]: Selected examples
159
+ """
160
+ selected_examples = []
161
+
162
+ example_idxs = []
163
+ if selection_criteria == "first_k":
164
+ example_idxs = list(range(few_shot_size))
165
+ elif selection_criteria == "random":
166
+ example_idxs = (
167
+ np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
168
+ .astype(int)
169
+ .tolist()
170
+ )
171
+
172
+ ic_examples = [train_dataset[idx] for idx in example_idxs]
173
+
174
+ ic_examples = [
175
+ {
176
+ "premise": example["premise"],
177
+ "hypothesis": example["hypothesis"],
178
+ "label": NUMBER_TO_TAG[example["label"]],
179
+ }
180
+ for example in ic_examples
181
+ ]
182
+
183
+ for idx, ic_language in enumerate(context):
184
+ (
185
+ selected_examples.append(ic_examples[idx])
186
+ if ic_language == lang
187
+ else (
188
+ selected_examples.append(
189
+ _translate_example(
190
+ example=ic_examples[idx],
191
+ src_language=lang,
192
+ target_language=ic_language,
193
+ )
194
+ )
195
+ )
196
+ )
197
+
198
+ return selected_examples
199
+
200
+
201
+ def load_xnli_dataset(
202
+ dataset_name: str,
203
+ lang: str,
204
+ split: str,
205
+ limit: int = 200,
206
+ ) -> Union[Dataset, DatasetDict]:
207
+ """
208
+ Args:
209
+ lang (str): Language for which xnli dataset is to be loaded
210
+ split (str): Train test of validation split of the model to load
211
+ dataset_frac (float): Fraction of examples to load. Defaults to 1.0
212
+
213
+ Returns:
214
+ Union[Dataset, DatasetDict]: huggingface dataset object
215
+ """
216
+ if dataset_name == "indicxnli": ##PJ:To add except hindi
217
+ dataset = load_dataset("Divyanshu/indicxnli", LANGUAGE_TO_SUFFIX[lang])[split]
218
+ else:
219
+ dataset = load_dataset("xnli", LANGUAGE_TO_SUFFIX[lang])[split]
220
+ return dataset.select(np.arange(limit))
221
+
222
+
223
+ def construct_prompt(
224
+ instruction: str, test_example: dict, ic_examples: List[dict], zero_shot: bool
225
+ ):
226
+ example_prompt = PromptTemplate(
227
+ input_variables=["premise", "hypothesis", "label"],
228
+ template="Premise: {premise}\n Hypothesis: {hypothesis} \n Label{label}",
229
+ )
230
+
231
+ zero_shot_template = (
232
+ f"""{instruction}""" + "\n hypothesis: {hypothesis} + \n Premise: {premise}" ""
233
+ )
234
+
235
+ prompt = (
236
+ FewShotPromptTemplate(
237
+ examples=ic_examples,
238
+ prefix=instruction,
239
+ example_prompt=example_prompt,
240
+ suffix="Premise: {premise} \n Hypothesis: {hypothesis}",
241
+ input_variables=["hypothesis", "premise"],
242
+ )
243
+ if not zero_shot
244
+ else PromptTemplate(
245
+ input_variables=["hypothesis", "premise"], template=zero_shot_template
246
+ )
247
+ )
248
+
249
+ return (
250
+ prompt.format(
251
+ hypothesis=test_example["hypothesis"], premise=test_example["premise"]
252
+ ),
253
+ test_example["label"],
254
+ )
255
+
256
+
257
+ def dump_metrics(
258
+ lang: str,
259
+ config: Dict[str, str],
260
+ r1: float,
261
+ r2: float,
262
+ rL: float,
263
+ metric_logger_path: str,
264
+ ):
265
+ # Check if the metric logger file exists
266
+ file_exists = os.path.exists(metric_logger_path)
267
+
268
+ # Open the CSV file in append mode
269
+ with open(metric_logger_path, "a", newline="") as f:
270
+ csvwriter = csv.writer(f, delimiter=",")
271
+
272
+ # Write header row if the file is newly created
273
+ if not file_exists:
274
+ header = [
275
+ "Language",
276
+ "Prefix",
277
+ "Input",
278
+ "Context",
279
+ "Output",
280
+ "R1",
281
+ "R2",
282
+ "RL",
283
+ ]
284
+ csvwriter.writerow(header)
285
+
286
+ csvwriter.writerow(
287
+ [
288
+ lang,
289
+ config["prefix"],
290
+ config["input"],
291
+ config["context"][0],
292
+ config["output"],
293
+ r1,
294
+ r2,
295
+ rL,
296
+ ]
297
+ )
298
+
299
+
300
+ def dump_predictions(idx, response, label, response_logger_file):
301
+ obj = {"q_idx": idx, "prediction": response, "label": label}
302
+ with open(response_logger_file, "a") as f:
303
+ f.write(json.dumps(obj, ensure_ascii=False) + "\n")
304
+
305
+
306
+ def compute_rouge(scorer, pred, label):
307
+ score = scorer.score(pred, label)
308
+ return score["rouge1"], score["rouge2"], score["rougeL"]
309
+
310
+
311
+ def _translate_instruction(basic_instruction: str, target_language: str) -> str:
312
+ translator = EasyGoogleTranslate(
313
+ source_language="en",
314
+ target_language=LANGUAGE_TO_SUFFIX[target_language],
315
+ timeout=10,
316
+ )
317
+ return translator.translate(basic_instruction)
318
+
319
+
320
+ def _translate_prediction_to_output_language(
321
+ prediction: str, prediction_language: str, output_language: str
322
+ ) -> str:
323
+ translator = EasyGoogleTranslate(
324
+ source_language=LANGUAGE_TO_SUFFIX[prediction_language],
325
+ target_language=LANGUAGE_TO_SUFFIX[output_language],
326
+ timeout=10,
327
+ )
328
+ return translator.translate(prediction)
329
+
330
+
331
+ def create_instruction(lang: str):
332
+ basic_instruction = f"""
333
+ You are an NLP assistant whose purpose is to solve Natural Language Inference (NLI) problems.
334
+ NLI is the task of determining the inference relation between two texts: entailment,
335
+ contradiction, or neutral.
336
+ Your answer should be one word of the following - entailment, contradiction, or neutral.
337
+ Pay attention: The output should be only one word!!!!
338
+ """
339
+ return (
340
+ basic_instruction
341
+ if lang == "english"
342
+ else _translate_instruction(basic_instruction, target_language=lang)
343
+ )
344
+
345
+
346
+ def run_one_configuration(params: Optional[PARAMS] = None, zero: bool= False):
347
+ if not params:
348
+ params = read_parameters("../../parameters.yaml")
349
+
350
+ lang = params["selected_language"]
351
+ config = params["config"]
352
+ zero_shot = len(config["context"]) == 0
353
+
354
+ if not zero:
355
+ config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}"
356
+ else:
357
+ config_header = f"{config['input']}_{config['prefix']}_zero"
358
+ test_data = load_xnli_dataset(
359
+ dataset_name=params["dataset_name"],
360
+ lang=lang,
361
+ split="test",
362
+ limit=params["limit"],
363
+ )
364
+
365
+ pool = mp.Pool(processes=3)
366
+
367
+ # Iterate over test_data using tqdm for progress tracking
368
+ for idx, test_example in tqdm(enumerate(test_data), total=len(test_data)):
369
+ # Apply asynchronous processing of each test example
370
+ pool.apply_async(
371
+ process_test_example,
372
+ args=(
373
+ test_data,
374
+ config_header,
375
+ idx,
376
+ test_example,
377
+ config,
378
+ zero_shot,
379
+ lang,
380
+ params,
381
+ ),
382
+ )
383
+
384
+ # Close the pool and wait for all processes to finish
385
+ pool.close()
386
+ pool.join()
387
+
388
+ def process_test_example(
389
+ test_data, config_header, idx, test_example, config, zero_shot, lang, params
390
+ ):
391
+ try:
392
+ instruction = create_instruction(lang=config["prefix"])
393
+ text_example = {
394
+ "premise": test_example["premise"],
395
+ "hypothesis": test_example["hypothesis"],
396
+ "label": test_example["label"],
397
+ }
398
+
399
+ ic_examples = []
400
+ if not zero_shot:
401
+ ic_examples = choose_few_shot_examples(
402
+ train_dataset=test_data,
403
+ few_shot_size=len(config["context"]),
404
+ context=config["context"],
405
+ selection_criteria="random",
406
+ lang=params["selected_language"],
407
+ )
408
+
409
+ prompt, label = construct_prompt(
410
+ instruction=instruction,
411
+ test_example=text_example,
412
+ ic_examples=ic_examples,
413
+ zero_shot=zero_shot,
414
+ )
415
+
416
+ pred = get_prediction(prompt=prompt, endpoint_id=7327255438662041600, project_id=16514800572)
417
+ print(pred)
418
+
419
+ os.makedirs(
420
+ f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True
421
+ )
422
+ dump_predictions(
423
+ idx=idx,
424
+ response=pred,
425
+ label=label,
426
+ response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv",
427
+ )
428
+
429
+ except Exception as e:
430
+ # Handle exceptions here
431
+ print(f"Error processing example {idx}: {e}")
432
+
433
+
434
+ def construct_prompt(
435
+ instruction: str,
436
+ test_example: dict,
437
+ zero_shot: bool,
438
+ num_examples: int,
439
+ lang: str,
440
+ config: Dict[str, str],
441
+ dataset_name: str = 'xnli'
442
+ ):
443
+
444
+ if not instruction:
445
+ print(lang)
446
+ instruction = create_instruction(lang)
447
+
448
+ example_prompt = PromptTemplate(
449
+ input_variables=["premise", "hypothesis", "label"],
450
+ template="Premise {premise}\n Hypothesis {hypothesis} \n{label}",
451
+ )
452
+
453
+ zero_shot_template = (
454
+ f"""{instruction}""" + "\n Hypothesis: {hypothesis} + \n Premise: {premise}" ""
455
+ )
456
+
457
+ test_data = load_xnli_dataset(dataset_name, lang, split="test", limit=100)
458
+
459
+ print(test_data)
460
+ print(num_examples)
461
+ print(lang)
462
+ ic_examples = []
463
+ if not zero_shot:
464
+
465
+ ic_examples = choose_few_shot_examples(
466
+ train_dataset=test_data,
467
+ few_shot_size=num_examples,
468
+ context=[config["context"]] * num_examples,
469
+ selection_criteria="random",
470
+ lang=lang,
471
+ )
472
+
473
+ prompt = (
474
+ FewShotPromptTemplate(
475
+ examples=ic_examples,
476
+ prefix=instruction,
477
+ example_prompt=example_prompt,
478
+ suffix="{premise} \n{hypothesis}",
479
+ input_variables=["hypothesis", "premise"],
480
+ )
481
+ if not zero_shot
482
+ else PromptTemplate(
483
+ input_variables=["hypothesis", "premise"], template=zero_shot_template
484
+ )
485
+ )
486
+
487
+ print("lang", lang)
488
+ print(config["input"] , lang)
489
+ if config["input"] != lang:
490
+ test_example = _translate_example(
491
+ example=test_example, src_language=lang, target_language=config["input"]
492
+ )
493
+
494
+ return prompt.format(
495
+ hypothesis=test_example["hypothesis"], premise=test_example["premise"]
496
+ )
tasks/qa.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import logging
4
+ import multiprocessing as mp
5
+ import os
6
+ import subprocess
7
+ import re
8
+
9
+ import string
10
+ import sys
11
+ import subprocess
12
+ import time
13
+ import unicodedata
14
+ from typing import Any, Dict, List, NewType, Optional, Union
15
+
16
+ import numpy as np
17
+ import openai
18
+ import requests
19
+ import yaml
20
+ from datasets import Dataset, load_dataset
21
+ from easygoogletranslate import EasyGoogleTranslate
22
+ from evaluate import load
23
+ from langchain.prompts import FewShotPromptTemplate, PromptTemplate
24
+ from tqdm import tqdm
25
+ from yaml.loader import SafeLoader
26
+
27
+
28
+ # from models.model_completion import gpt3x_completion, gemini_completion
29
+
30
+ def gemini_completion(prompt):
31
+ # Define the endpoint URL
32
+ genai.configure(api_key="")
33
+ model = genai.GenerativeModel("models/gemini-1.0-pro-latest")
34
+ return model.generate_content(prompt).text
35
+
36
+
37
+ # checkpoint = "bigscience/mt0-base"
38
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
39
+ #
40
+ # tokenizer = AutoTokenizer.from_pretrained(checkpoint)
41
+ # model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")
42
+ # model.to("cuda:04")
43
+
44
+
45
+
46
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
47
+
48
+
49
+ def get_entities_gpt3_long(prompt):
50
+ response = openai.ChatCompletion.create(
51
+ engine="chatgpt", temperature=0, messages=[{"role": "user", "content": prompt}]
52
+ )
53
+ return response["choices"][0]["message"]["content"]
54
+
55
+
56
+ def gpt3x_completion(
57
+ prompt: Union[str, List[Dict[str, str]]],
58
+ model: str = "chatgpt",
59
+ # run_details: Any = {},
60
+ # num_evals_per_sec: int = 2,
61
+ # **model_params,
62
+ ) -> str:
63
+ import os
64
+ import openai
65
+ os.environ["OPENAI_API_KEY"] = ''
66
+ openai.api_type = "azure"
67
+
68
+ def get_entities_chatGPT(final_prompt):
69
+ response = openai.ChatCompletion.create(
70
+ engine="gpt35-16k",
71
+ temperature=0,
72
+ messages=[
73
+ {"role": "user", "content": final_prompt}
74
+ ]
75
+ )
76
+ return response['choices'][0]['message']['content']
77
+
78
+ return get_entities_chatGPT(final_prompt=prompt)
79
+
80
+
81
+ def mt0_completion(prompt):
82
+ inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
83
+ outputs = model.generate(inputs)
84
+ return tokenizer.decode(outputs[0])
85
+
86
+
87
+ def mixtral_completion(prompt):
88
+ url = "https://api.together.xyz/v1/chat/completions"
89
+
90
+ # Define your Together API key
91
+ together_api_key = "" # Replace with your actual API key
92
+
93
+ # Define the request payload
94
+ payload = {
95
+ "temperature": 0,
96
+ "max_tokens": 30,
97
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
98
+ "messages": [{"role": "user", "content": f"{prompt}"}],
99
+ }
100
+
101
+ # Define request headers
102
+ headers = {
103
+ "Authorization": f"Bearer {together_api_key}",
104
+ "Content-Type": "application/json",
105
+ }
106
+
107
+ # Send POST request
108
+ response = requests.post(url, json=payload, headers=headers)
109
+
110
+ # Check response status
111
+ if response.status_code == 200:
112
+ # Print the response content (API output)
113
+ return response.json()["choices"][0]["message"]["content"]
114
+ else:
115
+ # Print error message if request fails
116
+ print(f"Error: {response.status_code} - {response.text}")
117
+
118
+
119
+ XQUAD_LANG2CODES = {
120
+ "bengali": "bn",
121
+ "korean": "ko",
122
+ "swahili": "sw",
123
+ "english": "en",
124
+ "indonesian": "id",
125
+ "arabic": "ar",
126
+ "finnish": "fi",
127
+ "telugu": "te",
128
+ "russian": "ru",
129
+ "german": "de",
130
+ "greek": "el",
131
+ "hindi": "hi",
132
+ "vietnamese": "vi",
133
+ "romanian": "ro",
134
+ }
135
+
136
+ INDICQA_LANG2CODES = {
137
+ "indicqa": "as",
138
+ "bengali": "bn",
139
+ "gujarati": "gu",
140
+ "hindi": "hi",
141
+ "kannada": "kn",
142
+ "malayalam": "ml",
143
+ "marathi": "mr",
144
+ "odia": "or",
145
+ "punjabi": "pa",
146
+ "tamil": "ta",
147
+ "telugu": "te",
148
+ "assamese": "as",
149
+ }
150
+
151
+ PUNCT = {
152
+ chr(i)
153
+ for i in range(sys.maxunicode)
154
+ if unicodedata.category(chr(i)).startswith("P")
155
+ }.union(string.punctuation)
156
+ WHITESPACE_LANGS = ["en", "es", "hi", "vi", "de", "ar"]
157
+ MIXED_SEGMENTATION_LANGS = ["zh"]
158
+
159
+ TYDIQA_LANG2CODES = {
160
+ "bengali": "bn",
161
+ "korean": "ko",
162
+ "swahili": "sw",
163
+ "english": "en",
164
+ "indonesian": "id",
165
+ "arabic": "ar",
166
+ "finnish": "fi",
167
+ "telugu": "te",
168
+ "russian": "ru",
169
+ "assamese": "as",
170
+ "persian": "fa",
171
+ }
172
+
173
+ logger = logging.Logger("Xlsum_task")
174
+ LANGUAGE_TO_SUFFIX = {
175
+ "chinese_simplified": "zh-CN",
176
+ "french": "fr",
177
+ "portuguese": "pt",
178
+ "english": "en",
179
+ "arabic": "ar",
180
+ "hindi": "hi",
181
+ "indonesian": "id",
182
+ "amharic": "am",
183
+ "bengali": "bn",
184
+ "telugu": "te",
185
+ "burmese": "my",
186
+ "german": "de",
187
+ "greek": "el",
188
+ "tamil": "ta",
189
+ "assamese": "as",
190
+ "hindi": "hi",
191
+ "vietnamese": "vi",
192
+ "russian": "ru",
193
+ "telugu": "te",
194
+ "romanian": "ro",
195
+ "malayalam": "ml",
196
+ "persian": "fa",
197
+ }
198
+
199
+ PARAMS = NewType("PARAMS", Dict[str, Any])
200
+
201
+
202
+ def read_parameters(args_path) -> PARAMS:
203
+ with open(args_path) as f:
204
+ args = yaml.load(f, Loader=SafeLoader)
205
+ return args
206
+
207
+
208
+ def load_qa_dataset(dataset_name, lang, split, translate_test=False, limit=5):
209
+ if dataset_name == "indicqa":
210
+ if split != "train":
211
+ dataset = load_dataset(
212
+ "ai4bharat/IndicQA", f"indicqa.{INDICQA_LANG2CODES[lang]}"
213
+ )[split]
214
+ else:
215
+ dataset = load_dataset("squad_v2")[split]
216
+ elif dataset_name == "xquad":
217
+ if split != "train":
218
+ dataset = load_dataset("xquad", f"xquad.{XQUAD_LANG2CODES[lang]}")[
219
+ "validation"
220
+ ]
221
+ else:
222
+ dataset = load_dataset("squad")[split]
223
+ elif dataset_name == "tydiqa":
224
+ dataset = load_dataset("tydiqa", "secondary_task")[split]
225
+ dataset = dataset.map(
226
+ lambda example: {"lang": TYDIQA_LANG2CODES[example["id"].split("-")[0]]}
227
+ )
228
+ dataset = dataset.filter(lambda example: example["lang"] == lang)
229
+ elif dataset_name == "mlqa":
230
+ if split == "train":
231
+ print("No Training Data for MLQA, switching to validation!")
232
+ split = "validation"
233
+ if translate_test:
234
+ dataset_name = f"mlqa-translate-test.{lang}"
235
+ else:
236
+ dataset_name = f"mlqa.{lang}.{lang}"
237
+
238
+ dataset = load_dataset("mlqa", dataset_name)[split]
239
+
240
+ else:
241
+ raise NotImplementedError()
242
+ return dataset.select(np.arange(limit))
243
+
244
+
245
+ def construct_prompt(
246
+ instruction: str,
247
+ test_example: dict,
248
+ ic_examples: List[dict],
249
+ zero_shot: bool,
250
+ lang: str,
251
+ config: Any,
252
+ ):
253
+ example_prompt = PromptTemplate(
254
+ input_variables=["context", "question", "answers"],
255
+ template="Context: {context} \n Question: {question} \n " "Answers: {answers}",
256
+ )
257
+
258
+ zero_shot_template = (
259
+ f"""{instruction}""" + " \n <Context>: {context} \n <Question>: {question} " ""
260
+ )
261
+
262
+ prompt = (
263
+ FewShotPromptTemplate(
264
+ examples=ic_examples,
265
+ prefix=instruction,
266
+ example_prompt=example_prompt,
267
+ suffix="<Context>: {context} \n <Question>: {question} \n Answers: ?",
268
+ input_variables=["question", "context"],
269
+ )
270
+ if not zero_shot
271
+ else PromptTemplate(
272
+ input_variables=["question", "context"], template=zero_shot_template
273
+ )
274
+ )
275
+
276
+ label = test_example["answers"]
277
+ if config["input"] != lang:
278
+ test_example = _translate_example(
279
+ example=test_example, src_language=lang, target_language=config["input"]
280
+ )
281
+
282
+ return (
283
+ prompt.format(
284
+ question=test_example["question"], context=test_example["context"]
285
+ ),
286
+ label,
287
+ )
288
+
289
+
290
+ def dump_metrics(
291
+ lang: str, config: Dict[str, str], f1: float, em: float, metric_logger_path: str
292
+ ):
293
+ # Check if the metric logger file exists
294
+ file_exists = os.path.exists(metric_logger_path)
295
+
296
+ # Open the CSV file in append mode
297
+ with open(metric_logger_path, "a", newline="") as f:
298
+ csvwriter = csv.writer(f, delimiter=",")
299
+
300
+ # Write header row if the file is newly created
301
+ if not file_exists:
302
+ header = ["Language", "Prefix", "Input", "Context", "Output", "F1", "Em"]
303
+ csvwriter.writerow(header)
304
+
305
+ csvwriter.writerow(
306
+ [
307
+ lang,
308
+ config["prefix"],
309
+ config["input"],
310
+ config["context"][0],
311
+ config["output"],
312
+ f1,
313
+ em,
314
+ ]
315
+ )
316
+
317
+
318
+ def dump_predictions(idx, response, label, response_logger_file):
319
+ obj = {"q_idx": idx, "prediction": response, "label": label}
320
+ with open(response_logger_file, "a") as f:
321
+ f.write(json.dumps(obj, ensure_ascii=False) + " \n ")
322
+
323
+
324
+ def _translate_instruction(basic_instruction: str, target_language: str) -> str:
325
+ translator = EasyGoogleTranslate(
326
+ source_language="en",
327
+ target_language=LANGUAGE_TO_SUFFIX[target_language],
328
+ timeout=50,
329
+ )
330
+ return translator.translate(basic_instruction)
331
+
332
+
333
+ def _translate_prediction_to_output_language(
334
+ prediction: str, prediction_language: str, output_language: str
335
+ ) -> str:
336
+ translator = EasyGoogleTranslate(
337
+ source_language=LANGUAGE_TO_SUFFIX[prediction_language],
338
+ target_language=LANGUAGE_TO_SUFFIX[output_language],
339
+ timeout=10,
340
+ )
341
+ return translator.translate(prediction)
342
+
343
+
344
+ def create_instruction(lang: str, expected_output: str):
345
+ basic_instruction = (
346
+ "Answer to the <Question> below, based only to the given <Context>, Follow these instructions: \n "
347
+ "1. The answer should include only words from the given context \n "
348
+ "2. The answer must include up to 5 words \n "
349
+ "3. The answer Should be the shortest as possible \n "
350
+ f"4. The answer must be in {expected_output} only!, not another language!!!"
351
+ )
352
+ return (
353
+ basic_instruction
354
+ if expected_output == "english"
355
+ else _translate_instruction(basic_instruction, target_language=lang)
356
+ )
357
+
358
+
359
+ def _translate_example(
360
+ example: Dict[str, str], src_language: str, target_language: str
361
+ ):
362
+ translator = EasyGoogleTranslate(
363
+ source_language=LANGUAGE_TO_SUFFIX[src_language],
364
+ target_language=LANGUAGE_TO_SUFFIX[target_language],
365
+ timeout=30,
366
+ )
367
+ try:
368
+ return {
369
+ "question": translator.translate(example["question"]),
370
+ "context": translator.translate(example["context"][:2000])
371
+ + translator.translate(example["context"][2000:4000])
372
+ + translator.translate(example["context"][4000:6000]),
373
+ "answers": "",
374
+ }
375
+ except Exception as e:
376
+ pass
377
+
378
+ def choose_few_shot_examples(
379
+ train_dataset: Dataset,
380
+ few_shot_size: int,
381
+ context: List[str],
382
+ selection_criteria: str,
383
+ lang: str,
384
+ ) -> List[Dict[str, Union[str, int]]]:
385
+ """Selects few-shot examples from training datasets
386
+
387
+ Args:
388
+ train_dataset (Dataset): Training Dataset
389
+ few_shot_size (int): Number of few-shot examples
390
+ selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k]
391
+
392
+ Returns:
393
+ List[Dict[str, Union[str, int]]]: Selected examples
394
+ """
395
+ selected_examples = []
396
+
397
+ example_idxs = []
398
+ if selection_criteria == "first_k":
399
+ example_idxs = list(range(few_shot_size))
400
+ elif selection_criteria == "random":
401
+ example_idxs = (
402
+ np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
403
+ .astype(int)
404
+ .tolist()
405
+ )
406
+
407
+ ic_examples = [
408
+ {
409
+ "question": train_dataset[idx]["question"],
410
+ "context": train_dataset[idx]["context"],
411
+ "answers": train_dataset[idx]["answers"]["text"],
412
+ }
413
+ for idx in example_idxs
414
+ ]
415
+
416
+ for idx, ic_language in enumerate(context):
417
+ (
418
+ selected_examples.append(ic_examples[idx])
419
+ if ic_language == lang
420
+ else (
421
+ selected_examples.append(
422
+ _translate_example(
423
+ example=ic_examples[idx],
424
+ src_language=lang,
425
+ target_language=ic_language,
426
+ )
427
+ )
428
+ )
429
+ )
430
+
431
+ return selected_examples
432
+
433
+
434
+ def normalize_answer(s):
435
+ """Lower text and remove punctuation, articles and extra whitespace."""
436
+
437
+ def remove_articles(text):
438
+ return re.sub(r"\b(a|an|the)\b", " ", text)
439
+
440
+ def white_space_fix(text):
441
+ return " ".join(text.split())
442
+
443
+ def remove_punc(text):
444
+ exclude = set(PUNCT) # set(string.punctuation)
445
+ return "".join(ch for ch in text if ch not in exclude)
446
+
447
+ def lower(text):
448
+ return text.lower()
449
+
450
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
451
+
452
+
453
+ def process_test_example(
454
+ test_data, config_header, idx, test_example, config, zero_shot, lang, params
455
+ ):
456
+ try:
457
+ # Your existing code for processing each test example
458
+ instruction = create_instruction(
459
+ lang=config["prefix"], expected_output=config["output"]
460
+ )
461
+ text_example = {
462
+ "question": test_example["question"],
463
+ "context": test_example["context"],
464
+ "answers": test_example["answers"]["text"],
465
+ }
466
+
467
+ ic_examples = []
468
+ if not zero_shot:
469
+ ic_examples = choose_few_shot_examples(
470
+ train_dataset=test_data,
471
+ few_shot_size=len(config["context"]),
472
+ context=config["context"],
473
+ selection_criteria="random",
474
+ lang=params["selected_language"],
475
+ )
476
+
477
+ prompt, label = construct_prompt(
478
+ instruction=instruction,
479
+ test_example=text_example,
480
+ ic_examples=ic_examples,
481
+ zero_shot=zero_shot,
482
+ lang=lang,
483
+ config=config,
484
+ )
485
+
486
+ print(len(prompt))
487
+ pred = get_prediction(prompt=prompt, endpoint_id=7327255438662041600, project_id=16514800572)
488
+ # pred = mixtral_completion(prompt)
489
+ print(pred)
490
+
491
+ logger.info("Saving prediction to persistent volume")
492
+ os.makedirs(
493
+ f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True
494
+ )
495
+ dump_predictions(
496
+ idx=idx,
497
+ response=pred,
498
+ label=label,
499
+ response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv",
500
+ )
501
+ except Exception as e:
502
+ # Handle exceptions here
503
+ print(f"Error processing example {idx}: {e}")
504
+
505
+
506
+ def run_one_configuration(params: Optional[PARAMS] = None):
507
+ if not params:
508
+ params = read_parameters("../../parameters.yaml")
509
+
510
+ lang = params["selected_language"]
511
+ config = params["config"]
512
+ zero_shot = len(config["context"]) == 0
513
+ rouge1, rouge2, rougeL, normalized_ic_examples, batched_predictions = (
514
+ [],
515
+ [],
516
+ [],
517
+ [],
518
+ [],
519
+ )
520
+ config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}_{config['output']}"
521
+ dataset_name = params["dataset_name"]
522
+ squad_metric = load("squad")
523
+ metric = params["metric"]
524
+ f1_sum = 0
525
+ em_sum = 0
526
+ avg_em = 0
527
+ avg_f1 = 0
528
+ preds = []
529
+ labels = []
530
+ f1s, ems = [], []
531
+
532
+ test_data = load_qa_dataset(
533
+ dataset_name=params["dataset_name"],
534
+ lang=lang,
535
+ split="validation" if params["dataset_name"] == "xquad" else "test",
536
+ limit=params["limit"],
537
+ )
538
+
539
+ for idx, test_example in (pbar := tqdm(enumerate(test_data))):
540
+ try:
541
+ instruction = create_instruction(
542
+ lang=config["prefix"], expected_output=config["output"]
543
+ )
544
+ text_example = {
545
+ "question": test_example["question"],
546
+ "context": test_example["context"],
547
+ "answers": test_example["answers"]["text"],
548
+ }
549
+
550
+ ic_examples = []
551
+ if not zero_shot:
552
+ ic_examples = choose_few_shot_examples(
553
+ train_dataset=test_data,
554
+ few_shot_size=len(config["context"]),
555
+ context=config["context"],
556
+ selection_criteria="random",
557
+ lang=params["selected_language"],
558
+ )
559
+
560
+ prompt, label = construct_prompt(
561
+ instruction=instruction,
562
+ test_example=text_example,
563
+ ic_examples=ic_examples,
564
+ zero_shot=zero_shot,
565
+ lang=lang,
566
+ config=config,
567
+ )
568
+
569
+ pred = mt0_completion(prompt=prompt)
570
+ print(pred)
571
+
572
+ logger.info("Saving prediction to persistent volume")
573
+ os.makedirs(
574
+ f"{params['response_logger_root']}" + f"{params['model']}" + f"/{lang}",
575
+ exist_ok=True,
576
+ )
577
+ dump_predictions(
578
+ idx=idx,
579
+ response=pred,
580
+ label=label,
581
+ response_logger_file=f"{params['response_logger_root']}"
582
+ + f"/{params['model']}"
583
+ + f"/{lang}/"
584
+ + config_header
585
+ + ".csv",
586
+ )
587
+ #
588
+ # normalized_prediction = normalize_answer(pred)
589
+ # batched_predictions.append(normalized_prediction)
590
+ #
591
+ # if config["output"] != params["selected_language"]:
592
+ # pred = _translate_prediction_to_output_language(
593
+ # prediction=normalized_prediction,
594
+ # prediction_language=config["output"],
595
+ # output_language=params["selected_language"],
596
+ # )
597
+ # print(
598
+ # f"Translated the prediciton from {config['output']} to {params['selected_language']}"
599
+ # )
600
+ #
601
+ # logger.info("Starting evaluation")
602
+ #
603
+ # if dataset_name == "xquad":
604
+ # prediction = {"prediction_text": pred, "id": test_example["id"]}
605
+ #
606
+ # reference = {}
607
+ # reference["answers"] = test_example["answers"]
608
+ # reference["id"] = test_example["id"]
609
+ # if reference["answers"]["text"][0] == "":
610
+ # reference["answers"]["text"] = []
611
+ # reference["answers"]["answer_start"] = []
612
+ #
613
+ # if params["metric"] == "squad":
614
+ # results = squad_metric.compute(
615
+ # predictions=[prediction], references=[reference]
616
+ # )
617
+ # else:
618
+ # results = squad_metric.compute(
619
+ # predictions=[prediction],
620
+ # references=[reference],
621
+ # no_answer_threshold=0.9,
622
+ # )
623
+ #
624
+ # f1_sum += results["f1"]
625
+ # if metric == "squad":
626
+ # em_sum += results["exact_match"]
627
+ # else:
628
+ # em_sum += results["exact"]
629
+ # avg_f1 = f1_sum / (idx + 1)
630
+ # avg_em = em_sum / (idx + 1)
631
+ #
632
+ # preds.append(prediction)
633
+ # labels.append(reference)
634
+ # f1s.append(results["f1"])
635
+ # if metric == "squad":
636
+ # ems.append(results["exact_match"])
637
+ # else:
638
+ # ems.append(results["exact"])
639
+
640
+ except Exception as e:
641
+ print(f"Found an exception {e}, continue to the next example")
642
+ continue
643
+
644
+ os.makedirs(f"{params['metrics_root']}" + f"/{params['model']}", exist_ok=True)
645
+
646
+ dump_metrics(
647
+ lang,
648
+ config,
649
+ avg_f1,
650
+ avg_em,
651
+ f"{params['metrics_root']}" + f"/{params['model']}" + f"/{lang}.csv",
652
+ )
653
+
654
+
655
+ # if __name__ == "__main__":
656
+ # run_one_configuration()
657
+
658
+
659
+ def run_one_configuration_paralle(params: Optional[PARAMS] = None, zero: bool = False):
660
+ if not params:
661
+ params = read_parameters("../../parameters.yaml")
662
+
663
+ lang = params["selected_language"]
664
+ config = params["config"]
665
+ zero_shot = len(config["context"]) == 0
666
+ rouge1, rouge2, rougeL, normalized_ic_examples, batched_predictions = (
667
+ [],
668
+ [],
669
+ [],
670
+ [],
671
+ [],
672
+ )
673
+ if not zero:
674
+ config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}_{config['output']}"
675
+ else:
676
+ config_header = f"{config['input']}_{config['prefix']}_zero_{config['output']}"
677
+ test_data = load_qa_dataset(
678
+ dataset_name=params["dataset_name"],
679
+ lang=lang,
680
+ split="validation" if params["dataset_name"] == "xquad" else "test",
681
+ limit=params["limit"],
682
+ )
683
+
684
+ # Initialize multiprocessing poosl
685
+ num_processes = mp.cpu_count() # Use number of available CPU cores
686
+ pool = mp.Pool(processes=10)
687
+
688
+ # Iterate over test_data using tqdm for progress tracking
689
+ for idx, test_example in tqdm(enumerate(test_data), total=len(test_data)):
690
+ # Apply asynchronous processing of each test example
691
+ pool.apply_async(
692
+ process_test_example,
693
+ args=(
694
+ test_data,
695
+ config_header,
696
+ idx,
697
+ test_example,
698
+ config,
699
+ zero_shot,
700
+ lang,
701
+ params,
702
+ ),
703
+ )
704
+
705
+ # Close the pool and wait for all processes to finish
706
+ pool.close()
707
+ pool.join()
708
+
709
+
710
+
711
+ def construct_prompt(
712
+ instruction: str,
713
+ test_example: dict,
714
+ zero_shot: bool,
715
+ num_examples: int,
716
+ lang: str,
717
+ config: Dict[str, str],
718
+ dataset_name: str = 'xquad'
719
+ ):
720
+ if not instruction:
721
+ instruction = create_instruction(lang, config['prefix'])
722
+
723
+ example_prompt = PromptTemplate(
724
+ input_variables=["context", "question", "answers"],
725
+ template="Context: {context} \n Question: {question} \n " "Answers: {answers}",
726
+ )
727
+
728
+ zero_shot_template = (
729
+ f"""{instruction}""" + " \n <Context>: {context} \n <Question>: {question} " ""
730
+ )
731
+
732
+ test_data = load_qa_dataset(dataset_name = dataset_name, lang=lang, split="test", limit=100)
733
+
734
+ print(test_data)
735
+ print(num_examples)
736
+ print(lang)
737
+ ic_examples = []
738
+ if not zero_shot:
739
+
740
+ ic_examples = choose_few_shot_examples(
741
+ train_dataset=test_data,
742
+ few_shot_size=num_examples,
743
+ context=[config["context"]] * num_examples,
744
+ selection_criteria="random",
745
+ lang=lang,
746
+ )
747
+
748
+ prompt = (
749
+ FewShotPromptTemplate(
750
+ examples=ic_examples,
751
+ prefix=instruction,
752
+ example_prompt=example_prompt,
753
+ suffix="<Context>: {context} \n <Question>: {question} \n Answers: ?",
754
+ input_variables=["question", "context"],
755
+ )
756
+ if not zero_shot
757
+ else PromptTemplate(
758
+ input_variables=["question", "context"], template=zero_shot_template
759
+ )
760
+ )
761
+ print("lang", lang)
762
+ print(config["input"] , lang)
763
+ if config["input"] != lang:
764
+ test_example = _translate_example(
765
+ example=test_example, src_language=lang, target_language=config["input"]
766
+ )
767
+
768
+ return prompt.format(
769
+ question=test_example["question"], context=test_example["context"]
770
+ )
tasks/summarization.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional, Union
2
+
3
+ import numpy as np
4
+ from datasets import Dataset, load_dataset
5
+ from easygoogletranslate import EasyGoogleTranslate
6
+ from langchain.prompts import PromptTemplate, FewShotPromptTemplate
7
+
8
+ LANGUAGE_TO_SUFFIX = {
9
+ "chinese_simplified": "zh-CN",
10
+ "french": "fr",
11
+ "portuguese": "pt",
12
+ "english": "en",
13
+ "arabic": "ar",
14
+ "hindi": "hi",
15
+ "indonesian": "id",
16
+ "amharic": "am",
17
+ "bengali": "bn",
18
+ "burmese": "my",
19
+ "uzbek": "uz",
20
+ "nepali": "ne",
21
+ "japanese": "ja",
22
+ "spanish": "es",
23
+ "turkish": "tr",
24
+ "persian": "fa",
25
+ "azerbaijani": "az",
26
+ "korean": "ko",
27
+ }
28
+
29
+ def choose_few_shot_examples(
30
+ train_dataset: Dataset, few_shot_size: int, context: List[str], selection_criteria: str, lang: str,
31
+ ) -> List[Dict[str, Union[str, int]]]:
32
+
33
+ selected_examples = []
34
+
35
+ example_idxs = []
36
+ if selection_criteria == "first_k":
37
+ example_idxs = list(range(few_shot_size))
38
+ elif selection_criteria == "random":
39
+ example_idxs = (
40
+ np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
41
+ .astype(int)
42
+ .tolist()
43
+ )
44
+
45
+ ic_examples = [{'text': train_dataset[idx]['text'], 'summary': train_dataset[idx]['summary']} for idx in
46
+ example_idxs]
47
+
48
+ for idx, ic_language in enumerate(context):
49
+ selected_examples.append(ic_examples[idx]) if ic_language == lang else (
50
+ selected_examples.append(
51
+ _translate_example(example=ic_examples[idx], src_language=lang, target_language=ic_language)))
52
+
53
+ return selected_examples
54
+
55
+
56
+ def _translate_instruction(basic_instruction: str, target_language: str) -> str:
57
+ translator = EasyGoogleTranslate(
58
+ source_language="en",
59
+ target_language=LANGUAGE_TO_SUFFIX[target_language],
60
+ timeout=50,
61
+ )
62
+ return translator.translate(basic_instruction)
63
+
64
+
65
+ def _translate_example(example: Dict[str, str], src_language: str, target_language: str):
66
+ translator = EasyGoogleTranslate(source_language=LANGUAGE_TO_SUFFIX[src_language],
67
+ target_language=LANGUAGE_TO_SUFFIX[target_language],
68
+ timeout=30)
69
+ try:
70
+ return {'text': translator.translate(example['text']), 'summary': ''}
71
+ except Exception as e:
72
+ print(e)
73
+
74
+
75
+ def create_instruction(lang: str, expected_output: str):
76
+ basic_instruction = (
77
+ f"Write a summary of the given <Text> \n The output should be in {expected_output} "
78
+ f"\n The output must be up to 2 sentences maximum!!!"
79
+ )
80
+ print(lang)
81
+ return (
82
+ basic_instruction
83
+ if expected_output == "english"
84
+ else _translate_instruction(basic_instruction, target_language=lang)
85
+ )
86
+
87
+
88
+ def load_xlsum_data(lang, split, limit = 5):
89
+ """Loads the xlsum dataset"""
90
+ dataset = load_dataset("csebuetnlp/xlsum", lang)[split]
91
+ return dataset.select(range(limit))
92
+
93
+
94
+ def construct_prompt(
95
+ instruction: str,
96
+ test_example: dict,
97
+ zero_shot: bool,
98
+ dataset: str,
99
+ num_examples: int,
100
+ lang: str,
101
+ config: Dict[str, str],
102
+ ):
103
+ if not instruction:
104
+ print(lang)
105
+ instruction = create_instruction(lang, config['prefix'])
106
+
107
+ example_prompt = PromptTemplate(
108
+ input_variables=["summary", "text"], template="Text: {text}\nSummary: {summary}"
109
+ )
110
+
111
+ zero_shot_template = f"""{instruction}""" + "\n Input: {text} " ""
112
+
113
+ test_data = load_xlsum_data(lang=lang, split="test", limit=100)
114
+
115
+ print(test_data)
116
+ print(num_examples)
117
+ print(lang)
118
+ ic_examples = []
119
+ if not zero_shot:
120
+
121
+ ic_examples = choose_few_shot_examples(
122
+ train_dataset=test_data,
123
+ few_shot_size=num_examples,
124
+ context=[config["context"]] * num_examples,
125
+ selection_criteria="random",
126
+ lang=lang,
127
+ )
128
+
129
+ prompt = (
130
+ FewShotPromptTemplate(
131
+ examples=ic_examples,
132
+ prefix=instruction,
133
+ example_prompt=example_prompt,
134
+ suffix="<Text>: {text}",
135
+ input_variables=["text"],
136
+ )
137
+ if not zero_shot
138
+ else PromptTemplate(input_variables=["text"], template=zero_shot_template)
139
+ )
140
+
141
+ print("lang", lang)
142
+ print(config["input"] , lang)
143
+ if config["input"] != lang:
144
+ test_example = _translate_example(
145
+ example=test_example, src_language=lang, target_language=config["input"]
146
+ )
147
+
148
+ print("test_example", prompt)
149
+ return prompt.format(text=test_example["text"])