import torch import torch.nn as nn import torch.nn.functional as F from transformers import BertTokenizer import gradio as gr import math from transformers import MarianTokenizer # Define your model architecture here (copied from your previous definition) class MultiHeadAttention(nn.Module): def __init__(self, d_k, d_model, n_heads, max_len, causal=False): super().__init__() self.d_k = d_k self.n_heads = n_heads self.key = nn.Linear(d_model, d_k * n_heads) self.query = nn.Linear(d_model, d_k * n_heads) self.value = nn.Linear(d_model, d_k * n_heads) self.fc = nn.Linear(d_k * n_heads, d_model) self.causal = causal if causal: cm = torch.tril(torch.ones(max_len, max_len)) self.register_buffer("causal_mask", cm.view(1, 1, max_len, max_len)) def forward(self, q, k, v, pad_mask=None): q = self.query(q) # N x T x (hd_k) k = self.key(k) # N x T x (hd_k) v = self.value(v) # N x T x (hd_v) N = q.shape[0] T_output = q.shape[1] T_input = k.shape[1] q = q.view(N, T_output, self.n_heads, self.d_k).transpose(1, 2) k = k.view(N, T_input, self.n_heads, self.d_k).transpose(1, 2) v = v.view(N, T_input, self.n_heads, self.d_k).transpose(1, 2) attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k) if pad_mask is not None: attn_scores = attn_scores.masked_fill(pad_mask[:, None, None, :] == 0, float('-inf')) if self.causal: attn_scores = attn_scores.masked_fill(self.causal_mask[:, :, :T_output, :T_input] == 0, float('-inf')) attn_weights = F.softmax(attn_scores, dim=-1) A = attn_weights @ v A = A.transpose(1, 2).contiguous().view(N, T_output, self.d_k * self.n_heads) return self.fc(A) class EncoderBlock(nn.Module): def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) self.mha = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False) self.ann = nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model), nn.Dropout(dropout_prob) ) self.dropout = nn.Dropout(p=dropout_prob) def forward(self, x, pad_mask=None): x = self.ln1(x + self.mha(x, x, x, pad_mask)) x = self.ln2(x + self.ann(x)) x = self.dropout(x) return x class DecoderBlock(nn.Module): def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) self.ln3 = nn.LayerNorm(d_model) self.mha1 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=True) self.mha2 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False) self.ann = nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model), nn.Dropout(dropout_prob) ) self.dropout = nn.Dropout(p=dropout_prob) def forward(self, enc_output, dec_input, enc_mask=None, dec_mask=None): x = self.ln1(dec_input + self.mha1(dec_input, dec_input, dec_input, dec_mask)) x = self.ln2(x + self.mha2(x, enc_output, enc_output, enc_mask)) x = self.ln3(x + self.ann(x)) x = self.dropout(x) return x class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=2048, dropout_prob=0.1): super().__init__() self.dropout = nn.Dropout(p=dropout_prob) position = torch.arange(max_len).unsqueeze(1) exp_term = torch.arange(0, d_model, 2) div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model)) pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1), :] return self.dropout(x) class Encoder(nn.Module): def __init__(self, vocab_size, max_len, d_k, d_model, n_heads, n_layers, dropout_prob): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob) transformer_blocks = [EncoderBlock(d_k, d_model, n_heads, max_len, dropout_prob) for _ in range(n_layers)] self.transformer_blocks = nn.Sequential(*transformer_blocks) self.ln = nn.LayerNorm(d_model) def forward(self, x, pad_mask=None): x = self.embedding(x) x = self.pos_encoding(x) for block in self.transformer_blocks: x = block(x, pad_mask) x = self.ln(x) return x class Decoder(nn.Module): def __init__(self, vocab_size, max_len, d_k, d_model, n_heads, n_layers, dropout_prob): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob) transformer_blocks = [DecoderBlock(d_k, d_model, n_heads, max_len, dropout_prob) for _ in range(n_layers)] self.transformer_blocks = nn.Sequential(*transformer_blocks) self.ln = nn.LayerNorm(d_model) self.fc = nn.Linear(d_model, vocab_size) def forward(self, enc_output, dec_input, enc_mask=None, dec_mask=None): x = self.embedding(dec_input) x = self.pos_encoding(x) for block in self.transformer_blocks: x = block(enc_output, x, enc_mask, dec_mask) x = self.ln(x) x = self.fc(x) return x class Transformer(nn.Module): def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def forward(self, enc_input, dec_input, enc_mask=None, dec_mask=None): enc_output = self.encoder(enc_input, enc_mask) dec_output = self.decoder(enc_output, dec_input, enc_mask, dec_mask) return dec_output # Load tokenizer and model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = MarianTokenizer.from_pretrained("tokenizer") encoder = Encoder( vocab_size=tokenizer.vocab_size + 1, max_len=512, d_k=16, d_model=64, n_heads=4, n_layers=2, dropout_prob=0.1 ) decoder = Decoder( vocab_size=tokenizer.vocab_size + 1, max_len=512, d_k=16, d_model=64, n_heads=4, n_layers=2, dropout_prob=0.1 ) transformer = Transformer(encoder, decoder) transformer.load_state_dict(torch.load("en_spanish_translation.pth", map_location=device)) transformer.to(device) transformer.eval() def translate(text): enc_input_ids = tokenizer.encode(text, return_tensors="pt").to(device) enc_attn_mask = torch.ones_like(enc_input_ids).to(device) dec_input_ids = torch.zeros((1, 1), dtype=torch.long).to(device) + tokenizer.cls_token_id for _ in range(512): logits = transformer(enc_input_ids, dec_input_ids, enc_attn_mask) prediction_id = logits[:, -1].argmax(-1) dec_input_ids = torch.hstack((dec_input_ids, prediction_id.view(1, 1))) if prediction_id == 0: break translation = tokenizer.decode(dec_input_ids[0, 1:]) translation = translation.replace("", "").strip() # Remove and strip whitespace return translation iface = gr.Interface(fn=translate, inputs="text", outputs="text") # # iface = gr.Interface(fn=translate, inputs=gr.inputs.Textbox(placeholder="Enter text to translate"), outputs=gr.outputs.Textbox(placeholder="Spanish Translation")) # iface = gr.Interface(fn=translate, inputs=gr.Textbox(placeholder="Enter text to translate"), outputs=gr.Textbox(placeholder="Spanish Translation")) iface.launch()