import random
import numpy as np
import streamlit as st
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import time

class Organelle:
    def __init__(self, type):
        self.type = type

class Modification:
    def __init__(self, name, effect):
        self.name = name
        self.effect = effect

class Cell:
    def __init__(self, x, y, cell_type="prokaryote"):
        self.x = x
        self.y = y
        self.energy = 100
        self.cell_type = cell_type
        self.organelles = []
        self.modifications = []
        self.size = 1
        self.color = "lightblue"
        self.division_threshold = 150
        
        self.update_properties()

    def update_properties(self):
        if self.cell_type == "early_eukaryote":
            self.organelles.append(Organelle("nucleus"))
            self.color = "green"
            self.size = 2
        elif self.cell_type == "advanced_eukaryote":
            self.organelles.extend([Organelle("nucleus"), Organelle("mitochondria")])
            self.color = "red"
            self.size = 3
        elif self.cell_type == "plant_like":
            self.organelles.extend([Organelle("nucleus"), Organelle("mitochondria"), Organelle("chloroplast")])
            self.color = "darkgreen"
            self.size = 4

    def move(self, environment):
        dx = random.uniform(-1, 1)
        dy = random.uniform(-1, 1)
        self.x = max(0, min(environment.width - 1, self.x + dx))
        self.y = max(0, min(environment.height - 1, self.y + dy))
        self.energy -= 0.5 * self.size

    def feed(self, environment):
        base_energy = environment.grid[int(self.y)][int(self.x)] * 0.1
        if "chloroplast" in [org.type for org in self.organelles]:
            base_energy += environment.light_level * 2
        
        for mod in self.modifications:
            base_energy *= mod.effect

        self.energy += base_energy
        environment.grid[int(self.y)][int(self.x)] *= 0.9

    def can_divide(self):
        return self.energy > self.division_threshold

    def divide(self):
        if self.can_divide():
            self.energy /= 2
            new_cell = Cell(self.x, self.y, self.cell_type)
            new_cell.organelles = self.organelles.copy()
            new_cell.modifications = self.modifications.copy()
            return new_cell
        return None

    def can_fuse(self, other):
        return (self.cell_type == other.cell_type and
                random.random() < 0.005)  # 0.5% chance of fusion

    def fuse(self, other):
        new_cell_type = self.cell_type
        if self.cell_type == "prokaryote":
            new_cell_type = "early_eukaryote"
        elif self.cell_type == "early_eukaryote":
            new_cell_type = "advanced_eukaryote"

        new_cell = Cell(
            (self.x + other.x) / 2,
            (self.y + other.y) / 2,
            new_cell_type
        )
        new_cell.energy = self.energy + other.energy
        new_cell.organelles = list(set(self.organelles + other.organelles))
        new_cell.modifications = list(set(self.modifications + other.modifications))
        new_cell.update_properties()
        return new_cell

    def acquire_modification(self):
        possible_mods = [
            Modification("Enhanced metabolism", 1.2),
            Modification("Thick cell wall", 0.8),
            Modification("Efficient energy storage", 1.1),
            Modification("Rapid division", 0.9)
        ]
        new_mod = random.choice(possible_mods)
        if new_mod not in self.modifications:
            self.modifications.append(new_mod)
            self.color = "purple"  # Visual indicator of modification

class Environment:
    def __init__(self, width, height):
        self.width = width
        self.height = height
        self.grid = np.random.rand(height, width) * 10
        self.light_level = 5
        self.cells = []
        self.time = 0
        self.population_history = {
            "prokaryote": [], "early_eukaryote": [],
            "advanced_eukaryote": [], "plant_like": [], "modified": []
        }

    def add_cell(self, cell):
        self.cells.append(cell)

    def update(self):
        self.time += 1
        self.grid += np.random.rand(self.height, self.width) * 0.1
        self.light_level = 5 + np.sin(self.time / 100) * 2

        new_cells = []
        cells_to_remove = []

        for cell in self.cells:
            cell.move(self)
            cell.feed(self)

            if cell.energy <= 0:
                cells_to_remove.append(cell)
            elif cell.can_divide():
                new_cell = cell.divide()
                if new_cell:
                    new_cells.append(new_cell)

        # Handle cell fusion
        for i, cell1 in enumerate(self.cells):
            for cell2 in self.cells[i+1:]:
                if cell1.can_fuse(cell2):
                    new_cell = cell1.fuse(cell2)
                    new_cells.append(new_cell)
                    cells_to_remove.extend([cell1, cell2])

        # Add new cells and remove dead/fused cells
        self.cells.extend(new_cells)
        self.cells = [cell for cell in self.cells if cell not in cells_to_remove]

        # Introduce mutations and modifications
        for cell in self.cells:
            if random.random() < 0.0001:  # 0.01% chance of mutation
                if cell.cell_type == "early_eukaryote":
                    cell.cell_type = "advanced_eukaryote"
                elif cell.cell_type == "advanced_eukaryote" and random.random() < 0.5:
                    cell.cell_type = "plant_like"
                cell.update_properties()
            
            if random.random() < 0.0005:  # 0.05% chance of acquiring a modification
                cell.acquire_modification()

        # Record population counts
        for cell_type in self.population_history.keys():
            if cell_type != "modified":
                count = len([cell for cell in self.cells if cell.cell_type == cell_type and not cell.modifications])
            else:
                count = len([cell for cell in self.cells if cell.modifications])
            self.population_history[cell_type].append(count)

    def get_visualization_data(self):
        cell_data = {
            "prokaryote": {"x": [], "y": [], "size": [], "color": "lightblue", "symbol": "circle"},
            "early_eukaryote": {"x": [], "y": [], "size": [], "color": "green", "symbol": "square"},
            "advanced_eukaryote": {"x": [], "y": [], "size": [], "color": "red", "symbol": "diamond"},
            "plant_like": {"x": [], "y": [], "size": [], "color": "darkgreen", "symbol": "star"},
            "modified": {"x": [], "y": [], "size": [], "color": "purple", "symbol": "cross"}
        }

        for cell in self.cells:
            cell_type = "modified" if cell.modifications else cell.cell_type
            cell_data[cell_type]["x"].append(cell.x)
            cell_data[cell_type]["y"].append(cell.y)
            cell_data[cell_type]["size"].append(cell.size * 3)

        return cell_data, self.population_history

