Lorenzob commited on
Commit
2119d2a
·
verified ·
1 Parent(s): 216cb97

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +4 -2
  2. handler.py +118 -0
README.md CHANGED
@@ -3,6 +3,8 @@
3
 
4
  This is a TRM model trained using the provided datasets.
5
 
6
- ## How to use
7
 
8
- [More detailed usage instructions can be added here]
 
 
 
3
 
4
  This is a TRM model trained using the provided datasets.
5
 
6
+ ## How to use for Inference
7
 
8
+ You can use this model for inference via the Hugging Face Inference API or with the `transformers` library.
9
+
10
+ Make sure you have the `modelling_trm.py` file in the same directory as the model files if using the `transformers` library locally.
handler.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer
5
+ # Import your custom model class
6
+ import sys
7
+ # Add the local directory containing modelling_trm.py to the Python path
8
+ sys.path.insert(0, ".") # Assuming the handler will be in the root of the repo
9
+ from modelling_trm import TRM, TRMConfig
10
+ sys.path.pop(0) # Remove the path after import
11
+
12
+ class InferenceHandler:
13
+ def __init__(self):
14
+ self.model = None
15
+ self.tokenizer = None
16
+ self.device = None
17
+
18
+ def load(self, model_path="."):
19
+ # Load model and tokenizer
20
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ print(f"Loading model on device: {self.device}")
22
+
23
+ # Load the config
24
+ config = TRMConfig.from_pretrained(model_path)
25
+
26
+ # Load the model
27
+ self.model = TRM.from_pretrained(model_path, config=config)
28
+ self.model.to(self.device)
29
+ self.model.eval() # Set model to evaluation mode
30
+
31
+ # Load the tokenizer (using a placeholder as the original had issues)
32
+ # You might need to adapt this based on your actual tokenizer
33
+ try:
34
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
35
+ except Exception:
36
+ # Fallback to a basic tokenizer if loading from path fails
37
+ self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
38
+ print("Loaded a placeholder tokenizer (bert-base-uncased) for inference.")
39
+
40
+
41
+ def preprocess(self, inputs):
42
+ # Preprocess inputs for the model
43
+ # 'inputs' will be the data received by the inference endpoint
44
+ # This needs to be adapted based on the expected input format (e.g., text string)
45
+ # For text generation, 'inputs' could be a string or a list of strings.
46
+ if isinstance(inputs, str):
47
+ inputs = [inputs]
48
+ elif not isinstance(inputs, list):
49
+ raise ValueError("Input must be a string or a list of strings.")
50
+
51
+ # Tokenize the input
52
+ # Ensure padding and truncation are handled
53
+ tokenized_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=self.model.config.seq_len)
54
+
55
+ # Move tokenized inputs to the model's device
56
+ tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()}
57
+
58
+ # Return only the inputs expected by the TRM model
59
+ # Based on training, TRM seems to only take 'input_ids'
60
+ return {'input_ids': tokenized_inputs['input_ids']}
61
+
62
+
63
+ def inference(self, inputs):
64
+ # Perform inference with the model
65
+ # 'inputs' here is the output of the preprocess method
66
+ with torch.no_grad():
67
+ # Perform the forward pass
68
+ # Assuming the model only takes input_ids
69
+ outputs = self.model(**inputs)
70
+
71
+ # The model's output structure might differ, assuming it returns logits
72
+ # You might need to adapt this based on the actual TRM output for inference
73
+ # For text generation, you might use model.generate() instead of a simple forward pass
74
+ # This example performs a simple forward pass and returns logits
75
+ logits = outputs.logits if hasattr(outputs, 'logits') else outputs['logits'] # Adapt based on model output
76
+
77
+
78
+ return logits # Or process logits further for text generation
79
+
80
+
81
+ def postprocess(self, outputs):
82
+ # Postprocess the model outputs
83
+ # 'outputs' here is the output of the inference method (e.g., logits)
84
+ # For text generation, you would typically decode the generated token IDs
85
+ # This is a placeholder postprocessing step (e.g., returning the raw logits as a list)
86
+
87
+ # Example: decode token IDs if using model.generate()
88
+ # generated_ids = outputs[0] # Assuming outputs from generate() is a tensor
89
+ # generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
90
+ # return generated_text
91
+
92
+ # For this basic handler returning logits, just convert to CPU and list
93
+ return outputs.cpu().tolist()
94
+
95
+
96
+ def handle(self, data):
97
+ # Main inference handler function
98
+ # 'data' is the input received by the inference endpoint
99
+
100
+ # 1. Preprocess
101
+ model_input = self.preprocess(data)
102
+
103
+ # 2. Inference
104
+ model_output = self.inference(model_input)
105
+
106
+ # 3. Postprocess
107
+ response = self.postprocess(model_output)
108
+
109
+ return response
110
+
111
+ # Example usage (for testing locally)
112
+ # if __name__ == "__main__":
113
+ # handler = InferenceHandler()
114
+ # handler.load()
115
+ # test_input = "This is a test input"
116
+ # output = handler.handle(test_input)
117
+ # print("Inference output:", output)
118
+