pushkarraj commited on
Commit
dffc8e2
1 Parent(s): c8bd88f

added device

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -2,12 +2,14 @@ import gradio as gr
2
  import pandas as pd
3
  import os
4
  import time
 
5
  from transformers import pipeline, GPT2Tokenizer, OPTForCausalLM
 
6
 
7
  model=OPTForCausalLM.from_pretrained('pushkarraj/pushkar_OPT_paraphaser')
8
  tokenizer=GPT2Tokenizer.from_pretrained('pushkarraj/pushkar_OPT_paraphaser',truncation=True)
9
 
10
- generator=pipeline("text-generation",model=model,tokenizer=tokenizer,device=0)
11
 
12
  def cleaned_para(input_sentence):
13
  p=generator('<s>'+input_sentence+ '</s>>>>><p>',do_sample=True,max_length=len(input_sentence.split(" "))+200,temperature = 0.8,repetition_penalty=1.2,top_p=0.4,top_k=1)
 
2
  import pandas as pd
3
  import os
4
  import time
5
+ import torch
6
  from transformers import pipeline, GPT2Tokenizer, OPTForCausalLM
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
  model=OPTForCausalLM.from_pretrained('pushkarraj/pushkar_OPT_paraphaser')
10
  tokenizer=GPT2Tokenizer.from_pretrained('pushkarraj/pushkar_OPT_paraphaser',truncation=True)
11
 
12
+ generator=pipeline("text-generation",model=model,tokenizer=tokenizer,device=device)
13
 
14
  def cleaned_para(input_sentence):
15
  p=generator('<s>'+input_sentence+ '</s>>>>><p>',do_sample=True,max_length=len(input_sentence.split(" "))+200,temperature = 0.8,repetition_penalty=1.2,top_p=0.4,top_k=1)