File size: 5,381 Bytes
d67dca9
 
 
 
 
a5f2527
9ca06b7
c4ae7c9
88cee09
8d7cbbe
88cee09
 
6e5f7b8
88cee09
 
 
 
 
 
 
 
d2b8088
88cee09
d67dca9
 
c4ae7c9
76efa5e
 
 
 
 
 
 
 
 
8b7d658
90f9ad1
 
6b43fcc
c4ae7c9
d67dca9
 
 
 
 
 
61e3a25
d67dca9
 
6b43fcc
d67dca9
 
 
 
 
 
 
6b43fcc
 
d67dca9
 
 
6b43fcc
d67dca9
6b43fcc
d67dca9
 
 
 
6b43fcc
d67dca9
 
 
 
6b43fcc
 
 
 
8d7cbbe
 
6b43fcc
 
 
d67dca9
6b43fcc
d67dca9
 
 
8b7d658
6b43fcc
d67dca9
a5f2527
8d7cbbe
6b43fcc
 
 
 
8d7cbbe
c4ae7c9
6b43fcc
 
 
 
 
a5f2527
88cee09
8d7cbbe
2b93f0a
 
58a0e00
 
 
 
 
 
 
8d7cbbe
a5f2527
 
d67dca9
8d7cbbe
 
 
88cee09
 
8d7cbbe
a86d2bc
 
 
 
8d7cbbe
 
88cee09
d779fa9
88cee09
 
8d7cbbe
a86d2bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import streamlit as st
from DDQN import DoubleDeepQNetwork
from antiJamEnv import AntiJamEnv
from langchain import HuggingFaceHub, PromptTemplate, LLMChain


repo_id = "tiiuae/falcon-7b-instruct"
huggingfacehub_api_token = "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"  # Replace with your actual API token

llm = HuggingFaceHub(huggingfacehub_api_token=huggingfacehub_api_token,
                     repo_id=repo_id,
                     model_kwargs={"temperature":0.2, "max_new_tokens":2000})

template = """You are an AI trained to analyze and provide insights about training graphs in the domain of deep 
reinforcement learning. Given the following data about a graph: {data}, provide detailed insights. """

prompt = PromptTemplate(template=template, input_variables=["data"])
llm_chain = LLMChain(prompt=prompt, verbose=True, llm=llm)


def train(jammer_type, channel_switching_cost):
    st.markdown("""
    In this demonstration, we address the challenge of mitigating jamming attacks using Deep Reinforcement Learning (DRL). 
    The process comprises three main steps:

    1. **DRL Training**: An agent is trained using DRL to tackle jamming attacks.
    2. **Training Performance Visualization**: Post-training, the performance metrics (rewards, exploration rate, etc.) are visualized to assess the agent's proficiency.
    3. **Insights Generation with Falcon 7B LLM**: Leveraging the Falcon 7B LLM, we generate insights from the training graphs, elucidating the agent's behavior and achievements.

    """, unsafe_allow_html=True)
    st.subheader("DRL Training Progress")
    progress_bar = st.progress(0)
    status_text = st.empty()

    env = AntiJamEnv(jammer_type, channel_switching_cost)
    ob_space = env.observation_space
    ac_space = env.action_space

    s_size = ob_space.shape[0]
    a_size = ac_space.n
    max_env_steps = 100
    TRAIN_Episodes = 25
    env._max_episode_steps = max_env_steps

    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.999
    discount_rate = 0.95
    lr = 0.001
    batch_size = 32

    DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
    rewards = []
    epsilons = []

    for e in range(TRAIN_Episodes):
        state = env.reset()
        state = np.reshape(state, [1, s_size])
        tot_rewards = 0
        for time in range(max_env_steps):
            action = DDQN_agent.action(state)
            next_state, reward, done, _ = env.step(action)
            next_state = np.reshape(next_state, [1, s_size])
            tot_rewards += reward
            DDQN_agent.store(state, action, reward, next_state, done)
            state = next_state

            if len(DDQN_agent.memory) > batch_size:
                DDQN_agent.experience_replay(batch_size)

            if done or time == max_env_steps - 1:
                rewards.append(tot_rewards)
                epsilons.append(DDQN_agent.epsilon)
                status_text.text(
                    f"Episode: {e + 1}/{TRAIN_Episodes}, Reward: {tot_rewards}, Epsilon: {DDQN_agent.epsilon:.3f}")
                progress_bar.progress((e + 1) / TRAIN_Episodes)
                break

        DDQN_agent.update_target_from_model()

        if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
            break

    st.sidebar.success("DRL Training completed!")

    # Plotting
    rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
    solved_threshold = max_env_steps - 0.10 * max_env_steps
    # Create a new Streamlit figure for the training graph
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(rewards, label='Rewards')
    ax.plot(rolling_average, color='black', label='Rolling Average')
    ax.axhline(y=solved_threshold, color='r', linestyle='-', label='Solved Line')
    eps_graph = [100 * x for x in epsilons]
    ax.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
    ax.set_xlabel('Episodes')
    ax.set_ylabel('Rewards')
    ax.set_title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
    ax.legend()

    insights = generate_insights_langchain(rewards, rolling_average, epsilons, solved_threshold)

    with st.container():
        col1, col2 = st.columns(2)

        with col1:
            st.subheader("Training Graph")
            st.pyplot(fig)

        with col2:
            st.subheader("Graph Explanation")
            st.write(insights)

    plt.close(fig)  # Close the figure to release resources

    return DDQN_agent


def generate_insights_langchain(rewards, rolling_average, epsilons, solved_threshold):
    data_description = (
        f"The graph represents training rewards over episodes. "
        f"The actual rewards range from {min(rewards):.2f} to {max(rewards):.2f} with an average of {np.mean(rewards):.2f}. "
        f"The rolling average values range from {min(rolling_average):.2f} to {max(rolling_average):.2f} with an average of {np.mean(rolling_average):.2f}. "
        f"The epsilon values range from {min(epsilons):.2f} to {max(epsilons):.2f} with an average exploration rate of {np.mean(epsilons):.2f}. "
        f"The solved threshold is set at {solved_threshold:.2f}."
    )

    result = llm_chain.predict(data=data_description)
    return result