File size: 30,254 Bytes
0bb0d8e
 
 
 
 
 
 
 
 
 
 
923de84
0bb0d8e
 
3233c26
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b83fe04
 
 
 
0bb0d8e
b83fe04
 
 
0bb0d8e
b83fe04
 
0bb0d8e
b83fe04
99840de
180330f
b83fe04
 
 
 
 
 
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
 
4ed2821
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
4ed2821
 
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
 
 
4ed2821
0bb0d8e
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964e0c7
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
 
 
 
964e0c7
 
 
 
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3233c26
 
 
 
964e0c7
3233c26
 
 
 
964e0c7
3233c26
 
 
 
 
 
 
 
964e0c7
3233c26
 
 
 
 
 
 
 
 
 
 
0bb0d8e
4ed2821
0bb0d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ed2821
0bb0d8e
 
 
 
 
 
 
 
 
 
 
684aed5
 
 
 
 
 
 
 
0bb0d8e
0861973
2866259
55f037d
0bb0d8e
 
8f62037
 
 
0bb0d8e
0861973
0bb0d8e
 
 
 
 
 
0861973
 
0bb0d8e
 
 
0861973
0bb0d8e
 
 
 
55f037d
0861973
0bb0d8e
 
 
 
 
 
55f037d
 
2866259
55f037d
 
0bb0d8e
 
 
 
 
 
 
 
 
 
 
0861973
0bb0d8e
 
 
 
55f037d
0bb0d8e
0861973
 
0bb0d8e
 
 
0861973
684aed5
0bb0d8e
55f037d
0bb0d8e
 
55f037d
0bb0d8e
9aac513
55f037d
9aac513
 
684aed5
9aac513
851f341
 
 
 
964e0c7
851f341
 
 
 
964e0c7
0bb0d8e
a5fc8ed
684aed5
c557b53
 
0bb0d8e
c557b53
0bb0d8e
 
 
c557b53
 
cd02a2c
c557b53
 
 
0bb0d8e
 
2e13b19
964e0c7
 
8f692f6
55f037d
 
0bb0d8e
0861973
0bb0d8e
 
55f037d
 
c557b53
684aed5
0bb0d8e
09caf70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
#python app.py
import gradio as gr
import os
import pandas as pd
import requests
from pathlib import Path
import ctranslate2
import time
import logging
import transformers
import json
import io
from tqdm import tqdm
import subprocess
from huggingface_hub import snapshot_download, upload_file, HfApi, create_repo

# Function to download a Parquet file from a specified URL
def download_parquet(url, local_path):
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(local_path, 'wb') as file:
            for chunk in response.iter_content(chunk_size=1024):
                file.write(chunk)
        print("File downloaded successfully.")
    else:
        print(f"Failed to download file, status code: {response.status_code}")

# Function to convert Parquet files to JSONL format
def convert_parquet_to_jsonl_polars(input_file, output_dir, override=False):
    output_dir_path = Path(output_dir)
    output_dir_path.mkdir(parents=True, exist_ok=True)
    
    input_path = Path(input_file)
    output_file_path = output_dir_path / input_path.with_suffix(".jsonl").name

    if output_file_path.exists() and not override:
        print(f"Skipping because output exists already: {output_file_path}")
    else:
        df = pl.read_parquet(input_path)
        df.write_ndjson(output_file_path)
        print(f"Data written to {output_file_path}")

def convert_parquet_to_jsonl(parquet_filename, jsonl_filename):
    try:
        # Read the parquet file
        df = pd.read_parquet(parquet_filename)
        logger.info(f"Read Parquet file {parquet_filename} successfully.")

        # Convert the dataframe to a JSON string and handle Unicode characters and forward slashes
        json_str = df.to_json(orient='records', lines=True, force_ascii=False)
        logger.info(f"Converted Parquet file to JSON string.")

        # Replace escaped forward slashes if needed
        json_str = json_str.replace('\\/', '/')

        # Write the modified JSON string to the JSONL file
        jsonl_filename += '/train.jsonl'
        logger.info(f"Attempting to save to {jsonl_filename}")
        with open(jsonl_filename, 'w', encoding='utf-8') as file:
            file.write(json_str)
        logger.info(f"Data saved to {jsonl_filename}")
    except Exception as e:
        logger.error(f"Failed to convert Parquet to JSONL: {e}")
        raise

