matt HOFFNER commited on
Commit
a21e1b7
1 Parent(s): 87495f8

add hf model support

Browse files
components/Playground/index.tsx CHANGED
@@ -36,7 +36,7 @@ const Playground = () => {
36
  const [prevMarkdownCode, setPrevMarkdownCode] = useState(markdownCode);
37
  const [isSystemInputVisible, setSystemInputVisible] = useState(false);
38
  const [isModelInputVisible, setModelInputVisible] = useState(false);
39
- const [selectedModel, setSelectedModel] = useState("GPT-4");
40
 
41
  const [systemMessage, setSystemMessage] = useState(
42
  DEFAULT_PROMPT
@@ -48,8 +48,8 @@ const Playground = () => {
48
 
49
  const modifiedHandleSubmit = async (e: FormEvent<HTMLFormElement>, chatRequestOptions?: ChatRequestOptions) => {
50
  e.preventDefault();
51
- // Simply append the user message
52
- await handleSubmit(e, chatRequestOptions);
53
  };
54
 
55
  useEffect(() => {
@@ -67,7 +67,8 @@ const Playground = () => {
67
 
68
  const { append, messages, input, setInput, handleSubmit, ...rest } = useChat({
69
  body: {
70
- systemMessage: systemMessage
 
71
  },
72
  onError: (error) => {
73
  console.error(error);
@@ -149,11 +150,12 @@ const Playground = () => {
149
  {isModelInputVisible && (
150
  <div className="my-4">
151
  <select
152
- value={selectedModel}
153
- onChange={(e) => setSelectedModel(e.target.value)}
154
- className="border p-2 rounded-md shadow-sm w-full bg-transparent text-gray-700"
155
- >
156
- <option value="GPT-4">GPT-4</option>
 
157
  </select>
158
  </div>
159
  )}
 
36
  const [prevMarkdownCode, setPrevMarkdownCode] = useState(markdownCode);
37
  const [isSystemInputVisible, setSystemInputVisible] = useState(false);
38
  const [isModelInputVisible, setModelInputVisible] = useState(false);
39
+ const [aiProvider, setAIProvider] = useState<string>("openai");
40
 
41
  const [systemMessage, setSystemMessage] = useState(
42
  DEFAULT_PROMPT
 
48
 
49
  const modifiedHandleSubmit = async (e: FormEvent<HTMLFormElement>, chatRequestOptions?: ChatRequestOptions) => {
50
  e.preventDefault();
51
+ // Pass the aiProvider in chatRequestOptions
52
+ await handleSubmit(e, { ...chatRequestOptions, aiProvider } as any);
53
  };
54
 
55
  useEffect(() => {
 
67
 
68
  const { append, messages, input, setInput, handleSubmit, ...rest } = useChat({
69
  body: {
70
+ systemMessage: systemMessage,
71
+ aiProvider: aiProvider
72
  },
73
  onError: (error) => {
74
  console.error(error);
 
150
  {isModelInputVisible && (
151
  <div className="my-4">
152
  <select
153
+ value={aiProvider}
154
+ onChange={(e) => setAIProvider(e.target.value)}
155
+ className="border p-2 rounded-md shadow-sm w-full bg-transparent text-gray-700 mt-2"
156
+ >
157
+ <option value="openai">OpenAI</option>
158
+ <option value="huggingface">Hugging Face</option>
159
  </select>
160
  </div>
161
  )}
package-lock.json CHANGED
@@ -14,6 +14,7 @@
14
  "@emotion/react": "^11.10.4",
15
  "@graphql-codegen/cli": "^2.6.2",
16
  "@graphql-codegen/typescript-react-apollo": "^3.3.3",
 
17
  "@monaco-editor/react": "^4.2.0",
18
  "@reduxjs/toolkit": "^1.6.0",
19
  "@types/apollo-upload-client": "^17.0.1",
@@ -2234,6 +2235,14 @@
2234
  "graphql": "^0.8.0 || ^0.9.0 || ^0.10.0 || ^0.11.0 || ^0.12.0 || ^0.13.0 || ^14.0.0 || ^15.0.0 || ^16.0.0"
2235
  }
2236
  },
 
 
 
 
 
 
 
 
2237
  "node_modules/@jridgewell/gen-mapping": {
2238
  "version": "0.3.3",
2239
  "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz",
 
14
  "@emotion/react": "^11.10.4",
15
  "@graphql-codegen/cli": "^2.6.2",
16
  "@graphql-codegen/typescript-react-apollo": "^3.3.3",
17
+ "@huggingface/inference": "^2.6.4",
18
  "@monaco-editor/react": "^4.2.0",
19
  "@reduxjs/toolkit": "^1.6.0",
20
  "@types/apollo-upload-client": "^17.0.1",
 
2235
  "graphql": "^0.8.0 || ^0.9.0 || ^0.10.0 || ^0.11.0 || ^0.12.0 || ^0.13.0 || ^14.0.0 || ^15.0.0 || ^16.0.0"
2236
  }
2237
  },
2238
+ "node_modules/@huggingface/inference": {
2239
+ "version": "2.6.4",
2240
+ "resolved": "https://registry.npmjs.org/@huggingface/inference/-/inference-2.6.4.tgz",
2241
+ "integrity": "sha512-Xna7arltBSBoKaH3diGi3sYvkExgJMd/pF4T6vl2YbmDccbr1G/X5EPZ2048p+YgrJYG1jTYFCtY6Dr3HvJaow==",
2242
+ "engines": {
2243
+ "node": ">=18"
2244
+ }
2245
+ },
2246
  "node_modules/@jridgewell/gen-mapping": {
2247
  "version": "0.3.3",
2248
  "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz",
package.json CHANGED
@@ -16,6 +16,7 @@
16
  "@emotion/react": "^11.10.4",
17
  "@graphql-codegen/cli": "^2.6.2",
18
  "@graphql-codegen/typescript-react-apollo": "^3.3.3",
 
19
  "@monaco-editor/react": "^4.2.0",
20
  "@reduxjs/toolkit": "^1.6.0",
21
  "@types/apollo-upload-client": "^17.0.1",
 
16
  "@emotion/react": "^11.10.4",
17
  "@graphql-codegen/cli": "^2.6.2",
18
  "@graphql-codegen/typescript-react-apollo": "^3.3.3",
19
+ "@huggingface/inference": "^2.6.4",
20
  "@monaco-editor/react": "^4.2.0",
21
  "@reduxjs/toolkit": "^1.6.0",
22
  "@types/apollo-upload-client": "^17.0.1",
pages/api/chat/index.ts CHANGED
@@ -1,15 +1,22 @@
1
  import { OpenAIStream, StreamingTextResponse } from "ai";
2
  import { Configuration, OpenAIApi } from "openai-edge";
 
 
 
3
 
4
- const config = new Configuration({
 
5
  apiKey: process.env.OPENAI_API_KEY,
6
  });
7
- const openai = new OpenAIApi(config);
8
 
9
- export const runtime = "edge";
 
 
 
10
 
11
  export default async function(req: Request) {
12
- let { messages, systemMessage } = await req.json();
13
 
14
  // Prepend the system message if it's not already there
15
  if (messages.length === 0 || messages[0].role !== "system") {
@@ -19,12 +26,31 @@ export default async function(req: Request) {
19
  }, ...messages];
20
  }
21
 
22
- const response = await openai.createChatCompletion({
23
- model: 'gpt-4',
24
- stream: true,
25
- messages
26
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- const stream = OpenAIStream(response);
29
- return new StreamingTextResponse(stream);
 
 
 
30
  }
 
1
  import { OpenAIStream, StreamingTextResponse } from "ai";
2
  import { Configuration, OpenAIApi } from "openai-edge";
3
+ import { HfInference } from '@huggingface/inference';
4
+ import { HuggingFaceStream } from 'ai';
5
+ import { experimental_buildLlama2Prompt } from 'ai/prompts'
6
 
7
+ // Configurations for OpenAI
8
+ const openaiConfig = new Configuration({
9
  apiKey: process.env.OPENAI_API_KEY,
10
  });
11
+ const openai = new OpenAIApi(openaiConfig);
12
 
13
+ // Create a new HuggingFace Inference instance
14
+ const Hf = new HfInference(process.env.HUGGINGFACE_API_KEY);
15
+
16
+ export const runtime = 'edge';
17
 
18
  export default async function(req: Request) {
19
+ let { messages, aiProvider = 'openai', systemMessage } = await req.json();
20
 
21
  // Prepend the system message if it's not already there
22
  if (messages.length === 0 || messages[0].role !== "system") {
 
26
  }, ...messages];
27
  }
28
 
29
+ if (aiProvider === 'openai') {
30
+ const response = await openai.createChatCompletion({
31
+ model: 'gpt-4',
32
+ stream: true,
33
+ messages
34
+ });
35
+
36
+ const stream = OpenAIStream(response);
37
+ return new StreamingTextResponse(stream);
38
+ } else if (aiProvider === 'huggingface') {
39
+ const response = Hf.textGenerationStream({
40
+ // @ts-ignore
41
+ model: 'meta-llama/Llama-2-7b-chat-hf',
42
+ inputs: experimental_buildLlama2Prompt(messages),
43
+ parameters: {
44
+ max_new_tokens: 500,
45
+ repetition_penalty: 1,
46
+ truncate: 4000,
47
+ return_full_text: false
48
+ }
49
+ })
50
 
51
+ const stream = HuggingFaceStream(response);
52
+ return new StreamingTextResponse(stream);
53
+ } else {
54
+ throw new Error(`Unsupported AI provider: ${aiProvider}`);
55
+ }
56
  }