Spaces:
Sleeping
Sleeping
Integrating the falconLLM
Browse files- app.py +18 -23
- tester.py +6 -42
- trainer.py +11 -10
app.py
CHANGED
@@ -5,6 +5,8 @@ import streamlit as st
|
|
5 |
import os
|
6 |
from trainer import train
|
7 |
from tester import test
|
|
|
|
|
8 |
|
9 |
|
10 |
def main():
|
@@ -24,32 +26,25 @@ def main():
|
|
24 |
st.sidebar.write(f"Jammer Type: {jammer_type}")
|
25 |
st.sidebar.write(f"Channel Switching Cost: {channel_switching_cost}")
|
26 |
|
27 |
-
|
28 |
-
test_button = st.sidebar.button('Test')
|
29 |
-
|
30 |
-
if train_button or test_button:
|
31 |
-
agent_name = f'DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
32 |
-
if os.path.exists(agent_name):
|
33 |
-
if train_button:
|
34 |
-
st.warning("Agent has been trained already! Do you want to retrain?")
|
35 |
-
retrain = st.sidebar.button('Yes')
|
36 |
-
if retrain:
|
37 |
-
perform_training(jammer_type, channel_switching_cost)
|
38 |
-
elif test_button:
|
39 |
-
perform_testing(jammer_type, channel_switching_cost)
|
40 |
-
else:
|
41 |
-
if train_button:
|
42 |
-
perform_training(jammer_type, channel_switching_cost)
|
43 |
-
elif test_button:
|
44 |
-
st.warning("Agent has not been trained yet. Click Train First!!!")
|
45 |
|
|
|
|
|
|
|
46 |
|
47 |
def perform_training(jammer_type, channel_switching_cost):
|
48 |
-
train(jammer_type, channel_switching_cost)
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
|
55 |
if __name__ == "__main__":
|
|
|
5 |
import os
|
6 |
from trainer import train
|
7 |
from tester import test
|
8 |
+
import transformers
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
|
11 |
|
12 |
def main():
|
|
|
26 |
st.sidebar.write(f"Jammer Type: {jammer_type}")
|
27 |
st.sidebar.write(f"Channel Switching Cost: {channel_switching_cost}")
|
28 |
|
29 |
+
start_button = st.sidebar.button('Start')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
if start_button:
|
32 |
+
agent = perform_training(jammer_type, channel_switching_cost)
|
33 |
+
test(agent, jammer_type, channel_switching_cost)
|
34 |
|
35 |
def perform_training(jammer_type, channel_switching_cost):
|
36 |
+
agent = train(jammer_type, channel_switching_cost)
|
37 |
+
model_name = "tiiuae/falcon-7b-instruct"
|
38 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
40 |
+
pipeline = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=100, temperature=0.7)
|
41 |
+
text = pipeline("Discuss this topic: Integrating LLMs to DRL-based anti-jamming.")
|
42 |
+
st.write(text)
|
43 |
+
return agent
|
44 |
+
|
45 |
+
|
46 |
+
def perform_testing(agent, jammer_type, channel_switching_cost):
|
47 |
+
test(agent, jammer_type, channel_switching_cost)
|
48 |
|
49 |
|
50 |
if __name__ == "__main__":
|
tester.py
CHANGED
@@ -9,26 +9,17 @@ from DDQN import DoubleDeepQNetwork
|
|
9 |
from antiJamEnv import AntiJamEnv
|
10 |
|
11 |
|
12 |
-
def test(jammer_type, channel_switching_cost):
|
13 |
env = AntiJamEnv(jammer_type, channel_switching_cost)
|
14 |
ob_space = env.observation_space
|
15 |
ac_space = env.action_space
|
16 |
|
17 |
s_size = ob_space.shape[0]
|
18 |
a_size = ac_space.n
|
19 |
-
max_env_steps =
|
20 |
-
TEST_Episodes =
|
21 |
env._max_episode_steps = max_env_steps
|
22 |
-
|
23 |
-
epsilon = 1.0 # exploration rate
|
24 |
-
epsilon_min = 0.01
|
25 |
-
epsilon_decay = 0.999
|
26 |
-
discount_rate = 0.95
|
27 |
-
lr = 0.001
|
28 |
-
|
29 |
-
agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
30 |
-
DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
|
31 |
-
DDQN_agent.model = DDQN_agent.load_saved_model(agentName)
|
32 |
rewards = [] # Store rewards for graphing
|
33 |
epsilons = [] # Store the Explore/Exploit
|
34 |
|
@@ -47,35 +38,8 @@ def test(jammer_type, channel_switching_cost):
|
|
47 |
break
|
48 |
next_state = np.reshape(next_state, [1, s_size])
|
49 |
tot_rewards += reward
|
|
|
|
|
50 |
# DON'T STORE ANYTHING DURING TESTING
|
51 |
state = next_state
|
52 |
|
53 |
-
# Plotting
|
54 |
-
rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
|
55 |
-
|
56 |
-
# Create a new Streamlit figure
|
57 |
-
fig = plt.figure()
|
58 |
-
plt.plot(rewards, label='Rewards')
|
59 |
-
plt.plot(rolling_average, color='black', label='Rolling Average')
|
60 |
-
plt.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-', label='Solved Line')
|
61 |
-
eps_graph = [100 * x for x in epsilons]
|
62 |
-
plt.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
|
63 |
-
plt.xlabel('Episodes')
|
64 |
-
plt.ylabel('Rewards')
|
65 |
-
plt.title(f'Testing Rewards - {jammer_type}, CSC: {channel_switching_cost}')
|
66 |
-
plt.legend()
|
67 |
-
|
68 |
-
# Display the Streamlit figure using streamlit.pyplot
|
69 |
-
st.set_option('deprecation.showPyplotGlobalUse', False)
|
70 |
-
st.pyplot(fig)
|
71 |
-
|
72 |
-
# Save the figure
|
73 |
-
plot_name = f'./data/test_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
74 |
-
plt.savefig(plot_name, bbox_inches='tight')
|
75 |
-
plt.close(fig) # Close the figure to release resources
|
76 |
-
|
77 |
-
# Save Results
|
78 |
-
# Rewards
|
79 |
-
fileName = f'./data/test_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
|
80 |
-
with open(fileName, 'w') as f:
|
81 |
-
json.dump(rewards, f)
|
|
|
9 |
from antiJamEnv import AntiJamEnv
|
10 |
|
11 |
|
12 |
+
def test(agent, jammer_type, channel_switching_cost):
|
13 |
env = AntiJamEnv(jammer_type, channel_switching_cost)
|
14 |
ob_space = env.observation_space
|
15 |
ac_space = env.action_space
|
16 |
|
17 |
s_size = ob_space.shape[0]
|
18 |
a_size = ac_space.n
|
19 |
+
max_env_steps = 3
|
20 |
+
TEST_Episodes = 1
|
21 |
env._max_episode_steps = max_env_steps
|
22 |
+
DDQN_agent = agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
rewards = [] # Store rewards for graphing
|
24 |
epsilons = [] # Store the Explore/Exploit
|
25 |
|
|
|
38 |
break
|
39 |
next_state = np.reshape(next_state, [1, s_size])
|
40 |
tot_rewards += reward
|
41 |
+
|
42 |
+
st.write(f"The state is: {state}, action taken is: {action}, obtained reward is: {reward}")
|
43 |
# DON'T STORE ANYTHING DURING TESTING
|
44 |
state = next_state
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.py
CHANGED
@@ -21,7 +21,7 @@ def train(jammer_type, channel_switching_cost):
|
|
21 |
s_size = ob_space.shape[0]
|
22 |
a_size = ac_space.n
|
23 |
max_env_steps = 100
|
24 |
-
TRAIN_Episodes =
|
25 |
env._max_episode_steps = max_env_steps
|
26 |
|
27 |
epsilon = 1.0
|
@@ -85,16 +85,17 @@ def train(jammer_type, channel_switching_cost):
|
|
85 |
st.pyplot(fig)
|
86 |
|
87 |
# Save the figure
|
88 |
-
plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
89 |
-
plt.savefig(plot_name, bbox_inches='tight')
|
90 |
plt.close(fig) # Close the figure to release resources
|
91 |
|
92 |
# Save Results
|
93 |
# Rewards
|
94 |
-
fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
|
95 |
-
with open(fileName, 'w') as f:
|
96 |
-
|
97 |
-
|
98 |
-
# Save the agent as a SavedAgent.
|
99 |
-
agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
100 |
-
DDQN_agent.save_model(agentName)
|
|
|
|
21 |
s_size = ob_space.shape[0]
|
22 |
a_size = ac_space.n
|
23 |
max_env_steps = 100
|
24 |
+
TRAIN_Episodes = 25
|
25 |
env._max_episode_steps = max_env_steps
|
26 |
|
27 |
epsilon = 1.0
|
|
|
85 |
st.pyplot(fig)
|
86 |
|
87 |
# Save the figure
|
88 |
+
# plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
89 |
+
# plt.savefig(plot_name, bbox_inches='tight')
|
90 |
plt.close(fig) # Close the figure to release resources
|
91 |
|
92 |
# Save Results
|
93 |
# Rewards
|
94 |
+
# fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
|
95 |
+
# with open(fileName, 'w') as f:
|
96 |
+
# json.dump(rewards, f)
|
97 |
+
#
|
98 |
+
# # Save the agent as a SavedAgent.
|
99 |
+
# agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
100 |
+
# DDQN_agent.save_model(agentName)
|
101 |
+
return DDQN_agent
|