File size: 3,272 Bytes
9705b6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
const BaseClient = require('../BaseClient');
const { getModelMaxTokens } = require('../../../utils');

class FakeClient extends BaseClient {
  constructor(apiKey, options = {}) {
    super(apiKey, options);
    this.sender = 'AI Assistant';
    this.setOptions(options);
  }
  setOptions(options) {
    if (this.options && !this.options.replaceOptions) {
      this.options.modelOptions = {
        ...this.options.modelOptions,
        ...options.modelOptions,
      };
      delete options.modelOptions;
      this.options = {
        ...this.options,
        ...options,
      };
    } else {
      this.options = options;
    }

    if (this.options.openaiApiKey) {
      this.apiKey = this.options.openaiApiKey;
    }

    const modelOptions = this.options.modelOptions || {};
    if (!this.modelOptions) {
      this.modelOptions = {
        ...modelOptions,
        model: modelOptions.model || 'gpt-3.5-turbo',
        temperature:
          typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
        top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
        presence_penalty:
          typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
        stop: modelOptions.stop,
      };
    }

    this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 4097;
  }
  getCompletion() {}
  buildMessages() {}
  getTokenCount(str) {
    return str.length;
  }
  getTokenCountForMessage(message) {
    return message?.content?.length || message.length;
  }
}

const initializeFakeClient = (apiKey, options, fakeMessages) => {
  let TestClient = new FakeClient(apiKey);
  TestClient.options = options;
  TestClient.abortController = { abort: jest.fn() };
  TestClient.saveMessageToDatabase = jest.fn();
  TestClient.loadHistory = jest
    .fn()
    .mockImplementation((conversationId, parentMessageId = null) => {
      if (!conversationId) {
        TestClient.currentMessages = [];
        return Promise.resolve([]);
      }

      const orderedMessages = TestClient.constructor.getMessagesForConversation({
        messages: fakeMessages,
        parentMessageId,
      });

      TestClient.currentMessages = orderedMessages;
      return Promise.resolve(orderedMessages);
    });

  TestClient.getSaveOptions = jest.fn().mockImplementation(() => {
    return {};
  });

  TestClient.getBuildMessagesOptions = jest.fn().mockImplementation(() => {
    return {};
  });

  TestClient.sendCompletion = jest.fn(async () => {
    return 'Mock response text';
  });

  TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => {
    const orderedMessages = TestClient.constructor.getMessagesForConversation({
      messages,
      parentMessageId,
    });
    const formattedMessages = orderedMessages.map((message) => {
      let { role: _role, sender, text } = message;
      const role = _role ?? sender;
      const content = text ?? '';
      return {
        role: role?.toLowerCase() === 'user' ? 'user' : 'assistant',
        content,
      };
    });
    return {
      prompt: formattedMessages,
      tokenCountMap: null, // Simplified for the mock
    };
  });

  return TestClient;
};

module.exports = { FakeClient, initializeFakeClient };