LibreChat / api /cache /banViolation.spec.js
N.Achyuth Reddy
Upload 683 files
9705b6c
const banViolation = require('./banViolation');
jest.mock('keyv');
jest.mock('../models/Session');
// Mocking the getLogStores function
jest.mock('./getLogStores', () => {
return jest.fn().mockImplementation(() => {
const EventEmitter = require('events');
const math = require('../server/utils/math');
const mockGet = jest.fn();
const mockSet = jest.fn();
class KeyvMongo extends EventEmitter {
constructor(url = 'mongodb://127.0.0.1:27017', options) {
super();
this.ttlSupport = false;
url = url ?? {};
if (typeof url === 'string') {
url = { url };
}
if (url.uri) {
url = { url: url.uri, ...url };
}
this.opts = {
url,
collection: 'keyv',
...url,
...options,
};
}
get = mockGet;
set = mockSet;
}
return new KeyvMongo('', {
namespace: 'bans',
ttl: math(process.env.BAN_DURATION, 7200000),
});
});
});
describe('banViolation', () => {
let req, res, errorMessage;
beforeEach(() => {
req = {
ip: '127.0.0.1',
cookies: {
refreshToken: 'someToken',
},
};
res = {
clearCookie: jest.fn(),
};
errorMessage = {
type: 'someViolation',
user_id: '12345',
prev_count: 0,
violation_count: 0,
};
process.env.BAN_VIOLATIONS = 'true';
process.env.BAN_DURATION = '7200000'; // 2 hours in ms
process.env.BAN_INTERVAL = '20';
});
afterEach(() => {
jest.clearAllMocks();
});
it('should not ban if BAN_VIOLATIONS are not enabled', async () => {
process.env.BAN_VIOLATIONS = 'false';
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeFalsy();
});
it('should not ban if errorMessage is not provided', async () => {
await banViolation(req, res, null);
expect(errorMessage.ban).toBeFalsy();
});
it('[1/3] should ban if violation_count crosses the interval threshold: 19 -> 39', async () => {
errorMessage.prev_count = 19;
errorMessage.violation_count = 39;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeTruthy();
});
it('[2/3] should ban if violation_count crosses the interval threshold: 19 -> 20', async () => {
errorMessage.prev_count = 19;
errorMessage.violation_count = 20;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeTruthy();
});
const randomValueAbove = Math.floor(20 + Math.random() * 100);
it(`[3/3] should ban if violation_count crosses the interval threshold: 19 -> ${randomValueAbove}`, async () => {
errorMessage.prev_count = 19;
errorMessage.violation_count = randomValueAbove;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeTruthy();
});
it('should handle invalid BAN_INTERVAL and default to 20', async () => {
process.env.BAN_INTERVAL = 'invalid';
errorMessage.prev_count = 19;
errorMessage.violation_count = 39;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeTruthy();
});
it('should ban if BAN_DURATION is invalid as default is 2 hours', async () => {
process.env.BAN_DURATION = 'invalid';
errorMessage.prev_count = 19;
errorMessage.violation_count = 39;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeTruthy();
});
it('should not ban if BAN_DURATION is 0 but should clear cookies', async () => {
process.env.BAN_DURATION = '0';
errorMessage.prev_count = 19;
errorMessage.violation_count = 39;
await banViolation(req, res, errorMessage);
expect(res.clearCookie).toHaveBeenCalledWith('refreshToken');
});
it('should not ban if violation_count does not change', async () => {
errorMessage.prev_count = 0;
errorMessage.violation_count = 0;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeFalsy();
});
it('[1/2] should not ban if violation_count does not cross the interval threshold: 0 -> 19', async () => {
errorMessage.prev_count = 0;
errorMessage.violation_count = 19;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeFalsy();
});
const randomValueUnder = Math.floor(1 + Math.random() * 19);
it(`[2/2] should not ban if violation_count does not cross the interval threshold: 0 -> ${randomValueUnder}`, async () => {
errorMessage.prev_count = 0;
errorMessage.violation_count = randomValueUnder;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeFalsy();
});
it('[EDGE CASE] should not ban if violation_count is lower', async () => {
errorMessage.prev_count = 0;
errorMessage.violation_count = -10;
await banViolation(req, res, errorMessage);
expect(errorMessage.ban).toBeFalsy();
});
});