heegyu commited on
Commit
8e1c181
·
1 Parent(s): 6ff1237

cuda handling

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import random
3
  import time
4
  from transformers import pipeline
@@ -7,7 +8,7 @@ from transformers import pipeline
7
  generator = pipeline(
8
  'text-generation',
9
  model="heegyu/gorani-v0",
10
- device="cuda:0"
11
  )
12
 
13
  def query(message, chat_history, max_turn=4):
@@ -15,7 +16,7 @@ def query(message, chat_history, max_turn=4):
15
  if len(chat_history) > max_turn:
16
  chat_history = chat_history[-max_turn:]
17
  for user, bot in chat_history:
18
- prompt.append(f"<usr> {user}")
19
  prompt.append(f"<bot> {bot}")
20
  prompt.append(f"<usr> {message}")
21
  prompt = "\n".join(prompt) + "\n<bot>"
 
1
  import gradio as gr
2
+ import torch
3
  import random
4
  import time
5
  from transformers import pipeline
 
8
  generator = pipeline(
9
  'text-generation',
10
  model="heegyu/gorani-v0",
11
+ device="cuda:0" if torch.cuda.is_available() else 'cpu'
12
  )
13
 
14
  def query(message, chat_history, max_turn=4):
 
16
  if len(chat_history) > max_turn:
17
  chat_history = chat_history[-max_turn:]
18
  for user, bot in chat_history:
19
+ prompt.append(f"<usr> 대답해: {user}")
20
  prompt.append(f"<bot> {bot}")
21
  prompt.append(f"<usr> {message}")
22
  prompt = "\n".join(prompt) + "\n<bot>"