daniloedu commited on
Commit
94a4897
·
1 Parent(s): 8c5a3b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -30
app.py CHANGED
@@ -2,47 +2,37 @@ import os
2
  import requests
3
  import gradio as gr
4
  from dotenv import load_dotenv
5
- from transformers import AutoTokenizer
6
 
7
  load_dotenv()
8
 
9
- model_name = "tiiuae/falcon-7b-instruct"
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
11
 
12
- API_URL = "https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct"
13
  headers = {"Authorization": f"Bearer {os.getenv('HF_API_KEY')}"}
14
 
15
- def format_chat_prompt(message, instruction):
16
- prompt = f"System:{instruction}\nUser: {message}\nAssistant:"
17
- return prompt
18
-
19
- def query(payload):
20
- response = requests.post(API_URL, headers=headers, json=payload)
21
  return response.json()
22
-
23
- def respond(message, instruction="A conversation between a user and an AI assistant. The assistant gives helpful and honest answers."):
24
- MAX_TOKENS = 1024 # limit for the model
25
- prompt = format_chat_prompt(message, instruction)
26
- # Check if the prompt is too long and, if so, truncate it
27
- num_tokens = len(tokenizer.encode(prompt))
28
- if num_tokens > MAX_TOKENS:
29
- # Truncate the prompt to fit within the token limit
30
- prompt = tokenizer.decode(tokenizer.encode(prompt)[-MAX_TOKENS:])
31
-
32
- response = query({"inputs": prompt})
33
- generated_text = response[0]['generated_text']
34
- assistant_message = generated_text.split("Assistant:")[-1]
35
- assistant_message = assistant_message.split("User:")[0].strip() # Only keep the text before the first "User:"
36
- return assistant_message
37
 
38
  iface = gr.Interface(
39
  respond,
40
- inputs=[
41
- gr.inputs.Textbox(label="Your question"),
42
- gr.inputs.Textbox(label="System message", lines=2, default="A conversation between a user and an AI assistant. The assistant gives helpful and honest answers.")
43
- ],
44
  outputs=[
45
- gr.outputs.Textbox(label="AI's response")
 
 
46
  ],
47
  )
48
 
 
2
  import requests
3
  import gradio as gr
4
  from dotenv import load_dotenv
 
5
 
6
  load_dotenv()
7
 
8
+ API_URL_FALCON = "https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct"
9
+ API_URL_GUANACO = "https://api-inference.huggingface.co/models/timdettmers/guanaco-33b-merged"
10
+ API_URL_PYTHIA = "https://api-inference.huggingface.co/models/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"
11
 
 
12
  headers = {"Authorization": f"Bearer {os.getenv('HF_API_KEY')}"}
13
 
14
+ def query(api_url, payload):
15
+ response = requests.post(api_url, headers=headers, json=payload)
 
 
 
 
16
  return response.json()
17
+
18
+ def respond(message):
19
+ response_falcon = query(API_URL_FALCON, {"inputs": message})
20
+ response_guanaco = query(API_URL_GUANACO, {"inputs": message})
21
+ response_pythia = query(API_URL_PYTHIA, {"inputs": message})
22
+
23
+ generated_text_falcon = response_falcon[0]['generated_text']
24
+ generated_text_guanaco = response_guanaco[0]['generated_text']
25
+ generated_text_pythia = response_pythia[0]['generated_text']
26
+
27
+ return generated_text_falcon, generated_text_guanaco, generated_text_pythia
 
 
 
 
28
 
29
  iface = gr.Interface(
30
  respond,
31
+ inputs=gr.inputs.Textbox(label="Prompt"),
 
 
 
32
  outputs=[
33
+ gr.outputs.Textbox(label="Falcon Response"),
34
+ gr.outputs.Textbox(label="Guanaco Response"),
35
+ gr.outputs.Textbox(label="Pythia Response")
36
  ],
37
  )
38