bart_rom_dev_tl

This model is a fine-tuned version of ar5entum/bart_hin_eng_mt on ar5entum/hindi-english-roman-devnagiri-transliteration-corpus dataset. It achieves the following results on the evaluation set:

  • Loss: 0.0998
  • Bleu: 63.9396
  • Gen Len: 114.6678

Model description

This model is trained on transliteration dataset of roman and devnagiri sentences. The objective of this experiment was to correctly transliterate sentences based on their context.

Inference and Evaluation

import torch
import evaluate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

def batch_long_string(text):
    batch = []
    temp = []
    count = 0
    for word in text.split():
        count+=len(word)
        temp.append(word.strip())
        if count > 40:
            count = 0
            batch.append(" ".join(temp).strip())
            temp = []
    if len(temp) > 0:
        batch.append(" ".join(temp).strip())
    return batch

class BartSmall():
    def __init__(self, model_path = 'ar5entum/bart_rom_dev_tl', device = None):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        if not device:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.model.to(device)

    def predict(self, input_text):
        inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
        pred_ids = self.model.generate(inputs.input_ids, max_length=512, num_beams=4, early_stopping=True)
        prediction = self.tokenizer.decode(pred_ids[0], skip_special_tokens=True)
        return prediction
    
    def predict_batch(self, input_texts, batch_size=32):
        all_predictions = []
        for i in range(0, len(input_texts), batch_size):
            batch_texts = input_texts[i:i+batch_size]
            inputs = self.tokenizer(batch_texts, return_tensors="pt", max_length=512, 
                                    truncation=True, padding=True).to(self.device)
            
            with torch.no_grad():
                pred_ids = self.model.generate(inputs.input_ids, 
                                               max_length=512, 
                                               num_beams=4, 
                                               early_stopping=True)
            
            predictions = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
            all_predictions.extend(predictions)

        return all_predictions

model = BartSmall(device='cuda')

input_texts = [
    "the education researcher evaluated the effectiveness of online learning.",
    "yah abhishek jal, ikshuras, dudh, chaval ka ataa, laal chandan, haldi, ashtagandh, chandan chura, char kalash, kesar vrishti, aarti, sugandhit kalash, mahashantidhara evam mahaarghya ke saath bhagvan Neminath ko samarpit kiya jata hai.",
    "kuch ne kaha ye chand hai kuch ne kaha chehra ter"
    ]
ground_truths = [
    "द एजुकेशन रिसर्चर इवैल्युएटेड द इफेक्टिवनेस ऑफ ऑनलाइन लर्निंग", 
    "यह अभिषेक जल, इक्षुरस, दुध, चावल का आटा, लाल चंदन, हल्दी, अष्टगंध, चंदन चुरा, चार कलश, केसर वृष्टि, आरती, सुगंधित कलश, महाशांतिधारा एवं महाअर्घ्य के साथ भगवान नेमिनाथ को समर्पित किया जाता है।",
    "कुछ ने कहा ये चांद है कुछ ने कहा चेहरा तेरा"
    ]
import time
start = time.time()

def batch_long_string(text):
    batch = []
    temp = []
    count = 0
    for word in text.split():
        count+=len(word)
        temp.append(word.strip())
        if count > 40:
            count = 0
            batch.append(" ".join(temp).strip())
            temp = []
    if len(temp) > 0:
        batch.append(" ".join(temp).strip())
    return batch

predictions = [" ".join([" ".join(model.predict_batch(batch, batch_size=len(batch))) for batch in batch_long_string(text)]) for text in input_texts]
end = time.time()
print("TIME: ", end-start)
for i in range(len(input_texts)):
    print("‾‾‾‾‾‾‾‾‾‾‾‾")
    print("Input text:\t", input_texts[i])
    print("Prediction:\t", predictions[i])
    print("Ground Truth:\t", ground_truths[i])
bleu = evaluate.load("bleu")
results = bleu.compute(predictions=predictions, references=ground_truths)
print(results)

