jam_shield_LLM_app / trainer.py
asataura's picture
Adding the description of the app
61e3a25
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import streamlit as st
from DDQN import DoubleDeepQNetwork
from antiJamEnv import AntiJamEnv
from langchain import HuggingFaceHub, PromptTemplate, LLMChain
repo_id = "tiiuae/falcon-7b-instruct"
huggingfacehub_api_token = "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" # Replace with your actual API token
llm = HuggingFaceHub(huggingfacehub_api_token=huggingfacehub_api_token,
repo_id=repo_id,
model_kwargs={"temperature":0.2, "max_new_tokens":2000})
template = """You are an AI trained to analyze and provide insights about training graphs in the domain of deep
reinforcement learning. Given the following data about a graph: {data}, provide detailed insights. """
prompt = PromptTemplate(template=template, input_variables=["data"])
llm_chain = LLMChain(prompt=prompt, verbose=True, llm=llm)
def train(jammer_type, channel_switching_cost):
st.markdown("""
In this demonstration, we address the challenge of mitigating jamming attacks using Deep Reinforcement Learning (DRL).
The process comprises three main steps:
1. **DRL Training**: An agent is trained using DRL to tackle jamming attacks.
2. **Training Performance Visualization**: Post-training, the performance metrics (rewards, exploration rate, etc.) are visualized to assess the agent's proficiency.
3. **Insights Generation with Falcon 7B LLM**: Leveraging the Falcon 7B LLM, we generate insights from the training graphs, elucidating the agent's behavior and achievements.
""", unsafe_allow_html=True)
st.subheader("DRL Training Progress")
progress_bar = st.progress(0)
status_text = st.empty()
env = AntiJamEnv(jammer_type, channel_switching_cost)
ob_space = env.observation_space
ac_space = env.action_space
s_size = ob_space.shape[0]
a_size = ac_space.n
max_env_steps = 100
TRAIN_Episodes = 25
env._max_episode_steps = max_env_steps
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.999
discount_rate = 0.95
lr = 0.001
batch_size = 32
DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
rewards = []
epsilons = []
for e in range(TRAIN_Episodes):
state = env.reset()
state = np.reshape(state, [1, s_size])
tot_rewards = 0
for time in range(max_env_steps):
action = DDQN_agent.action(state)
next_state, reward, done, _ = env.step(action)
next_state = np.reshape(next_state, [1, s_size])
tot_rewards += reward
DDQN_agent.store(state, action, reward, next_state, done)
state = next_state
if len(DDQN_agent.memory) > batch_size:
DDQN_agent.experience_replay(batch_size)
if done or time == max_env_steps - 1:
rewards.append(tot_rewards)
epsilons.append(DDQN_agent.epsilon)
status_text.text(
f"Episode: {e + 1}/{TRAIN_Episodes}, Reward: {tot_rewards}, Epsilon: {DDQN_agent.epsilon:.3f}")
progress_bar.progress((e + 1) / TRAIN_Episodes)
break
DDQN_agent.update_target_from_model()
if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
break
st.sidebar.success("DRL Training completed!")
# Plotting
rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
solved_threshold = max_env_steps - 0.10 * max_env_steps
# Create a new Streamlit figure for the training graph
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(rewards, label='Rewards')
ax.plot(rolling_average, color='black', label='Rolling Average')
ax.axhline(y=solved_threshold, color='r', linestyle='-', label='Solved Line')
eps_graph = [100 * x for x in epsilons]
ax.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
ax.set_xlabel('Episodes')
ax.set_ylabel('Rewards')
ax.set_title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
ax.legend()
insights = generate_insights_langchain(rewards, rolling_average, epsilons, solved_threshold)
with st.container():
col1, col2 = st.columns(2)
with col1:
st.subheader("Training Graph")
st.pyplot(fig)
with col2:
st.subheader("Graph Explanation")
st.write(insights)
plt.close(fig) # Close the figure to release resources
return DDQN_agent
def generate_insights_langchain(rewards, rolling_average, epsilons, solved_threshold):
data_description = (
f"The graph represents training rewards over episodes. "
f"The actual rewards range from {min(rewards):.2f} to {max(rewards):.2f} with an average of {np.mean(rewards):.2f}. "
f"The rolling average values range from {min(rolling_average):.2f} to {max(rolling_average):.2f} with an average of {np.mean(rolling_average):.2f}. "
f"The epsilon values range from {min(epsilons):.2f} to {max(epsilons):.2f} with an average exploration rate of {np.mean(epsilons):.2f}. "
f"The solved threshold is set at {solved_threshold:.2f}."
)
result = llm_chain.predict(data=data_description)
return result