Thouph commited on
Commit
8e922f8
1 Parent(s): a4e7805

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -4
app.py CHANGED
@@ -7,7 +7,7 @@ import gradio as gr
7
  from datetime import datetime
8
  torch.set_grad_enabled(False)
9
 
10
- model = Qwen2ForSequenceClassification.from_pretrained("Thouph/prompt2tag-qwen2-0.5b-v0.1", num_labels = 9940).to("cuda")
11
  model.eval()
12
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
13
 
@@ -25,9 +25,6 @@ def create_tags(prompt, threshold):
25
  return_tensors="pt",
26
  )
27
 
28
- for k in inputs.keys():
29
- inputs[k] = inputs[k].to("cuda")
30
- # Generate
31
  output = model(**inputs).logits
32
  output = torch.nn.functional.sigmoid(output)
33
  indices = torch.where(output > threshold)
 
7
  from datetime import datetime
8
  torch.set_grad_enabled(False)
9
 
10
+ model = Qwen2ForSequenceClassification.from_pretrained("Thouph/prompt2tag-qwen2-0.5b-v0.1", num_labels = 9940, map_location=torch.device('cpu'))
11
  model.eval()
12
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
13
 
 
25
  return_tensors="pt",
26
  )
27
 
 
 
 
28
  output = model(**inputs).logits
29
  output = torch.nn.functional.sigmoid(output)
30
  indices = torch.where(output > threshold)