EveSa commited on
Commit
4cbd001
1 Parent(s): 8dba466

Initial commit

Browse files
Files changed (3) hide show
  1. requirements.txt +11 -0
  2. src/inference.py +6 -3
  3. src/train.py +1 -0
requirements.txt CHANGED
@@ -1,26 +1,37 @@
1
  anyio==3.6.2
 
 
2
  click==8.1.3
3
  fastapi==0.92.0
 
4
  h11==0.14.0
 
5
  idna==3.4
6
  Jinja2==3.1.2
7
  joblib==1.2.0
8
  MarkupSafe==2.1.2
 
9
  numpy==1.24.2
10
  nvidia-cublas-cu11==11.10.3.66
11
  nvidia-cuda-nvrtc-cu11==11.7.99
12
  nvidia-cuda-runtime-cu11==11.7.99
13
  nvidia-cudnn-cu11==8.5.0.96
 
14
  pandas==1.5.3
15
  pydantic==1.10.5
16
  python-dateutil==2.8.2
17
  python-multipart==0.0.6
18
  pytz==2022.7.1
 
19
  regex==2022.10.31
 
20
  six==1.16.0
21
  sniffio==1.3.0
22
  starlette==0.25.0
 
23
  torch==1.13.1
24
  tqdm==4.65.0
 
25
  typing_extensions==4.5.0
 
26
  uvicorn==0.20.0
 
1
  anyio==3.6.2
2
+ certifi==2022.12.7
3
+ charset-normalizer==3.1.0
4
  click==8.1.3
5
  fastapi==0.92.0
6
+ filelock==3.9.0
7
  h11==0.14.0
8
+ huggingface-hub==0.13.1
9
  idna==3.4
10
  Jinja2==3.1.2
11
  joblib==1.2.0
12
  MarkupSafe==2.1.2
13
+ nltk==3.8.1
14
  numpy==1.24.2
15
  nvidia-cublas-cu11==11.10.3.66
16
  nvidia-cuda-nvrtc-cu11==11.7.99
17
  nvidia-cuda-runtime-cu11==11.7.99
18
  nvidia-cudnn-cu11==8.5.0.96
19
+ packaging==23.0
20
  pandas==1.5.3
21
  pydantic==1.10.5
22
  python-dateutil==2.8.2
23
  python-multipart==0.0.6
24
  pytz==2022.7.1
25
+ PyYAML==6.0
26
  regex==2022.10.31
27
+ requests==2.28.2
28
  six==1.16.0
29
  sniffio==1.3.0
30
  starlette==0.25.0
31
+ tokenizers==0.13.2
32
  torch==1.13.1
33
  tqdm==4.65.0
34
+ transformers==4.26.1
35
  typing_extensions==4.5.0
36
+ urllib3==1.26.15
37
  uvicorn==0.20.0
src/inference.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
 
8
  import dataloader
9
  from model import Decoder, Encoder, EncoderDecoderModel
 
10
 
11
  with open("model/vocab.pkl", "rb") as vocab:
12
  words = pickle.load(vocab)
@@ -33,6 +34,8 @@ def inferenceAPI(text: str) -> str:
33
  decoder.to(device)
34
 
35
  # On instancie le modèle
 
 
36
  model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
37
 
38
  model.load_state_dict(torch.load("model/model.pt", map_location=device))
@@ -51,6 +54,6 @@ def inferenceAPI(text: str) -> str:
51
  return vectoriser.decode(output)
52
 
53
 
54
- # if __name__ == "__main__":
55
- # # inference()
56
- # print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys."))
 
7
 
8
  import dataloader
9
  from model import Decoder, Encoder, EncoderDecoderModel
10
+ from transformers import AutoModel
11
 
12
  with open("model/vocab.pkl", "rb") as vocab:
13
  words = pickle.load(vocab)
 
34
  decoder.to(device)
35
 
36
  # On instancie le modèle
37
+ model = AutoModel.from_pretrained("EveSa/SummaryProject-LSTM", revision="main")
38
+ model = AutoModel.PretrainedConfig()
39
  model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
40
 
41
  model.load_state_dict(torch.load("model/model.pt", map_location=device))
 
54
  return vectoriser.decode(output)
55
 
56
 
57
+ if __name__ == "__main__":
58
+ # inference()
59
+ print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys."))
src/train.py CHANGED
@@ -194,6 +194,7 @@ if __name__ == "__main__":
194
 
195
  torch.save(trained_classifier.state_dict(), "model/model.pt")
196
  vectoriser.save("model/vocab.pkl")
 
197
 
198
  print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
199
  print(
 
194
 
195
  torch.save(trained_classifier.state_dict(), "model/model.pt")
196
  vectoriser.save("model/vocab.pkl")
197
+ trained_classifier.config.to_json_file("config.json")
198
 
199
  print(f"test summary : {vectoriser.decode(dev_dataset[6][1])}")
200
  print(