import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn


class LogisticRegressionTorch(nn.Module):

    def __init__(self,
                 input_dim: int,
                 output_dim: int):

        super(LogisticRegressionTorch, self).__init__()
        self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.batch_norm(x)
        out = self.linear(x)
        return out

class BertClassifier(nn.Module):

    def __init__(self,
                 bert_model: AutoModel,
                 classifier: LogisticRegressionTorch,
                 num_labels: int):

        super(BertClassifier, self).__init__()
        self.bert = bert_model  # Assume bert_model is an instance of a pre-trained BertModel
        self.classifier = classifier
        self.num_labels = num_labels

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
                token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
        # Extract outputs from the BERT model
        outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        
        # Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
        pooled_output = outputs.hidden_states[-1][:, 0, :]

        assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
        # to-do later!

        # Pass the pooled output to the classifier to get the logits
        logits = self.classifier(pooled_output)

        # Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
        loss = None

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            pred = logits.view(-1, self.num_labels)
            observed = labels.view(-1)
            loss = loss_fct(pred, observed)

        # Return the loss and logits
        return loss, logits



# Load the Hugging Face model and tokenizer

metadata_features = 0
N_UNIQUE_CLASSES = 38 ## or 38

base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)

# Initialize the classifier
input_size = 768 + metadata_features # featurizer output size + metadata size
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)

# Load Weights
model_weights_path = 'gena-blastln-bs33-lr4e-05-S168.pth'
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))

base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])

# Creating Model
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)


# Define a function to process the DNA sequence
def analyze_dna(sequence):
    # Preprocess the input sequence
    inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)

    print("tokenization done.")
    # Get model predictions
    _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

    print("Forward pass done.")
    
    # Convert logits to probabilities
    probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()

    print("Probabilities, done.")
    # Get the top 5 most likely classes
    top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
    top_5_probs = [probabilities[i] for i in top_5_indices]
    
    # Prepare the output as a list of tuples (class_index, probability)
    result = [(index, prob) for index, prob in zip(top_5_indices, top_5_probs)]
    
    return result

# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")

# Launch the interface
demo.launch()