def setup_figure(env):
    cell_types = ["prokaryote", "early_eukaryote", "advanced_eukaryote", "plant_like", "modified"]
    fig = make_subplots(rows=2, cols=3, 
                        subplot_titles=("Cell Distribution", "Total Population", 
                                        "Prokaryotes", "Early Eukaryotes", 
                                        "Advanced Eukaryotes", "Plant-like & Modified"),
                        vertical_spacing=0.1,
                        horizontal_spacing=0.05)

    # Cell distribution
    for cell_type, data in env.get_visualization_data()[0].items():
        fig.add_trace(go.Scatter(
            x=data["x"], y=data["y"], mode='markers',
            marker=dict(color=data["color"], size=data["size"], symbol=data["symbol"]),
            name=cell_type
        ), row=1, col=1)

    # Total population over time
    for cell_type, counts in env.population_history.items():
        fig.add_trace(go.Scatter(y=counts, mode='lines', name=cell_type), row=1, col=2)

    # Individual population charts
    for i, cell_type in enumerate(cell_types):
        if cell_type == "modified":
            fig.add_trace(go.Scatter(y=env.population_history[cell_type], mode='lines', 
                                     name=cell_type, line=dict(color="purple")), row=2, col=3)
        elif cell_type == "plant_like":
            fig.add_trace(go.Scatter(y=env.population_history[cell_type], mode='lines', 
                                     name=cell_type, line=dict(color="darkgreen")), row=2, col=3)
        else:
            fig.add_trace(go.Scatter(y=env.population_history[cell_type], mode='lines', 
                                     name=cell_type), row=2, col=i+1)

    fig.update_xaxes(title_text="X", row=1, col=1)
    fig.update_yaxes(title_text="Y", row=1, col=1)
    fig.update_xaxes(title_text="Time", row=1, col=2)
    fig.update_yaxes(title_text="Population", row=1, col=2)

    for i in range(1, 4):
        fig.update_xaxes(title_text="Time", row=2, col=i)
        fig.update_yaxes(title_text="Population", row=2, col=i)

    fig.update_layout(height=800, width=1200, title_text="Advanced Cell Evolution Simulation")

    return fig

# Streamlit app
st.title("Advanced Cell Evolution Simulation")

num_steps = st.slider("Number of simulation steps", 100, 2000, 1000)
initial_cells = st.slider("Initial number of cells", 10, 200, 100)
update_interval = st.slider("Update interval (milliseconds)", 50, 500, 100)

if st.button("Run Simulation"):
    env = Environment(100, 100)
    
    # Add initial cells
    for _ in range(initial_cells):
        cell = Cell(random.uniform(0, env.width), random.uniform(0, env.height))
        env.add_cell(cell)
    
    # Set up the figure
    fig = setup_figure(env)
    chart = st.plotly_chart(fig, use_container_width=True)
    
    # Run simulation
    for step in range(num_steps):
        env.update()
        
        # Update the figure data
        with fig.batch_update():
            cell_data, population_history = env.get_visualization_data()
            for i, (cell_type, data) in enumerate(cell_data.items()):
                fig.data[i].x = data["x"]
                fig.data[i].y = data["y"]
                fig.data[i].marker.size = data["size"]
            
            for i, (cell_type, counts) in enumerate(population_history.items()):
                fig.data[i+5].y = counts  # +5 because we have 5 cell types in the first subplot
                if cell_type != "modified" and cell_type != "plant_like":
                    fig.data[i+10].y = counts  # Update individual population charts
                else:
                    fig.data[13].y = population_history["plant_like"]
                    fig.data[14].y = population_history["modified"]
        
        fig.layout.title.text = f"Advanced Cell Evolution Simulation (Time: {env.time})"
        
        # Update the chart
        chart.plotly_chart(fig, use_container_width=True)
        
        time.sleep(update_interval / 1000)  # Convert milliseconds to seconds

    st.write("Simulation complete!")