Frorozcol commited on
Commit
56e1ed7
1 Parent(s): 9208454

Plus: Addes the .csv from huawei

Browse files
src/__pycache__/predict.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/predict.cpython-310.pyc and b/src/__pycache__/predict.cpython-310.pyc differ
 
src/predict.py CHANGED
@@ -1,5 +1,6 @@
1
  from pathlib import Path
2
  import torch
 
3
 
4
  from .tokenizer import load_tokenizer, preprocessing_text
5
  from .model import load_model
@@ -17,6 +18,24 @@ tokenizer = load_tokenizer(model_name)
17
  model = load_model(checkpoint_path, model_name, num_labels, divice)
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def get_predict(text):
21
  inputs = preprocessing_text(text, tokenizer)
22
  input_ids = inputs["input_ids"].to(divice)
@@ -24,4 +43,5 @@ def get_predict(text):
24
  token_type_ids = inputs["token_type_ids"].to(divice)
25
  outputs = model(input_ids, attention_mask, token_type_ids)
26
  preds = torch.sigmoid(outputs).detach().cpu().numpy()
 
27
  return preds
 
1
  from pathlib import Path
2
  import torch
3
+ import numpy as np
4
 
5
  from .tokenizer import load_tokenizer, preprocessing_text
6
  from .model import load_model
 
18
  model = load_model(checkpoint_path, model_name, num_labels, divice)
19
 
20
 
21
+ RETURN_VALUES =[
22
+ "target_sentiment_negative",
23
+ "target_sentiment_neutral",
24
+ "target_sentiment_positive",
25
+ "companies_sentiment_negative",
26
+ "companies_sentiment_neutral",
27
+ "companies_sentiment_positive",
28
+ "consumers_sentiment_negative",
29
+ "consumers_sentiment_neutral",
30
+ "consumers_sentiment_positive"
31
+ ]
32
+
33
+ def filter(preds, threshold=0.5):
34
+ bool = preds > threshold
35
+ indices = np.where(bool)[0]
36
+ filtered_values = {RETURN_VALUES[index]: preds[index] for index in indices}
37
+ return filtered_values
38
+
39
  def get_predict(text):
40
  inputs = preprocessing_text(text, tokenizer)
41
  input_ids = inputs["input_ids"].to(divice)
 
43
  token_type_ids = inputs["token_type_ids"].to(divice)
44
  outputs = model(input_ids, attention_mask, token_type_ids)
45
  preds = torch.sigmoid(outputs).detach().cpu().numpy()
46
+ preds = filter(preds[0])
47
  return preds