Spaces:
Runtime error
Runtime error
Added Sber Gpt
Browse files- app.py +35 -0
- config.json +2 -0
- requirements.txt +2 -0
app.py
CHANGED
@@ -3,6 +3,8 @@ import time
|
|
3 |
import json
|
4 |
import requests
|
5 |
import gradio as gr
|
|
|
|
|
6 |
|
7 |
with open("config.json", "r") as f:
|
8 |
config = json.load(f)
|
@@ -13,8 +15,41 @@ max_attempts = config["MAX_ATTEMPS"]
|
|
13 |
wait_time = config["WAIT_TIME"]
|
14 |
chatgpt_url = config["CHATGPT_URL"]
|
15 |
system_prompt = config["SYSTEM_PROMPT"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def get_answer(question: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
headers = {
|
19 |
'Content-Type': 'application/json; charset=utf-8'
|
20 |
}
|
|
|
3 |
import json
|
4 |
import requests
|
5 |
import gradio as gr
|
6 |
+
import torch
|
7 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
8 |
|
9 |
with open("config.json", "r") as f:
|
10 |
config = json.load(f)
|
|
|
15 |
wait_time = config["WAIT_TIME"]
|
16 |
chatgpt_url = config["CHATGPT_URL"]
|
17 |
system_prompt = config["SYSTEM_PROMPT"]
|
18 |
+
sber_gpt = config["SBER_GRT"]
|
19 |
+
use_sber = config["USE_SBER"]
|
20 |
+
|
21 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
22 |
+
|
23 |
+
if use_sber:
|
24 |
+
tokenizer = GPT2Tokenizer.from_pretrained(sber_gpt)
|
25 |
+
model = GPT2LMHeadModel.from_pretrained(sber_gpt).to(DEVICE)
|
26 |
+
|
27 |
+
def generate(
|
28 |
+
model, tok, text,
|
29 |
+
do_sample=True, max_length=10000, repetition_penalty=5.0,
|
30 |
+
top_k=5, top_p=0.95, temperature=1,
|
31 |
+
num_beams=None,
|
32 |
+
no_repeat_ngram_size=3
|
33 |
+
):
|
34 |
+
input_ids = tok.encode(text, return_tensors="pt").to(DEVICE)
|
35 |
+
out = model.generate(
|
36 |
+
input_ids.to(DEVICE),
|
37 |
+
max_length=max_length,
|
38 |
+
repetition_penalty=repetition_penalty,
|
39 |
+
do_sample=do_sample,
|
40 |
+
top_k=top_k, top_p=top_p, temperature=temperature,
|
41 |
+
num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
|
42 |
+
)
|
43 |
+
return list(map(tok.decode, out))[0]
|
44 |
|
45 |
def get_answer(question: str) -> Dict[str, Any]:
|
46 |
+
if use_sber:
|
47 |
+
content = generate(model, tokenizer, question)
|
48 |
+
return {
|
49 |
+
'status': True,
|
50 |
+
'content': content
|
51 |
+
}
|
52 |
+
|
53 |
headers = {
|
54 |
'Content-Type': 'application/json; charset=utf-8'
|
55 |
}
|
config.json
CHANGED
@@ -4,5 +4,7 @@
|
|
4 |
"MAX_ATTEMPS": 5,
|
5 |
"WAIT_TIME": 1,
|
6 |
"CHATGPT_URL": "https://free.churchless.tech/v1/chat/completions",
|
|
|
|
|
7 |
"SYSTEM_PROMPT": "Your task is to give the most detailed answer to the question posed. At the beginning of the question, there are tags in square brackets specifying the subject of the question. It is necessary to answer in the language of the user's question"
|
8 |
}
|
|
|
4 |
"MAX_ATTEMPS": 5,
|
5 |
"WAIT_TIME": 1,
|
6 |
"CHATGPT_URL": "https://free.churchless.tech/v1/chat/completions",
|
7 |
+
"SBER_GRT": "sberbank-ai/ruT5-large",
|
8 |
+
"USE_SBER": 1,
|
9 |
"SYSTEM_PROMPT": "Your task is to give the most detailed answer to the question posed. At the beginning of the question, there are tags in square brackets specifying the subject of the question. It is necessary to answer in the language of the user's question"
|
10 |
}
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|