# Function to count lines in a JSONL file
def count_lines_in_jsonl(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        line_count = sum(1 for _ in file)
    return line_count

def parse_range_specification(range_specification, file_length):
    line_indices = []
    ranges = range_specification.split(',')
    for r in ranges:
        if '-' in r:
            parts = r.split('-')
            start = int(parts[0]) - 1 if parts[0] else 0
            end = int(parts[1]) - 1 if parts[1] else file_length - 1
            if start < 0 or end >= file_length:
                logging.error(f"Range {r} is out of bounds.")
                continue  # Skip ranges that are out of bounds
            line_indices.extend(range(start, end + 1))
        else:
            single_line = int(r) - 1
            if single_line < 0 or single_line >= file_length:
                logging.error(f"Line number {r} is out of bounds.")
                continue  # Skip line numbers that are out of bounds
            line_indices.append(single_line)
    return line_indices

def translate_text(text, translator, tokenizer, target_language):
    """
    Translates the given text from English to German using CTranslate2 and the WMT21 model,
    with special handling for newlines and segmenting text longer than 500 characters.
    Ensures sequences of newlines (\n\n, \n\n\n, etc.) are accurately reproduced.
    """
    try:
        segments = []
        newline_sequences = []  # To store sequences of newlines
        segment = ""

        i = 0
        while i < len(text):
            # Collect sequences of newlines
            if text[i] == '\n':
                newline_sequence = '\n'
                while i + 1 < len(text) and text[i + 1] == '\n':
                    newline_sequence += '\n'
                    i += 1
                if segment:
                    segments.append(segment)  # Add the preceding text segment
                    segment = ""
                newline_sequences.append(newline_sequence)  # Store the newline sequence
            else:
                segment += text[i]
                # If segment exceeds 500 characters, or if we reach the end of the text, process it
                if len(segment) >= 500 or i == len(text) - 1:
                    end_index = max(segment.rfind('.', 0, 500), segment.rfind('?', 0, 500), segment.rfind('!', 0, 500))
                    if end_index != -1 and len(segment) > 500:
                        # Split at the last punctuation within the first 500 characters
                        segments.append(segment[:end_index+1])
                        segment = segment[end_index+1:].lstrip()
                    else:
                        # No suitable punctuation or end of text, add the whole segment
                        segments.append(segment)
                        segment = ""
            i += 1

        # Translate the collected text segments
        translated_segments = []
        for segment in segments:
            source = tokenizer.convert_ids_to_tokens(tokenizer.encode(segment))
            target_prefix = [tokenizer.lang_code_to_token[target_language]]
            results = translator.translate_batch([source], target_prefix=[target_prefix])
            target = results[0].hypotheses[0][1:]
            translated_segment = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
            translated_segments.append(translated_segment)

        # Reassemble the translated text with original newline sequences
        translated_text = ""
        for i, segment in enumerate(translated_segments):
            translated_text += segment
            if i < len(newline_sequences):
                translated_text += newline_sequences[i]  # Insert the newline sequence

        return translated_text.strip()

    except Exception as e:
        logging.error(f"An error occurred during translation: {e}")
        return None

def translate_item_ufb(item, raw_file_path, translator, tokenizer, target_language):
    try:
        # Translate the prompt directly since it's a string
        translated_prompt = translate_text(item['prompt'], translator, tokenizer)

        # Translate the chosen and rejected contents
        translated_chosen = []
        for choice in item['chosen']:
            translated_content = translate_text(choice['content'], translator, tokenizer, target_language)
            translated_chosen.append({'content': translated_content, 'role': choice['role']})

        translated_rejected = []
        for choice in item['rejected']:
            translated_content = translate_text(choice['content'], translator, tokenizer, target_language)
            translated_rejected.append({'content': translated_content, 'role': choice['role']})

        # Write the raw response to a backup file
        with open(raw_file_path, 'a', encoding='utf-8') as raw_file:
            raw_file.write(f"Prompt: {translated_prompt}\n")
            raw_file.write(f"Chosen: {json.dumps(translated_chosen, ensure_ascii=False)}\n")
            raw_file.write(f"Rejected: {json.dumps(translated_rejected, ensure_ascii=False)}\n\n")

        logging.info("Translation request successful.")
        # Update the original item with the translated fields
        item['prompt'] = translated_prompt
        item['chosen'] = translated_chosen
        item['rejected'] = translated_rejected
        return item

    except Exception as e:
        logging.error(f"An error occurred during translation: {e}")
        return None

def validate_item_ufb(item):
    # Check basic required fields including 'prompt' as a simple string
    required_fields = ['source', 'prompt', 'chosen', 'rejected']
    for field in required_fields:
        if field not in item:
            logging.warning(f"Missing required field: {field}")
            return False
        if field == 'prompt' and not isinstance(item['prompt'], str):
            logging.warning("Prompt must be a string.")
            return False

    # Check 'chosen' and 'rejected' which should be lists of dictionaries
    for field in ['chosen', 'rejected']:
        if not isinstance(item[field], list) or not item[field]:
            logging.warning(f"No entries or incorrect type for section: {field}")
            return False
        for idx, message in enumerate(item[field]):
            if 'content' not in message or 'role' not in message:
                logging.warning(f"Missing 'content' or 'role' field in {field} at index {idx}")
                return False
            if not isinstance(message['content'], str) or not isinstance(message['role'], str):
                logging.warning(f"Invalid type for 'content' or 'role' field in {field} at index {idx}")
                return False

    return True


    
def translate_item_mix(item, raw_file_path, translator, tokenizer, target_language):
    """
    Translates the relevant fields in the given item from English to German using CTranslate2 and the WMT21 model,
    and saves the raw response to a backup file.
    """
    #print ("translating:", item)
    try:
        # Translate each part of the prompt separately and preserve the order
        translated_prompts = []
        for message in item['prompt']:
            translated_content = translate_text(message['content'], translator, tokenizer, target_language)
            translated_prompts.append({'content': translated_content, 'role': message['role']})

        # Translate the chosen and rejected contents
        translated_chosen_content = translate_text(item['chosen'][0]['content'], translator, tokenizer, target_language)
        translated_rejected_content = translate_text(item['rejected'][0]['content'], translator, tokenizer, target_language)
        
        # Write the raw response to a backup file
        with open(raw_file_path, 'a', encoding='utf-8') as raw_file:
            raw_file.write("Prompt content:\n")
            for translated_prompt in translated_prompts:
                raw_file.write(f"{translated_prompt['role']}: {translated_prompt['content']}\n")
            raw_file.write(f"Chosen content: {translated_chosen_content}\n")
            raw_file.write(f"Rejected content: {translated_rejected_content}\n\n")
        
        logging.info("Translation request successful.")
    except Exception as e:
        logging.error(f"An error occurred during translation: {e}")
        return None
    
    # Update the original item with the translated fields
    item['prompt'] = translated_prompts
    item['chosen'][0]['content'] = translated_chosen_content
    item['rejected'][0]['content'] = translated_rejected_content
    
    logging.info("Translation processing successful.")
    return item

def validate_item_mix(item):
    """
    Validates the structure, presence, and content of required fields in the given item,
    allowing for multiple elements in the 'prompt' field for multi-turn conversations.
    """
    required_fields = ['dataset', 'prompt', 'chosen', 'rejected']
    for field in required_fields:
        if field not in item:
            logging.warning(f"Missing required field: {field}")
            return False
    
    # Check for at least one element in 'prompt' and exactly one element in 'chosen' and 'rejected'
    if len(item['prompt']) < 1 or len(item['chosen']) != 1 or len(item['rejected']) != 1:
        logging.warning("Invalid number of elements in 'prompt', 'chosen', or 'rejected' field.")
        return False
    
    # Validate 'content' and 'role' fields in all messages of 'prompt', and single elements of 'chosen' and 'rejected'
    for choice in item['prompt'] + item['chosen'] + item['rejected']:
        if 'content' not in choice or 'role' not in choice:
            logging.warning("Missing 'content' or 'role' field in choice.")
            return False
        if not isinstance(choice['content'], str) or not isinstance(choice['role'], str):
            logging.warning("Invalid type for 'content' or 'role' field in choice.")
            return False
    
    return True

def translate_item_ufb_cached(item, raw_file_path, translator, tokenizer, target_language):
    try:
        translated_texts = {}  # Cache to store translated texts

        # Translate the prompt if necessary (which is a user input and can appear again)
        if item['prompt'] not in translated_texts:
            translated_prompt = translate_text(item['prompt'], translator, tokenizer, target_language)
            translated_texts[item['prompt']] = translated_prompt
        else:
            translated_prompt = translated_texts[item['prompt']]

        # Helper function to handle content translation with caching
        def get_translated_content(content):
            if content not in translated_texts:
                translated_texts[content] = translate_text(content, translator, tokenizer, target_language)
            return translated_texts[content]

        # Process translations for chosen and rejected sections
        def translate_interactions(interactions):
            translated_interactions = []
            for interaction in interactions:
                translated_content = get_translated_content(interaction['content'])
                translated_interactions.append({'content': translated_content, 'role': interaction['role']})
            return translated_interactions

        translated_chosen = translate_interactions(item['chosen'])
        translated_rejected = translate_interactions(item['rejected'])

        # Write the raw response to a backup file
        with open(raw_file_path, 'a', encoding='utf-8') as raw_file:
            raw_file.write(f"Prompt: {translated_prompt}\n")
            raw_file.write(f"Chosen: {json.dumps(translated_chosen, ensure_ascii=False)}\n")
            raw_file.write(f"Rejected: {json.dumps(translated_rejected, ensure_ascii=False)}\n\n")

        logging.info("Translation request successful.")
        # Update the original item with the translated fields
        item['prompt'] = translated_prompt
        item['chosen'] = translated_chosen
        item['rejected'] = translated_rejected
        return item

    except Exception as e:
        logging.error(f"An error occurred during translation: {e}")
        return None

def validate_item_ufb_cached(item):
    # Check basic required fields
    required_fields = ['source', 'prompt', 'chosen', 'rejected']
    for field in required_fields:
        if field not in item:
            logging.warning(f"Missing required field: {field}")
            return False

    # Ensure 'prompt' is a string
    if not isinstance(item['prompt'], str):
        logging.warning("Prompt must be a string.")
        return False

    # Check 'chosen' and 'rejected' which should be lists of dictionaries
    for field in ['chosen', 'rejected']:
        if not isinstance(item[field], list) or not item[field]:
            logging.warning(f"No entries or incorrect type for section: {field}")
            return False
        for idx, message in enumerate(item[field]):
            if 'content' not in message or 'role' not in message:
                logging.warning(f"Missing 'content' or 'role' field in {field} at index {idx}")
                return False
            if not isinstance(message['content'], str) or not isinstance(message['role'], str):
                logging.warning(f"Invalid type for 'content' or 'role' field in {field} at index {idx}")
                return False

    return True
    
def process_file(input_file_path, output_file_path, raw_file_path, line_indices, translator, tokenizer, model_type, target_language):
    try:
        # Assigning validation and translation functions based on model_type
        if model_type == "mix":
            print ("translating a mix-style model...")
            validate_item = validate_item_mix
            translate_item = translate_item_mix
        elif model_type == "ufb_cached":
            print ("translating an ufb_cached-style model...")
            validate_item = validate_item_ufb_cached
            translate_item = translate_item_ufb_cached # def translate_item_ufb(item, raw_file_path, translator, tokenizer):
        elif model_type == "ufb":
            print ("translating an ultrafeedback-style model...")
            validate_item = validate_item_ufb
            translate_item = translate_item_ufb # def translate_item_ufb(item, raw_file_path, translator, tokenizer):
        else:
            raise ValueError(f"Unsupported model_type: {model_type}")

        with open(input_file_path, 'r', encoding='utf-8') as file:
            data_points = [json.loads(line) for line in file]

        failed_items = []
        failed_items_indices = []

        for index in tqdm(line_indices, desc="Processing lines", unit="item"):
            item = data_points[index]

            # Validate the item structure
            if not validate_item(item):
                logging.warning("Skipping item due to invalid structure.")
                failed_items.append(item)
                continue

            # Translate the relevant fields in the item
            translated_item = None
            retry_count = 0
            while translated_item is None and retry_count < 3:
                print ("going to translate the item...")
                translated_item = translate_item(item, raw_file_path, translator, tokenizer, target_language)
                retry_count += 1
                if translated_item is None:
                    logging.warning(f"Translation failed for item. Retry attempt: {retry_count}")
                    time.sleep(1)
            
            if translated_item is not None:
                translated_item['index'] = index
                with open(output_file_path, 'a', encoding='utf-8') as file:
                    file.write(json.dumps(translated_item, ensure_ascii=False) + "\n")
            else:
                failed_items_indices.append(index)
                failed_items.append(item)
                logging.error("Translation failed after multiple attempts. Skipping item.")

            # Validate the translated item structure
            if not validate_item(translated_item):
                logging.warning("Skipping translated item due to invalid structure.")
                failed_items.append(item)
                continue
        
        with open('failed_items.jsonl', 'w', encoding='utf-8') as file:
            for item in failed_items:
                file.write(json.dumps(item, ensure_ascii=False) + "\n")

        failed_items_str = generate_failed_items_str(failed_items_indices)
        with open('failed_items_index.txt', 'w', encoding='utf-8') as f:
            f.write(failed_items_str)
        
        logging.info("Translation completed successfully.")

    except Exception as e:
        logging.error(f"An error occurred: {e}")

def generate_failed_items_str(indices):
    """
    Converts a list of failed item indices into a string.
    """
    if not indices:
        return ""

    # Sort the list of indices and initialize the first range
    indices.sort()
    range_start = indices[0]
    current = range_start
    ranges = []

    for i in indices[1:]:
        if i == current + 1:
            current = i
        else:
            if range_start == current:
                ranges.append(f"{range_start}")
            else:
                ranges.append(f"{range_start}-{current}")
            range_start = current = i

    # Add the last range
    if range_start == current:
        ranges.append(f"{range_start}")
    else:
        ranges.append(f"{range_start}-{current}")

    return ",".join(ranges)

# Function to upload the output file to Hugging Face
def upload_output_to_huggingface(output_file_path, repo_name, token):
    api = HfApi()
    
    # Check if the repository exists
    try:
        print ("checking repo:", repo_name)
        api.repo_info(repo_id=repo_name, repo_type="dataset", token=token)
    except Exception as e:
        if "404" in str(e):
            # Create the repository if it doesn't exist
            print ("creating it...")
            create_repo(repo_id=repo_name, repo_type="dataset", token=token)
            print(f"Created repository: {repo_name}")
        else:
            print(f"Failed to check repository existence: {e}")
            return

    # Upload the file to the repository
    try:
        print ("starting dataset upload from:", output_file_path)
        upload_file(
            path_or_fileobj=output_file_path,
            path_in_repo=output_file_path,
            repo_id=repo_name,
            repo_type="dataset",
            token=token
        )
        print(f"Uploaded {output_file_path} to Hugging Face repository: {repo_name}")
    except Exception as e:
        print(f"Failed to upload {output_file_path} to Hugging Face: {e}")
        raise

def translate_dataset(train_url, local_parquet_path, input_file_path, output_file_path, raw_file_path, range_specification, model_type, output_dir, output_repo_name, token, translator, tokenizer, target_language):
    try:
        # Download the Parquet file
        download_parquet(train_url, local_parquet_path)
    except Exception as e:
        logging.error(f"Failed to download the Parquet file from {train_url}: {e}")
        return

    try:
        # Convert the downloaded Parquet file to JSONL
        convert_parquet_to_jsonl(local_parquet_path, output_dir)
    except Exception as e:
        logging.error(f"Failed to convert Parquet to JSONL: {e}")
        return

    try:
        # Rename the JSONL file using subprocess to ensure correct handling
        subprocess.run(["mv", f"{output_dir}/train.jsonl", input_file_path], check=True)
    except subprocess.CalledProcessError as e:
        logging.error(f"Failed to rename the file from 'train.jsonl' to {input_file_path}: {e}")
        return

    try:
        # Count lines in the JSONL file to validate contents
        line_count = count_lines_in_jsonl(input_file_path)
        logging.info(f"Number of lines in the file: {line_count}")
    except Exception as e:
        logging.error(f"Failed to count lines in {input_file_path}: {e}")
        return

    try:
        # Parse the range specification for processing specific lines
        line_indices = parse_range_specification(range_specification, file_length=line_count)
        if not line_indices:
            logging.error("No valid line indices to process. Please check the range specifications.")
            return
    except Exception as e:
        logging.error(f"Error parsing range specification '{range_specification}': {e}")
        return

    try:
        # Process the file with specified model type and line indices
        process_file(input_file_path, output_file_path, raw_file_path, line_indices, translator, tokenizer, model_type, target_language)
    except Exception as e:
        logging.error(f"Failed to process the file {input_file_path}: {e}")
        return

    try:
        # Upload the output file to Hugging Face repository
        upload_output_to_huggingface(output_file_path, output_repo_name, token)
    except Exception as e:
        logging.error(f"Failed to upload {output_file_path} to Hugging Face: {e}")  

# Setup logging configuration
log_stream = io.StringIO()
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[
                        logging.FileHandler("translation.log", mode='a'),
                        logging.StreamHandler(log_stream)
                    ])
