bczhou commited on
Commit
bde270a
1 Parent(s): d520d83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -5,9 +5,11 @@ import torch
5
 
6
  os.environ['CURL_CA_BUNDLE'] = ''
7
 
 
 
8
  config = LinearMappingConfig()
9
  model = LinearMapping(config)
10
- model.load_state_dict(torch.load("pytorch_model.bin"))
11
  processor = LinearMappingProcessor(config)
12
  processor.tokenizer.padding_side = 'left'
13
  processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
@@ -53,4 +55,4 @@ demo = gr.Interface(
53
  description=description
54
  )
55
 
56
- demo.launch(share=True)
 
5
 
6
  os.environ['CURL_CA_BUNDLE'] = ''
7
 
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+
10
  config = LinearMappingConfig()
11
  model = LinearMapping(config)
12
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
13
  processor = LinearMappingProcessor(config)
14
  processor.tokenizer.padding_side = 'left'
15
  processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
 
55
  description=description
56
  )
57
 
58
+ demo.launch()