twigs commited on
Commit
d87595a
1 Parent(s): bef3028

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -18,8 +18,8 @@ cwi_tok = AutoTokenizer.from_pretrained('twigs/cwi-regressor')
18
  cwi_model = AutoModelForSequenceClassification.from_pretrained('twigs/cwi-regressor')
19
  simpl_tok = BartTokenizer.from_pretrained('twigs/bart-text2text-simplifier')
20
  simpl_model = BartForConditionalGeneration.from_pretrained('twigs/bart-text2text-simplifier')
21
- cwi_pipe = pipeline('text-classification', model=cwi_model, tokenizer=cwi_tok, function_to_apply='none', device=0)
22
- fill_pipe = pipeline('fill-mask', model=simpl_model, tokenizer=simpl_tok, top_k=1, device=0)
23
 
24
 
25
  def id_replace_complex(s, threshold=0.4):
@@ -44,7 +44,7 @@ def id_replace_complex(s, threshold=0.4):
44
 
45
  def generate_candidate_text(s, model, tokenizer, tokenized=False):
46
 
47
- out = simpl_tok([s], max_length=256, padding="max_length", truncation=True, return_tensors='pt').to('cuda') if not tokenized else s
48
 
49
  generated_ids = model.generate(
50
  input_ids=out['input_ids'],
 
18
  cwi_model = AutoModelForSequenceClassification.from_pretrained('twigs/cwi-regressor')
19
  simpl_tok = BartTokenizer.from_pretrained('twigs/bart-text2text-simplifier')
20
  simpl_model = BartForConditionalGeneration.from_pretrained('twigs/bart-text2text-simplifier')
21
+ cwi_pipe = pipeline('text-classification', model=cwi_model, tokenizer=cwi_tok, function_to_apply='none')
22
+ fill_pipe = pipeline('fill-mask', model=simpl_model, tokenizer=simpl_tok, top_k=1)
23
 
24
 
25
  def id_replace_complex(s, threshold=0.4):
 
44
 
45
  def generate_candidate_text(s, model, tokenizer, tokenized=False):
46
 
47
+ out = simpl_tok([s], max_length=256, padding="max_length", truncation=True, return_tensors='pt') if not tokenized else s
48
 
49
  generated_ids = model.generate(
50
  input_ids=out['input_ids'],