logger = logging.getLogger(__name__)

# Main function to handle the translation workflow
# Main function to handle the translation workflow
def main(dataset_url, model_type, output_dataset_name, range_specification, target_language, token: gr.OAuthToken | None, profile: gr.OAuthProfile | None):
    try:
        # Login to Hugging Face
        if token is None or profile is None or token.token is None or profile.username is None:
            return "### You must be logged in to use this service."
        
        if token:
            logger.info("Logged in to Hugging Face")

            # Configuration and paths
            tokenizer_name = "facebook/wmt21-dense-24-wide-en-x"
            model_repo_name = "cstr/wmt21ct2_int8"  # Repository to download the model from

            # Download the model snapshot from Hugging Face
            model_path = snapshot_download(repo_id=model_repo_name, token=token.token)
            logger.info(f"Model downloaded to: {model_path}")

            # Load the CTranslate2 model
            translator = ctranslate2.Translator(model_path, device="auto")
            logger.info("CTranslate2 model loaded successfully.")

            # Load the tokenizer
            tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
            tokenizer.src_lang = "en"
            tokenizer.tgt_lang = target_language  # Set target language
            logger.info("Tokenizer loaded successfully.")

            # Define the task based on user input
            task = {
                "url": dataset_url,
                "local_path": "train.parquet",
                "input_file": f"{model_type}_en.jsonl",
                "output_file": f"{model_type}_{target_language}.jsonl",  # Include target language in the filename
                "raw_file": f"{model_type}_{target_language}_raw.jsonl",
                "range_spec": range_specification,
                "model_type": model_type,
                "target_language": target_language  # Include target language in the task
            }

            # Call the translate_dataset function with the provided parameters
            translate_dataset(
                train_url=task["url"],
                local_parquet_path=task["local_path"],
                input_file_path=task["input_file"],
                output_file_path=task["output_file"],
                output_dir=".",
                output_repo_name=output_dataset_name,
                raw_file_path=task["raw_file"],
                token=token.token,
                range_specification=task["range_spec"],
                model_type=task["model_type"],
                translator=translator,
                tokenizer=tokenizer,
                target_language=task["target_language"]  # Pass the target language
            )
            logger.info("Dataset translation completed!")
            return "Dataset translation completed!\n\n### Logs:\n" + log_stream.getvalue()
        else:
            return "Login failed. Please try again."
    except Exception as e:
        logger.error(f"An error occurred in the main function: {e}")
        return f"An error occurred: {e}\n\n### Logs:\n{log_stream.getvalue()}"


