Spaces:
Runtime error
Runtime error
Переход на Sage
Browse files- app.py +11 -32
- config.json +2 -2
- requirements.txt +2 -1
app.py
CHANGED
@@ -3,8 +3,7 @@ import time
|
|
3 |
import json
|
4 |
import requests
|
5 |
import gradio as gr
|
6 |
-
import
|
7 |
-
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
8 |
|
9 |
with open("config.json", "r") as f:
|
10 |
config = json.load(f)
|
@@ -15,41 +14,21 @@ max_attempts = config["MAX_ATTEMPS"]
|
|
15 |
wait_time = config["WAIT_TIME"]
|
16 |
chatgpt_url = config["CHATGPT_URL"]
|
17 |
system_prompt = config["SYSTEM_PROMPT"]
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
if use_sber:
|
24 |
-
tokenizer = GPT2Tokenizer.from_pretrained(sber_gpt)
|
25 |
-
model = GPT2LMHeadModel.from_pretrained(sber_gpt).to(DEVICE)
|
26 |
-
|
27 |
-
def generate(
|
28 |
-
model, tok, text,
|
29 |
-
do_sample=True, max_length=10000, repetition_penalty=5.0,
|
30 |
-
top_k=5, top_p=0.95, temperature=1,
|
31 |
-
num_beams=10,
|
32 |
-
no_repeat_ngram_size=3
|
33 |
-
):
|
34 |
-
input_ids = tok.encode(text, return_tensors="pt").to(DEVICE)
|
35 |
-
out = model.generate(
|
36 |
-
input_ids.to(DEVICE),
|
37 |
-
max_length=max_length,
|
38 |
-
repetition_penalty=repetition_penalty,
|
39 |
-
do_sample=do_sample,
|
40 |
-
top_k=top_k, top_p=top_p, temperature=temperature,
|
41 |
-
num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
|
42 |
-
)
|
43 |
-
return list(map(tok.decode, out))[0]
|
44 |
|
45 |
def get_answer(question: str) -> Dict[str, Any]:
|
46 |
-
|
47 |
-
|
|
|
|
|
48 |
return {
|
49 |
'status': True,
|
50 |
-
'content':
|
51 |
}
|
52 |
-
|
53 |
headers = {
|
54 |
'Content-Type': 'application/json; charset=utf-8'
|
55 |
}
|
|
|
3 |
import json
|
4 |
import requests
|
5 |
import gradio as gr
|
6 |
+
import poe
|
|
|
7 |
|
8 |
with open("config.json", "r") as f:
|
9 |
config = json.load(f)
|
|
|
14 |
wait_time = config["WAIT_TIME"]
|
15 |
chatgpt_url = config["CHATGPT_URL"]
|
16 |
system_prompt = config["SYSTEM_PROMPT"]
|
17 |
+
use_sage = config["USE_SAGE"]
|
18 |
+
sage_token = config["SAGE_TOKEN"]
|
19 |
+
|
20 |
+
client = poe.Client(sage_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def get_answer(question: str) -> Dict[str, Any]:
|
23 |
+
|
24 |
+
if use_sage:
|
25 |
+
for chunk in client.send_message("capybara", question, with_chat_break=True):
|
26 |
+
pass
|
27 |
return {
|
28 |
'status': True,
|
29 |
+
'content': chunk["text"]
|
30 |
}
|
31 |
+
|
32 |
headers = {
|
33 |
'Content-Type': 'application/json; charset=utf-8'
|
34 |
}
|
config.json
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
"MAX_ATTEMPS": 5,
|
5 |
"WAIT_TIME": 1,
|
6 |
"CHATGPT_URL": "https://free.churchless.tech/v1/chat/completions",
|
7 |
-
"
|
8 |
-
"
|
9 |
"SYSTEM_PROMPT": "Your task is to give the most detailed answer to the question posed. At the beginning of the question, there are tags in square brackets specifying the subject of the question. It is necessary to answer in the language of the user's question"
|
10 |
}
|
|
|
4 |
"MAX_ATTEMPS": 5,
|
5 |
"WAIT_TIME": 1,
|
6 |
"CHATGPT_URL": "https://free.churchless.tech/v1/chat/completions",
|
7 |
+
"USE_SAGE": 1,
|
8 |
+
"SAGE_TOKEN": "PGUXiyEZKRHcMoij9AjxXw%3D%3D",
|
9 |
"SYSTEM_PROMPT": "Your task is to give the most detailed answer to the question posed. At the beginning of the question, there are tags in square brackets specifying the subject of the question. It is necessary to answer in the language of the user's question"
|
10 |
}
|
requirements.txt
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
torch
|
2 |
-
transformers
|
|
|
|
1 |
torch
|
2 |
+
transformers
|
3 |
+
poe-api
|