|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| module stdp_synapse #(
|
| parameter DATA_WIDTH = 16,
|
| parameter TRACE_WIDTH = 8,
|
| parameter TRACE_MAX = 8'd127,
|
| parameter TRACE_DECAY = 8'd4,
|
| parameter LEARN_RATE = 8'd4,
|
| parameter WEIGHT_MAX = 16'd800,
|
| parameter WEIGHT_MIN = -16'sd800,
|
| parameter WEIGHT_INIT = 16'd0
|
| )(
|
| input wire clk,
|
| input wire rst_n,
|
| input wire learn_enable,
|
| input wire pre_spike,
|
| input wire post_spike,
|
| output reg signed [DATA_WIDTH-1:0] weight,
|
| output reg signed [DATA_WIDTH-1:0] post_current,
|
| output wire [TRACE_WIDTH-1:0] pre_trace_out,
|
| output wire [TRACE_WIDTH-1:0] post_trace_out
|
| );
|
|
|
| reg [TRACE_WIDTH-1:0] pre_trace;
|
| reg [TRACE_WIDTH-1:0] post_trace;
|
|
|
| assign pre_trace_out = pre_trace;
|
| assign post_trace_out = post_trace;
|
|
|
| wire signed [DATA_WIDTH-1:0] ltp_delta;
|
| wire signed [DATA_WIDTH-1:0] ltd_delta;
|
|
|
| assign ltp_delta = {{(DATA_WIDTH-TRACE_WIDTH){1'b0}}, pre_trace} >>> LEARN_RATE;
|
| assign ltd_delta = {{(DATA_WIDTH-TRACE_WIDTH){1'b0}}, post_trace} >>> LEARN_RATE;
|
|
|
| always @(posedge clk or negedge rst_n) begin
|
| if (!rst_n) begin
|
| pre_trace <= 0;
|
| post_trace <= 0;
|
| weight <= WEIGHT_INIT;
|
| post_current <= 0;
|
|
|
| end else begin
|
| if (pre_spike) begin
|
| pre_trace <= TRACE_MAX;
|
| end else if (pre_trace > TRACE_DECAY) begin
|
| pre_trace <= pre_trace - TRACE_DECAY;
|
| end else begin
|
| pre_trace <= 0;
|
| end
|
|
|
| if (post_spike) begin
|
| post_trace <= TRACE_MAX;
|
| end else if (post_trace > TRACE_DECAY) begin
|
| post_trace <= post_trace - TRACE_DECAY;
|
| end else begin
|
| post_trace <= 0;
|
| end
|
|
|
| if (learn_enable) begin
|
| if (post_spike && pre_trace > 0) begin
|
| if (weight + ltp_delta > WEIGHT_MAX)
|
| weight <= WEIGHT_MAX;
|
| else
|
| weight <= weight + ltp_delta;
|
| end
|
|
|
| if (pre_spike && post_trace > 0) begin
|
| if (weight - ltd_delta < WEIGHT_MIN)
|
| weight <= WEIGHT_MIN;
|
| else
|
| weight <= weight - ltd_delta;
|
| end
|
| end
|
|
|
| if (pre_spike) begin
|
| post_current <= weight;
|
| end else begin
|
| post_current <= 0;
|
| end
|
| end
|
| end
|
|
|
| endmodule
|
|
|