indictrans2-conversation / indictrans2.py
sam749's picture
Upload folder using huggingface_hub
3a89850 verified
raw
history blame contribute delete
No virus
2.99 kB
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
from IndicTransTokenizer.IndicTransTokenizer.utils import IndicProcessor
from IndicTransTokenizer.IndicTransTokenizer.tokenizer import IndicTransTokenizer
from peft import PeftModel
from config import lora_repo_id, model_repo_id, batch_size, src_lang, tgt_lang
DIRECTION = "en-indic"
QUANTIZATION = None
IP = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HALF = True if torch.cuda.is_available() else False
def initialize_model_and_tokenizer():
if QUANTIZATION == "4-bit":
qconfig = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
elif QUANTIZATION == "8-bit":
qconfig = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_compute_dtype=torch.bfloat16,
)
else:
qconfig = None
tokenizer = IndicTransTokenizer(direction=DIRECTION)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_repo_id,
trust_remote_code=True,
low_cpu_mem_usage=True,
quantization_config=qconfig,
)
model2 = AutoModelForSeq2SeqLM.from_pretrained(
model_repo_id,
trust_remote_code=True,
low_cpu_mem_usage=True,
quantization_config=qconfig,
)
if qconfig == None:
model = model.to(DEVICE)
model2 = model2.to(DEVICE)
model.eval()
model2.eval()
lora_model = PeftModel.from_pretrained(model2, lora_repo_id)
return tokenizer, model, lora_model
def batch_translate(input_sentences, model, tokenizer):
translations = []
for i in range(0, len(input_sentences), batch_size):
batch = input_sentences[i : i + batch_size]
# Preprocess the batch and extract entity mappings
batch = IP.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
# Tokenize the batch and generate input encodings
inputs = tokenizer(
batch,
src=True,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
# Generate translations using the model
with torch.inference_mode():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens into text
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(), src=False
)
# Postprocess the translations, including entity replacement
translations += IP.postprocess_batch(generated_tokens, lang=tgt_lang)
del inputs
return translations