mrm8488 commited on
Commit
45d922e
1 Parent(s): 89fe8bb

Add device support

Browse files
Files changed (2) hide show
  1. app.py +4 -1
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
3
  from ui import title, description
4
  from langs import LANGS
5
 
@@ -9,6 +10,7 @@ CKPT = "facebook/nllb-200-distilled-600M"
9
  model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
10
  tokenizer = AutoTokenizer.from_pretrained(CKPT)
11
 
 
12
 
13
  def translate(text, src_lang, tgt_lang, max_length=400):
14
  """
@@ -19,7 +21,8 @@ def translate(text, src_lang, tgt_lang, max_length=400):
19
  tokenizer=tokenizer,
20
  src_lang=src_lang,
21
  tgt_lang=tgt_lang,
22
- max_length=max_length)
 
23
 
24
  result = translation_pipeline(text)
25
  return result[0]['translation_text']
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ import torch
4
  from ui import title, description
5
  from langs import LANGS
6
 
 
10
  model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
11
  tokenizer = AutoTokenizer.from_pretrained(CKPT)
12
 
13
+ device = 0 if torch.cuda.is_available() else -1
14
 
15
  def translate(text, src_lang, tgt_lang, max_length=400):
16
  """
 
21
  tokenizer=tokenizer,
22
  src_lang=src_lang,
23
  tgt_lang=tgt_lang,
24
+ max_length=max_length,
25
+ device=device)
26
 
27
  result = translation_pipeline(text)
28
  return result[0]['translation_text']
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- transformers
2
  gradio
3
  torch
 
1
+ git+https://github.com/huggingface/transformers
2
  gradio
3
  torch