# Gradio interface setup
gradio_title = "🧐 WMT21 Dataset Translation"
gradio_desc = """This tool translates english datasets using the WMT21 translation model.
## πŸ’­ What Does This Tool Do:
- Translates datasets (as parquet files) with structures based on the selected model type (see below).
- The translation model (facebook/wmt21-dense-24-wide-en-x) supports as target languages: Hausa (ha), Icelandic (is), Japanese (ja), Czech (cs), Russian (ru), Chinese (zh), German (de)
- Uploads the translated dataset as jsonl to Hugging Face.
- At the moment, this works only on CPU, and therefore is very very slow."""
datasets_desc = """## πŸ“Š Dataset Types:
Note: additional fields will be kept (untranslated), an additional index field is added, which makes it easier to verify results, i.a.
- **mix**: 
  - `prompt`: List of dictionaries with 'content' and 'role' fields (multi-turn conversation).
  - `chosen`: Single dictionary with 'content' and 'role' fields.
  - `rejected`: Single dictionary with 'content' and 'role' fields.
- **ufb_cached**: 
  - `prompt`: String (user input).
  - `chosen`: List of dictionaries with 'content' and 'role' fields.
  - `rejected`: List of dictionaries with 'content' and 'role' fields.
- **ufb**: 
  - like ufb_cached, but we do not check for already translated strings
## πŸ› οΈ Backend:
The translation model is int8 quantized from facebook/wmt21-dense-24-wide-en-x and runs via ctranslate2 on the Hugging Face Hub."""

