import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os
import huggingface_hub
from huggingface_hub import hf_hub_download, login

# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}

# Update labels based on the given conditions
for k, v in int_to_label.items():
    if "KOREA" in v:
        int_to_label[k] = "KOREA"
    elif "KINGDOM" in v:
        int_to_label[k] = "UK"
    elif "RUSSIAN" in v:
        int_to_label[k] = "RUSSIA"

class LogisticRegressionTorch(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(LogisticRegressionTorch, self).__init__()
        self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.batch_norm(x)
        out = self.linear(x)
        return out

class BertClassifier(nn.Module):
    def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
        super(BertClassifier, self).__init__()
        self.bert = bert_model
        self.classifier = classifier
        self.num_labels = num_labels

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        pooled_output = outputs.hidden_states[-1][:, 0, :]
        logits = self.classifier(pooled_output)
        return logits

def load_model():
    metadata_features = 0
    N_UNIQUE_CLASSES = 38  

    base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
    tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)

    input_size = 768 + metadata_features
    log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)

    token = os.getenv('HUGGINGFACE_TOKEN')
    if token is None:
        raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")

    login(token=token)
    file_path = hf_hub_download(
        repo_id="mawairon/noo_test",
        filename="gena-blastln-bs33-lr4e-05-S168.pth",
        use_auth_token=token
    )
    weights = torch.load(file_path, map_location=torch.device('cpu'))

    base_model.load_state_dict(weights['model_state_dict'])
    log_reg.load_state_dict(weights['log_reg_state_dict'])

    model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
    model.eval()

    return model, tokenizer

model, tokenizer = load_model()

def analyze_dna(sequence, password):

    pw_token = os.getenv('PASSWORD')
    if password != pw_token:
        return {"error": "Invalid password"}, ""

    try:
        if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
            return {"error": "Sequence contains invalid characters"}, ""

        if len(sequence) < 300:
            return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""

        inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
        logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

        probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
        top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
        top_5_probs = [probabilities[i] for i in top_5_indices]
        top_5_labels = [int_to_label[i] for i in top_5_indices]
        result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]

        fig, ax = plt.subplots(figsize=(10, 6))
        ax.barh(top_5_labels, top_5_probs, color='skyblue')
        ax.set_xlabel('Probability')
        ax.set_title('Top 5 Most Likely Labels')
        plt.gca().invert_yaxis()

        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        buf.close()

        return f'<img src="data:image/png;base64,{image_base64}" />'

    except Exception as e:
        return {"error": str(e)}, ""

# Create a Gradio interface
demo = gr.Interface(
    fn=analyze_dna,
    inputs=[gr.Textbox(label="DNA Sequence"), gr.Textbox(label="Password", type="password")],
    outputs=["html"]
)

# Launch the interface
demo.launch()