# TIME:  9.683340787887573
# ‾‾‾‾‾‾‾‾‾‾‾‾
# Input text:	 the education researcher evaluated the effectiveness of online learning.
# Prediction:	 द एजुकेशन रिसर्चर इवैल्युएट्स द इफेक्टिंग ओफ ऑनाइनल लर्निंग
# Ground Truth:	 द एजुकेशन रिसर्चर इवैल्युएटेड द इफेक्टिवनेस ऑफ ऑनलाइन लर्निंग
# ‾‾‾‾‾‾‾‾‾‾‾‾
# Input text:	 yah abhishek jal, ikshuras, dudh, chaval ka ataa, laal chandan, haldi, ashtagandh, chandan chura, char kalash, kesar vrishti, aarti, sugandhit kalash, mahashantidhara evam mahaarghya ke saath bhagvan Neminath ko samarpit kiya jata hai.
# Prediction:	 यह अभिषेक जल, इक्षुरस, दुध, चावल का आता, लाल चन्दन, हालडी, अष्टगंध, चन्दन चुरा, चार कलाश, केसर वृष्टि, आर्ती, सुगंधित कलाश, महासंतिधारा एवं महार्घ्य के साथ भगवान नेमीनाथ को समर्पित किया जाता है।
# Ground Truth:	 यह अभिषेक जल, इक्षुरस, दुध, चावल का आटा, लाल चंदन, हल्दी, अष्टगंध, चंदन चुरा, चार कलश, केसर वृष्टि, आरती, सुगंधित कलश, महाशांतिधारा एवं महाअर्घ्य के साथ भगवान नेमिनाथ को समर्पित किया जाता है।
# ‾‾‾‾‾‾‾‾‾‾‾‾
# Input text:	 kuch ne kaha ye chand hai kuch ne kaha chehra ter
# Prediction:	 कुछ ने कहा ये चाँद है कुछ ने कहा चेहरा तेर
# Ground Truth:	 कुछ ने कहा ये चांद है कुछ ने कहा चेहरा तेरा
# {'bleu': 0.43170068926336663, 'precisions': [0.7538461538461538, 0.532258064516129, 0.3728813559322034, 0.23214285714285715], 'brevity_penalty': 1.0, 'length_ratio': 1.0, 'translation_length': 65, 'reference_length': 65}

Training Procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 3e-05
  • train_batch_size: 100
  • eval_batch_size: 40
  • seed: 42
  • distributed_type: multi-GPU
  • num_devices: 2
  • total_train_batch_size: 200
  • total_eval_batch_size: 80
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_steps: 80
  • num_epochs: 100.0

Training results