# Define the theme
theme = gr.themes.Soft(text_size="lg", spacing_size="lg")

with gr.Blocks(theme=theme) as demo:
    gr.HTML(f"""<h1 align="center" id="space-title">{gradio_title}</h1>""")
    gr.Markdown(gradio_desc)

    with gr.Row(variant="panel"):
        gr.Markdown(value="## πŸš€ Login to Hugging Face"),
        gr.LoginButton(min_width=380)

    gr.Markdown(value="🚨 **This is needed to upload the resulting dataset.**")

    with gr.Row(equal_height=False):
        with gr.Column():
            dataset_url = gr.Textbox(label="Input Dataset URL", lines=2, placeholder = "https://huggingface.co/datasets/alvarobartt/dpo-mix-7k-simplified/resolve/main/data/train-00000-of-00001.parquet?download=true")
            model_type = gr.Dropdown(choices=["mix", "ufb_cached", "ufb"], label="Dataset Type")
            output_dataset_name = gr.Textbox(label="Output Dataset Name", lines=1, placeholder = "cstr/translated_datasets")
            range_specification = gr.Textbox(label="Range Specification", lines=1, placeholder="e.g., 1-100")
            target_language = gr.Dropdown(choices=["ha", "is", "ja", "cs", "ru", "zh", "de"], label="Target Language")  # New dropdown for target language
        
        with gr.Column():
            output = gr.Markdown(label="Output")

    submit_btn = gr.Button("Translate Dataset", variant="primary")
    submit_btn.click(main, inputs=[dataset_url, model_type, output_dataset_name, range_specification, target_language], outputs=output)


    gr.Markdown(datasets_desc)

demo.queue(max_size=10).launch(share=True, show_api=True)