cheesexuebao commited on
Commit
afc3372
1 Parent(s): 74b913c

Load model from hub

Browse files
Prediction.py CHANGED
@@ -66,19 +66,6 @@ def predict_single(sentence, tokenizer, model, device, max_token_len=128):
66
  y_inten = [round(i, 8) for i in y_inten]
67
  return y_inten
68
 
69
- def model_factory(local_path, device):
70
- manager = {}
71
- for model_path in glob.glob(f"{local_path}/*"):
72
- base_name = os.path.basename(model_path)
73
- model_name = os.path.splitext(base_name)[0]
74
- tokenizer = BertTokenizer.from_pretrained(model_path)
75
- model = BertForSequenceClassification.from_pretrained(model_path)
76
- model = model.to(device)
77
- manager[model_name] = {
78
- "model": model,
79
- "tokenizer": tokenizer
80
- }
81
- return manager
82
 
83
 
84
  if __name__ == "__main__":
@@ -87,9 +74,10 @@ if __name__ == "__main__":
87
  Data = Data[:20]
88
  device = torch.device('cpu')
89
 
90
- manager = model_factory("./models", device)
91
- for model_name, dct in manager.items():
92
- model, tokenizer = dct['model'], dct['tokenizer']
93
- fk_doc_result = predict_csv(Data,"content", tokenizer, model, device)
94
- single_response = predict_single("Games of the imagination teach us actions have consequences in a realm that can be reset.", tokenizer, model, device)
95
- fk_doc_result.to_csv(f"output/prediction_{model_name}.csv")
 
 
66
  y_inten = [round(i, 8) for i in y_inten]
67
  return y_inten
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  if __name__ == "__main__":
 
74
  Data = Data[:20]
75
  device = torch.device('cpu')
76
 
77
+ # Load model directly
78
+ tokenizer = BertTokenizer.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
79
+ model = BertForSequenceClassification.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
80
+ model = model.to(device)
81
+ fk_doc_result = predict_csv(Data,"content", tokenizer, model, device)
82
+ single_response = predict_single("Games of the imagination teach us actions have consequences in a realm that can be reset.", tokenizer, model, device)
83
+ fk_doc_result.to_csv(f"output/prediction_Brand_Tone_of_Voice.csv")
app.py CHANGED
@@ -20,13 +20,12 @@ else:
20
  ]
21
 
22
  device = torch.device('cpu')
23
- manager = model_factory("./models", device)
 
 
24
 
25
 
26
  def single_sentence(sentence):
27
- model_name = 'All_Data'
28
- dct = manager[model_name]
29
- model, tokenizer = dct['model'], dct['tokenizer']
30
  predictions = predict_single(sentence, tokenizer, model, device)
31
  predictions.sort(reverse=True)
32
  return list(zip(LABEL_COLUMNS, predictions))
@@ -38,18 +37,15 @@ def csv_process(csv_file, attr="content"):
38
  data = data.reset_index()
39
  os.makedirs('output', exist_ok=True)
40
  outputs = []
41
- model_name = 'All_Data'
42
- dct = manager[model_name]
43
- model, tokenizer = dct['model'], dct['tokenizer']
44
  predictions = predict_csv(data, attr, tokenizer, model, device)
45
- output_path = f"output/prediction_{model_name}_{formatted_time}.csv"
46
  predictions.to_csv(output_path)
47
  outputs.append(output_path)
48
  return outputs
49
 
50
 
51
  my_theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
52
- with gr.Blocks(theme=my_theme, title='Murphy') as demo:
53
  gr.HTML(
54
  """
55
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
@@ -60,7 +56,7 @@ with gr.Blocks(theme=my_theme, title='Murphy') as demo:
60
  <h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
61
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;>
62
  <a href="https://arxiv.org/abs/xx.xx"><img src="https://img.shields.io/badge/Arxiv-xx.xx-red"></a>
63
- <a href='https://huggingface.co/spaces/cheesexuebao/murphy'><img src='https://img.shields.io/badge/Project_Page-Murphy/xxBert' alt='Project Page'></a>
64
  <a href='https://github.com'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
65
  </div>
66
  </div>
 
20
  ]
21
 
22
  device = torch.device('cpu')
23
+ tokenizer = BertTokenizer.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
24
+ model = BertForSequenceClassification.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
25
+ model = model.to(device)
26
 
27
 
28
  def single_sentence(sentence):
 
 
 
29
  predictions = predict_single(sentence, tokenizer, model, device)
30
  predictions.sort(reverse=True)
31
  return list(zip(LABEL_COLUMNS, predictions))
 
37
  data = data.reset_index()
38
  os.makedirs('output', exist_ok=True)
39
  outputs = []
 
 
 
40
  predictions = predict_csv(data, attr, tokenizer, model, device)
41
+ output_path = f"output/prediction_Brand_Tone_of_Voice_{formatted_time}.csv"
42
  predictions.to_csv(output_path)
43
  outputs.append(output_path)
44
  return outputs
45
 
46
 
47
  my_theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
48
+ with gr.Blocks(theme=my_theme, title='Brand_Tone_of_Voice_demo') as demo:
49
  gr.HTML(
50
  """
51
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
 
56
  <h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
57
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;>
58
  <a href="https://arxiv.org/abs/xx.xx"><img src="https://img.shields.io/badge/Arxiv-xx.xx-red"></a>
59
+ <a href='https://huggingface.co/spaces/Oliver12315/Brand_Tone_of_Voice_demo'><img src='https://img.shields.io/badge/Project_Page-Oliver12315/Brand_Tone_of_Voice_demo' alt='Project Page'></a>
60
  <a href='https://github.com'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
61
  </div>
62
  </div>
models/All_Data/config.json DELETED
@@ -1,39 +0,0 @@
1
- {
2
- "_name_or_path": "bert-base-uncased",
3
- "architectures": [
4
- "BertForSequenceClassification"
5
- ],
6
- "attention_probs_dropout_prob": 0.1,
7
- "classifier_dropout": null,
8
- "gradient_checkpointing": false,
9
- "hidden_act": "gelu",
10
- "hidden_dropout_prob": 0.1,
11
- "hidden_size": 768,
12
- "id2label": {
13
- "0": "Assertive Tone",
14
- "1": "Conversational Tone",
15
- "2": "Emotional Tone",
16
- "3": "Informative Tone",
17
- "4": "None"
18
- },
19
- "initializer_range": 0.02,
20
- "intermediate_size": 3072,
21
- "label2id": {
22
- "Assertive Tone": 0,
23
- "Conversational Tone": 1,
24
- "Emotional Tone": 2,
25
- "Informative Tone": 3,
26
- "None": 4
27
- },
28
- "layer_norm_eps": 1e-12,
29
- "max_position_embeddings": 512,
30
- "model_type": "bert",
31
- "num_attention_heads": 12,
32
- "num_hidden_layers": 12,
33
- "pad_token_id": 0,
34
- "position_embedding_type": "absolute",
35
- "transformers_version": "4.36.2",
36
- "type_vocab_size": 2,
37
- "use_cache": true,
38
- "vocab_size": 30522
39
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/All_Data/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:593dc3210abcc95df5a0f63580ce571df2b60c39cc4f1d7122e371c9f37c4c64
3
- size 438024366
 
 
 
 
models/All_Data/vocab.txt DELETED
The diff for this file is too large to render. See raw diff
 
tmp.py DELETED
@@ -1,5 +0,0 @@
1
- import pandas as pd
2
-
3
- pd.read_csv('output/example.csv')
4
- pd.inde
5
- ...