EveSa commited on
Commit
ca615e0
·
1 Parent(s): 3805a61

inference fonctionnelle sans load de fichier

Browse files
Files changed (2) hide show
  1. .gitattributes +1 -0
  2. src/inference.py +7 -13
.gitattributes CHANGED
@@ -1 +1,2 @@
1
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.pkl filter=lfs diff=lfs merge=lfs -text
src/inference.py CHANGED
@@ -5,19 +5,12 @@ import pickle
5
 
6
  import torch
7
 
8
- import dataloader
9
- from model import Decoder, Encoder, EncoderDecoderModel
10
 
11
- # On doit loader les données pour avoir le Vectoriser > sauvegarder "words" dans un fichiers et le loader par la suite ??
12
- ### À CHANGER POUR N'AVOIR À LOADER QUE LE VECTORISER
13
- data1 = dataloader.Data("data/train_extract.jsonl")
14
- data2 = dataloader.Data("data/dev_extract.jsonl")
15
- words = pickle.load("model/vocab.pkl")
16
- words = data1.get_words()
17
-
18
- vectoriser = dataloader.Vectoriser()
19
- vectoriser.load("model/vocab.pkl")
20
- word_counts = vectoriser.word_count
21
 
22
 
23
  def inferenceAPI(text: str) -> str:
@@ -42,7 +35,7 @@ def inferenceAPI(text: str) -> str:
42
  )
43
 
44
  # On instancie le modèle
45
- model = EncoderDecoderModel(encoder, decoder, device)
46
 
47
  model.load_state_dict(torch.load("model/model.pt", map_location=device))
48
  model.eval()
@@ -56,6 +49,7 @@ def inferenceAPI(text: str) -> str:
56
  with torch.no_grad():
57
  output = model(source).to(device)
58
  output.to(device)
 
59
  return vectoriser.decode(output)
60
 
61
 
 
5
 
6
  import torch
7
 
8
+ from src import dataloader
9
+ from src.model import Decoder, Encoder, EncoderDecoderModel
10
 
11
+ with open ("model/vocab.pkl", 'rb') as vocab:
12
+ words = pickle.load(vocab)
13
+ vectoriser = dataloader.Vectoriser(words)
 
 
 
 
 
 
 
14
 
15
 
16
  def inferenceAPI(text: str) -> str:
 
35
  )
36
 
37
  # On instancie le modèle
38
+ model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
39
 
40
  model.load_state_dict(torch.load("model/model.pt", map_location=device))
41
  model.eval()
 
49
  with torch.no_grad():
50
  output = model(source).to(device)
51
  output.to(device)
52
+ output=output.argmax(dim=-1)
53
  return vectoriser.decode(output)
54
 
55