Training Loss Epoch Step Validation Loss Bleu Gen Len
1.1468 1.0 71 1.0356 0.1783 127.8914
0.9193 2.0 142 0.7876 0.7522 120.098
0.714 3.0 213 0.5704 2.2388 116.7362
0.5751 4.0 284 0.4415 5.169 115.8671
0.4807 5.0 355 0.3694 9.2386 114.9026
0.4178 6.0 426 0.3220 13.4352 114.9967
0.3717 7.0 497 0.2920 16.5527 114.3776
0.3355 8.0 568 0.2728 18.8968 113.7553
0.3103 9.0 639 0.2502 22.688 114.4191
0.2916 10.0 710 0.2346 24.9505 114.3487
0.2696 11.0 781 0.2237 26.5227 114.2283
0.2583 12.0 852 0.2129 28.6141 114.0349
0.2438 13.0 923 0.2019 30.3471 114.3934
0.23 14.0 994 0.1972 31.3042 114.2145
0.2158 15.0 1065 0.1871 33.5445 114.5664
0.2108 16.0 1136 0.1811 34.5349 114.2928
0.2033 17.0 1207 0.1749 35.8154 114.4217
0.1901 18.0 1278 0.1706 36.853 114.55
0.1879 19.0 1349 0.1665 37.8791 114.4046
0.1772 20.0 1420 0.1605 39.197 114.6211
0.167 21.0 1491 0.1582 40.4274 114.5737
0.1678 22.0 1562 0.1549 40.4937 114.377
0.1621 23.0 1633 0.1508 42.0233 114.5882
0.1585 24.0 1704 0.1477 42.7916 114.573
0.1494 25.0 1775 0.1449 43.8836 114.6026
0.1477 26.0 1846 0.1424 44.1819 114.5197
0.1441 27.0 1917 0.1399 44.9919 114.6526
0.1379 28.0 1988 0.1375 45.8493 114.5329
0.1354 29.0 2059 0.1358 45.7367 114.4757
0.1325 30.0 2130 0.1330 46.9613 114.698
0.1288 31.0 2201 0.1315 47.5834 114.6257
0.1262 32.0 2272 0.1300 47.9596 114.5145
0.1232 33.0 2343 0.1277 48.2481 114.6474
0.1173 34.0 2414 0.1264 48.8469 114.623
0.1138 35.0 2485 0.1248 49.5157 114.6112
0.1126 36.0 2556 0.1237 49.6457 114.5947
0.1125 37.0 2627 0.1225 50.4627 114.6875
0.1101 38.0 2698 0.1207 50.9736 114.6388
0.1069 39.0 2769 0.1198 51.5928 114.6579
0.1035 40.0 2840 0.1185 52.0712 114.6132
0.096 41.0 2911 0.1175 52.6016 114.6441
0.0958 42.0 2982 0.1172 52.9595 114.6066
0.0967 43.0 3053 0.1160 52.6965 114.6461
0.0948 44.0 3124 0.1151 53.5073 114.6737
0.0957 45.0 3195 0.1144 53.5772 114.6822
0.0922 46.0 3266 0.1135 54.2084 114.6612
0.0903 47.0 3337 0.1127 54.2512 114.6368
0.088 48.0 3408 0.1119 55.1423 114.6947
0.0869 49.0 3479 0.1109 55.4669 114.6467
0.0849 50.0 3550 0.1110 55.7087 114.5855
0.0825 51.0 3621 0.1105 55.5851 114.6349
0.0818 52.0 3692 0.1097 57.163 114.727
0.0811 53.0 3763 0.1089 57.233 114.5928
0.0767 54.0 3834 0.1083 57.0785 114.6822
0.0751 55.0 3905 0.1081 57.4657 114.6487
0.0737 56.0 3976 0.1078 57.6215 114.848
0.0766 57.0 4047 0.1071 57.8275 114.5743
0.0766 58.0 4118 0.1064 58.1423 114.6309
0.0716 59.0 4189 0.1056 58.5167 114.7026
0.071 60.0 4260 0.1053 59.226 114.627
0.0715 61.0 4331 0.1054 59.1511 114.6697
0.0709 62.0 4402 0.1046 59.3669 114.6816
0.0703 63.0 4473 0.1046 59.418 114.6171
0.0686 64.0 4544 0.1039 60.1412 114.6961
0.066 65.0 4615 0.1037 60.4565 114.7559
0.0647 66.0 4686 0.1039 59.9588 114.6382
0.0668 67.0 4757 0.1030 60.5026 114.7447
0.0649 68.0 4828 0.1035 60.2735 114.6099
0.0637 69.0 4899 0.1032 60.6524 114.6171
0.0641 70.0 4970 0.1029 60.7721 114.7461
0.0639 71.0 5041 0.1025 61.1837 114.6901
0.062 72.0 5112 0.1024 61.3516 114.7447
0.0588 73.0 5183 0.1025 61.3766 114.6539
0.0609 74.0 5254 0.1019 61.8364 114.7467
0.0592 75.0 5325 0.1020 61.7948 114.7072
0.0604 76.0 5396 0.1019 61.8981 114.6921
0.0593 77.0 5467 0.1013 61.9623 114.6921
0.057 78.0 5538 0.1013 62.2082 114.6553
0.0595 79.0 5609 0.1011 62.3174 114.6684
0.0565 80.0 5680 0.1010 62.1364 114.6158
0.0592 81.0 5751 0.1009 62.6892 114.6671
0.0563 82.0 5822 0.1010 62.431 114.7099
0.0544 83.0 5893 0.1007 62.78 114.6579
0.0546 84.0 5964 0.1009 62.8921 114.6112
0.0558 85.0 6035 0.1007 62.7137 114.7289
0.0529 86.0 6106 0.1008 62.859 114.6401
0.0549 87.0 6177 0.1003 63.1903 114.6934
0.0544 88.0 6248 0.1003 63.2949 114.6888
0.0535 89.0 6319 0.1005 63.3252 114.6546
0.0547 90.0 6390 0.0999 63.3835 114.7
0.0533 91.0 6461 0.0999 63.5284 114.6875
0.0523 92.0 6532 0.1000 63.6207 114.7145
0.0533 93.0 6603 0.0999 63.5598 114.723
0.0545 94.0 6674 0.0999 63.6451 114.7303
0.052 95.0 6745 0.0999 63.6712 114.7283
0.0527 96.0 6816 0.1001 63.7187 114.6711
0.0511 97.0 6887 0.0999 63.9161 114.6671
0.0531 98.0 6958 0.0999 63.8758 114.6645
0.0539 99.0 7029 0.0999 63.9162 114.6566
0.0533 100.0 7100 0.0998 63.9396 114.6678

Framework versions

  • Transformers 4.45.0.dev0
  • Pytorch 2.4.0+cu121
  • Datasets 2.21.0
  • Tokenizers 0.19.1
Downloads last month
23
Safetensors
Model size
49.4M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for ar5entum/bart_rom_dev_tl

Finetuned
(1)
this model

Dataset used to train ar5entum/bart_rom_dev_tl