Makaria commited on
Commit
ee6ad7f
·
1 Parent(s): 0adbdde
Files changed (1) hide show
  1. app.py +30 -26
app.py CHANGED
@@ -1,40 +1,44 @@
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
4
 
5
- # Загрузка модели и токенизатора
6
- model_name = "DialoGPT-medium" # Используем Diálogo GPT для диалогов
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
9
 
10
  # Функция для общения с моделью
11
- def chat_with_model(user_input, chat_history):
12
- # Токенизация пользовательского ввода
13
- new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
14
 
15
- # Объединение истории чата и нового ввода
16
- if chat_history:
17
  bot_input_ids = torch.cat([torch.tensor(chat_history), new_user_input_ids], dim=-1)
18
  else:
19
  bot_input_ids = new_user_input_ids
20
-
21
  # Генерация ответа от модели
22
- chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id,
23
- temperature=0.8, top_p=0.9, do_sample=True)
24
-
25
- # Получаем ответ бота и обновляем историю чата
26
- bot_output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
27
-
28
- # Возвращаем новый ввод и обновлённую историю чата без лишних цифр
29
- return bot_output, chat_history_ids.tolist()
30
 
31
  # Интерфейс Gradio
32
- iface = gr.Interface(
33
- fn=chat_with_model,
34
- inputs=[gr.inputs.Textbox(label="Ваше сообщение"), gr.inputs.State()],
35
- outputs=[gr.outputs.Textbox(label="Ответ бота"), gr.outputs.State()],
36
- title="Чат-бот с сарказмом и абсурдом",
37
- description="Побеседуй со своим ботом, который использует сарказм и немного абсурда. Напиши ему что угодно!",
38
- )
 
 
 
39
 
40
- iface.launch()
 
1
+ import os
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import gradio as gr
5
 
6
+ # Модель и токен
7
+ model_name = "microsoft/DialoGPT-medium"
8
+ huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
9
+
10
+ # Загрузка токенайзера и модели с использованием токена
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=huggingface_token)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=huggingface_token)
13
 
14
  # Функция для общения с моделью
15
+ def chat_with_model(input_text, chat_history=[]):
16
+ new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors="pt")
 
17
 
18
+ # Если есть история чата, объединяем её с новым вводом
19
+ if len(chat_history) > 0:
20
  bot_input_ids = torch.cat([torch.tensor(chat_history), new_user_input_ids], dim=-1)
21
  else:
22
  bot_input_ids = new_user_input_ids
23
+
24
  # Генерация ответа от модели
25
+ chat_history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token)
26
+
27
+ # Получение текста и вывод ответа
28
+ response = tokenizer.decode(chat_history[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
29
+
30
+ return response, chat_history
 
 
31
 
32
  # Интерфейс Gradio
33
+ with gr.Blocks() as demo:
34
+ chatbot = gr.Chatbot()
35
+ msg = gr.Textbox()
36
+ state = gr.State([]) # Для сохранения истории чата
37
+
38
+ def respond(message, chat_history):
39
+ response, chat_history = chat_with_model(message, chat_history)
40
+ return chatbot.update([message, response]), chat_history
41
+
42
+ msg.submit(respond, [msg, state], [chatbot, state])
43
 
44
+ demo.launch()