Spaces:
Sleeping
Sleeping
Integrating the falcon7B LLM
Browse files- app.py +24 -55
- trainer.py +31 -12
app.py
CHANGED
@@ -2,73 +2,42 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
import streamlit as st
|
5 |
-
import os
|
6 |
from trainer import train
|
7 |
from tester import test
|
8 |
-
import transformers
|
9 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
-
import torch
|
11 |
|
12 |
|
13 |
-
def
|
14 |
-
|
15 |
-
return agent
|
16 |
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
# pipeline = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=100,
|
26 |
-
# temperature=0.7)
|
27 |
|
28 |
-
st.
|
29 |
|
30 |
-
|
31 |
-
|
|
|
32 |
|
33 |
-
if mode == "Auto":
|
34 |
-
jammer_type = "dynamic"
|
35 |
-
channel_switching_cost = 0.1
|
36 |
-
else:
|
37 |
-
jammer_type = st.sidebar.selectbox("Select Jammer Type", ["constant", "sweeping", "random", "dynamic"])
|
38 |
-
channel_switching_cost = st.sidebar.selectbox("Select Channel Switching Cost", [0, 0.05, 0.1, 0.15, 0.2])
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
|
44 |
-
start_button = st.sidebar.button('Start')
|
45 |
|
46 |
-
|
47 |
-
agent, rewards = perform_training(jammer_type, channel_switching_cost)
|
48 |
-
st.subheader("Generating Insights of the DRL-Training")
|
49 |
-
# text = pipeline("Discuss this topic: Integrating LLMs to DRL-based anti-jamming.")
|
50 |
-
# st.write(text)
|
51 |
test(agent, jammer_type, channel_switching_cost)
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
# pipeline = transformers.pipeline(
|
57 |
-
# "text-generation",
|
58 |
-
# model=model,
|
59 |
-
# tokenizer=tokenizer,
|
60 |
-
# torch_dtype=torch.bfloat16,
|
61 |
-
# trust_remote_code=True,
|
62 |
-
# device_map="auto",
|
63 |
-
# )
|
64 |
-
# sequences = pipeline(
|
65 |
-
# "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
66 |
-
# max_length=200,
|
67 |
-
# do_sample=True,
|
68 |
-
# top_k=10,
|
69 |
-
# num_return_sequences=1,
|
70 |
-
# eos_token_id=tokenizer.eos_token_id,
|
71 |
-
# )
|
72 |
-
# st.title("Beyond the Anti-Jam: Integration of DRL with LLM")
|
73 |
-
# for seq in sequences:
|
74 |
-
# st.write(f"Result: {seq['generated_text']}")
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
import streamlit as st
|
|
|
5 |
from trainer import train
|
6 |
from tester import test
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
+
def main():
|
10 |
+
st.title("Beyond the Anti-Jam: Integration of DRL with LLM")
|
|
|
11 |
|
12 |
+
st.sidebar.header("Make Your Environment Configuration")
|
13 |
+
mode = st.sidebar.radio("Choose Mode", ["Auto", "Manual"])
|
14 |
|
15 |
+
if mode == "Auto":
|
16 |
+
jammer_type = "dynamic"
|
17 |
+
channel_switching_cost = 0.1
|
18 |
+
else:
|
19 |
+
jammer_type = st.sidebar.selectbox("Select Jammer Type", ["constant", "sweeping", "random", "dynamic"])
|
20 |
+
channel_switching_cost = st.sidebar.selectbox("Select Channel Switching Cost", [0, 0.05, 0.1, 0.15, 0.2])
|
21 |
|
22 |
+
st.sidebar.subheader("Configuration:")
|
23 |
+
st.sidebar.write(f"Jammer Type: {jammer_type}")
|
24 |
+
st.sidebar.write(f"Channel Switching Cost: {channel_switching_cost}")
|
|
|
|
|
25 |
|
26 |
+
start_button = st.sidebar.button('Start')
|
27 |
|
28 |
+
if start_button:
|
29 |
+
agent = perform_training(jammer_type, channel_switching_cost)
|
30 |
+
test(agent, jammer_type, channel_switching_cost)
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
def perform_training(jammer_type, channel_switching_cost):
|
34 |
+
agent = train(jammer_type, channel_switching_cost)
|
35 |
+
return agent
|
36 |
|
|
|
37 |
|
38 |
+
def perform_testing(agent, jammer_type, channel_switching_cost):
|
|
|
|
|
|
|
|
|
39 |
test(agent, jammer_type, channel_switching_cost)
|
40 |
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.py
CHANGED
@@ -7,6 +7,11 @@ import json
|
|
7 |
import streamlit as st
|
8 |
from DDQN import DoubleDeepQNetwork
|
9 |
from antiJamEnv import AntiJamEnv
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def train(jammer_type, channel_switching_cost):
|
@@ -53,7 +58,8 @@ def train(jammer_type, channel_switching_cost):
|
|
53 |
if done or time == max_env_steps - 1:
|
54 |
rewards.append(tot_rewards)
|
55 |
epsilons.append(DDQN_agent.epsilon)
|
56 |
-
status_text.text(
|
|
|
57 |
progress_bar.progress((e + 1) / TRAIN_Episodes)
|
58 |
break
|
59 |
|
@@ -66,12 +72,12 @@ def train(jammer_type, channel_switching_cost):
|
|
66 |
|
67 |
# Plotting
|
68 |
rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
|
69 |
-
|
70 |
# Create a new Streamlit figure for the training graph
|
71 |
fig, ax = plt.subplots(figsize=(8, 6))
|
72 |
ax.plot(rewards, label='Rewards')
|
73 |
ax.plot(rolling_average, color='black', label='Rolling Average')
|
74 |
-
ax.axhline(y=
|
75 |
eps_graph = [100 * x for x in epsilons]
|
76 |
ax.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
|
77 |
ax.set_xlabel('Episodes')
|
@@ -79,23 +85,18 @@ def train(jammer_type, channel_switching_cost):
|
|
79 |
ax.set_title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
|
80 |
ax.legend()
|
81 |
|
82 |
-
|
|
|
83 |
with st.container():
|
84 |
col1, col2 = st.columns(2)
|
85 |
|
86 |
with col1:
|
87 |
st.subheader("Training Graph")
|
88 |
-
st.set_option('deprecation.showPyplotGlobalUse', False)
|
89 |
st.pyplot(fig)
|
90 |
|
91 |
with col2:
|
92 |
st.subheader("Graph Explanation")
|
93 |
-
st.write(
|
94 |
-
The training graph shows the rewards received by the agent in each episode of the training process.
|
95 |
-
The blue line represents the actual reward values, while the black line represents a rolling average.
|
96 |
-
The red horizontal line indicates the threshold for considering the task solved.
|
97 |
-
The green line represents the epsilon (exploration rate) values for the agent, indicating how often it takes random actions.
|
98 |
-
""")
|
99 |
|
100 |
# Save the figure
|
101 |
# plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
@@ -111,4 +112,22 @@ def train(jammer_type, channel_switching_cost):
|
|
111 |
# # Save the agent as a SavedAgent.
|
112 |
# agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
113 |
# DDQN_agent.save_model(agentName)
|
114 |
-
return DDQN_agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import streamlit as st
|
8 |
from DDQN import DoubleDeepQNetwork
|
9 |
from antiJamEnv import AntiJamEnv
|
10 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
11 |
+
|
12 |
+
model_name = "tiiuae/falcon-7b-instruct" # Replace with the exact model name or path
|
13 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
|
16 |
|
17 |
def train(jammer_type, channel_switching_cost):
|
|
|
58 |
if done or time == max_env_steps - 1:
|
59 |
rewards.append(tot_rewards)
|
60 |
epsilons.append(DDQN_agent.epsilon)
|
61 |
+
status_text.text(
|
62 |
+
f"Episode: {e + 1}/{TRAIN_Episodes}, Reward: {tot_rewards}, Epsilon: {DDQN_agent.epsilon:.3f}")
|
63 |
progress_bar.progress((e + 1) / TRAIN_Episodes)
|
64 |
break
|
65 |
|
|
|
72 |
|
73 |
# Plotting
|
74 |
rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
|
75 |
+
solved_threshold = max_env_steps - 0.10 * max_env_steps
|
76 |
# Create a new Streamlit figure for the training graph
|
77 |
fig, ax = plt.subplots(figsize=(8, 6))
|
78 |
ax.plot(rewards, label='Rewards')
|
79 |
ax.plot(rolling_average, color='black', label='Rolling Average')
|
80 |
+
ax.axhline(y=solved_threshold, color='r', linestyle='-', label='Solved Line')
|
81 |
eps_graph = [100 * x for x in epsilons]
|
82 |
ax.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
|
83 |
ax.set_xlabel('Episodes')
|
|
|
85 |
ax.set_title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
|
86 |
ax.legend()
|
87 |
|
88 |
+
insights = generate_insights(rewards, rolling_average, epsilons, solved_threshold)
|
89 |
+
|
90 |
with st.container():
|
91 |
col1, col2 = st.columns(2)
|
92 |
|
93 |
with col1:
|
94 |
st.subheader("Training Graph")
|
|
|
95 |
st.pyplot(fig)
|
96 |
|
97 |
with col2:
|
98 |
st.subheader("Graph Explanation")
|
99 |
+
st.write(insights)
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# Save the figure
|
102 |
# plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
|
|
112 |
# # Save the agent as a SavedAgent.
|
113 |
# agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
114 |
# DDQN_agent.save_model(agentName)
|
115 |
+
return DDQN_agent
|
116 |
+
|
117 |
+
|
118 |
+
def generate_insights(rewards, rolling_average, epsilons, solved_threshold):
|
119 |
+
description = (
|
120 |
+
f"The graph represents training rewards over episodes. "
|
121 |
+
f"The actual rewards range from {min(rewards)} to {max(rewards)} with an average of {np.mean(rewards):.2f}. "
|
122 |
+
f"The rolling average values range from {min(rolling_average)} to {max(rolling_average)} with an average of {np.mean(rolling_average):.2f}. "
|
123 |
+
f"The epsilon values range from {min(epsilons)} to {max(epsilons)} with an average exploration rate of {np.mean(epsilons):.2f}. "
|
124 |
+
f"The solved threshold is set at {solved_threshold}. "
|
125 |
+
f"Provide insights based on this data."
|
126 |
+
)
|
127 |
+
input_ids = tokenizer.encode(description, return_tensors="pt")
|
128 |
+
|
129 |
+
# Generate output from model
|
130 |
+
output_ids = model.generate(input_ids, max_length=300, num_return_sequences=1)
|
131 |
+
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
132 |
+
|
133 |
+
return output_text
|