alfiannajih
commited on
Commit
•
74c1449
1
Parent(s):
a2c9add
Delete g_retriever
Browse files- g_retriever/.gitkeep +0 -0
- g_retriever/__init__.py +0 -0
- g_retriever/g_retriever_config.py +0 -31
- g_retriever/g_retriever_model.py +0 -215
- g_retriever/g_retriever_pipeline.py +0 -51
g_retriever/.gitkeep
DELETED
File without changes
|
g_retriever/__init__.py
DELETED
File without changes
|
g_retriever/g_retriever_config.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
from transformers import LlamaConfig
|
2 |
-
|
3 |
-
class GRetrieverConfig(LlamaConfig):
|
4 |
-
model_type = "llama"
|
5 |
-
|
6 |
-
def __init__(
|
7 |
-
self,
|
8 |
-
max_txt_len: int = 1024,
|
9 |
-
max_new_tokens: int = 256,
|
10 |
-
gnn_num_layers: int = 4,
|
11 |
-
gnn_in_dim: int = 768,
|
12 |
-
gnn_hidden_dim: int = 1024,
|
13 |
-
gnn_num_heads: int = 4,
|
14 |
-
gnn_dropout: int = 0,
|
15 |
-
bos_id: list = [128000, 128006, 882, 128007],
|
16 |
-
**kwargs
|
17 |
-
):
|
18 |
-
pretrained_config = LlamaConfig.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B")
|
19 |
-
pretrained_config.update(kwargs)
|
20 |
-
|
21 |
-
self.max_txt_len = max_txt_len
|
22 |
-
self.max_new_tokens = max_new_tokens
|
23 |
-
self.gnn_num_layers = gnn_num_layers
|
24 |
-
self.gnn_in_dim = gnn_in_dim
|
25 |
-
self.gnn_hidden_dim = gnn_hidden_dim
|
26 |
-
self.gnn_num_heads = gnn_num_heads
|
27 |
-
self.gnn_dropout = gnn_dropout
|
28 |
-
self.bos_id = bos_id
|
29 |
-
|
30 |
-
super().__init__(**pretrained_config.to_dict())
|
31 |
-
self.pad_token_id = pretrained_config.eos_token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
g_retriever/g_retriever_model.py
DELETED
@@ -1,215 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
from transformers import LlamaForCausalLM
|
6 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
7 |
-
from transformers.cache_utils import StaticCache
|
8 |
-
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask_with_cache_position
|
9 |
-
from .g_retriever_config import GRetrieverConfig
|
10 |
-
from .gnn import GAT
|
11 |
-
|
12 |
-
from functools import wraps
|
13 |
-
from torch_geometric.nn.pool import global_mean_pool
|
14 |
-
|
15 |
-
class GRetrieverModel(LlamaForCausalLM):
|
16 |
-
config_class = GRetrieverConfig
|
17 |
-
|
18 |
-
def __init__(self, config):
|
19 |
-
super().__init__(config)
|
20 |
-
self.graph_encoder = GAT(
|
21 |
-
in_channels=config.gnn_in_dim,
|
22 |
-
out_channels=config.gnn_hidden_dim,
|
23 |
-
hidden_channels=config.gnn_hidden_dim,
|
24 |
-
num_layers=config.gnn_num_layers,
|
25 |
-
dropout=config.gnn_dropout,
|
26 |
-
num_heads=config.gnn_num_heads,
|
27 |
-
).to(self.model.dtype)
|
28 |
-
|
29 |
-
self.projector = nn.Sequential(
|
30 |
-
nn.Linear(config.gnn_hidden_dim, 2048),
|
31 |
-
nn.Sigmoid(),
|
32 |
-
nn.Linear(2048, self.get_input_embeddings().embedding_dim),
|
33 |
-
).to(self.model.dtype)
|
34 |
-
|
35 |
-
def encode_graphs(self, graph):
|
36 |
-
n_embeds, _ = self.graph_encoder(
|
37 |
-
graph.x.to(self.model.dtype),
|
38 |
-
graph.edge_index.long(),
|
39 |
-
graph.edge_attr.to(self.model.dtype)
|
40 |
-
)
|
41 |
-
|
42 |
-
# mean pooling
|
43 |
-
g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
|
44 |
-
|
45 |
-
return g_embeds
|
46 |
-
|
47 |
-
@wraps(LlamaForCausalLM.forward)
|
48 |
-
def forward(
|
49 |
-
self,
|
50 |
-
input_ids=None,
|
51 |
-
graph=None,
|
52 |
-
attention_mask=None,
|
53 |
-
position_ids=None,
|
54 |
-
past_key_values=None,
|
55 |
-
inputs_embeds=None,
|
56 |
-
labels=None,
|
57 |
-
use_cache=None,
|
58 |
-
output_attentions=None,
|
59 |
-
output_hidden_states=None,
|
60 |
-
return_dict=None,
|
61 |
-
cache_position=None
|
62 |
-
):
|
63 |
-
inputs = input_ids.clone()
|
64 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
65 |
-
output_hidden_states = (
|
66 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
67 |
-
)
|
68 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
69 |
-
|
70 |
-
if (inputs==-1).any():
|
71 |
-
# embed bos prompt
|
72 |
-
bos_embeds = self.get_input_embeddings()(torch.tensor(
|
73 |
-
self.config.bos_id,
|
74 |
-
device=self.model.device
|
75 |
-
))
|
76 |
-
|
77 |
-
# encode graph
|
78 |
-
graph_embeds = self.encode_graphs(graph)
|
79 |
-
graph_embeds = self.projector(graph_embeds).to(self.model.device)
|
80 |
-
|
81 |
-
# prepare for reserved ids (bos+graph)
|
82 |
-
non_tokenized_ids = (inputs == -1).nonzero()
|
83 |
-
non_tokenized_shape = non_tokenized_ids[:, 0], non_tokenized_ids[:, 1]
|
84 |
-
|
85 |
-
# embed inputs
|
86 |
-
inputs[non_tokenized_shape] = self.config.pad_token_id
|
87 |
-
temp_inputs_embeds = self.get_input_embeddings()(inputs)
|
88 |
-
non_tokenized_embeds = torch.cat([bos_embeds.repeat(len(inputs), 1, 1), graph_embeds.unsqueeze(1)], dim=1)
|
89 |
-
|
90 |
-
# replace reserved ids with bos+graph
|
91 |
-
inputs_embeds = temp_inputs_embeds.clone()
|
92 |
-
inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1)
|
93 |
-
|
94 |
-
else:
|
95 |
-
inputs_embeds = self.get_input_embeddings()(inputs)
|
96 |
-
|
97 |
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
98 |
-
outputs = self.model(
|
99 |
-
attention_mask=attention_mask,
|
100 |
-
position_ids=position_ids,
|
101 |
-
past_key_values=past_key_values,
|
102 |
-
inputs_embeds=inputs_embeds,
|
103 |
-
use_cache=use_cache,
|
104 |
-
output_attentions=output_attentions,
|
105 |
-
output_hidden_states=output_hidden_states,
|
106 |
-
return_dict=return_dict,
|
107 |
-
cache_position=cache_position,
|
108 |
-
)
|
109 |
-
|
110 |
-
hidden_states = outputs[0]
|
111 |
-
|
112 |
-
if self.config.pretraining_tp > 1:
|
113 |
-
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
114 |
-
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
115 |
-
logits = torch.cat(logits, dim=-1)
|
116 |
-
else:
|
117 |
-
logits = self.lm_head(hidden_states)
|
118 |
-
logits = logits.float()
|
119 |
-
|
120 |
-
loss = None
|
121 |
-
if labels is not None:
|
122 |
-
# Shift so that tokens < n predict n
|
123 |
-
shift_logits = logits[..., :-1, :].contiguous()
|
124 |
-
shift_labels = labels[..., 1:].contiguous()
|
125 |
-
# Flatten the tokens
|
126 |
-
loss_fct = nn.CrossEntropyLoss()
|
127 |
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
128 |
-
shift_labels = shift_labels.view(-1)
|
129 |
-
# Enable model parallelism
|
130 |
-
shift_labels = shift_labels.to(shift_logits.device)
|
131 |
-
loss = loss_fct(shift_logits, shift_labels)
|
132 |
-
|
133 |
-
if not return_dict:
|
134 |
-
output = (logits,) + outputs[1:]
|
135 |
-
return (loss,) + output if loss is not None else output
|
136 |
-
|
137 |
-
return CausalLMOutputWithPast(
|
138 |
-
loss=loss,
|
139 |
-
logits=logits,
|
140 |
-
past_key_values=outputs.past_key_values,
|
141 |
-
hidden_states=outputs.hidden_states,
|
142 |
-
attentions=outputs.attentions,
|
143 |
-
)
|
144 |
-
|
145 |
-
def prepare_inputs_for_generation(
|
146 |
-
self,
|
147 |
-
input_ids,
|
148 |
-
graph=None,
|
149 |
-
past_key_values=None,
|
150 |
-
attention_mask=None,
|
151 |
-
inputs_embeds=None,
|
152 |
-
cache_position=None,
|
153 |
-
position_ids=None,
|
154 |
-
use_cache=True,
|
155 |
-
**kwargs,
|
156 |
-
):
|
157 |
-
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
158 |
-
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
159 |
-
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
160 |
-
if past_key_values is not None:
|
161 |
-
if inputs_embeds is not None: # Exception 1
|
162 |
-
input_ids = input_ids[:, -cache_position.shape[0] :]
|
163 |
-
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
164 |
-
input_ids = input_ids[:, cache_position]
|
165 |
-
|
166 |
-
if attention_mask is not None and position_ids is None:
|
167 |
-
# create position_ids on the fly for batch generation
|
168 |
-
position_ids = attention_mask.long().cumsum(-1) - 1
|
169 |
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
170 |
-
if past_key_values:
|
171 |
-
position_ids = position_ids[:, -input_ids.shape[1] :]
|
172 |
-
|
173 |
-
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
174 |
-
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
175 |
-
|
176 |
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
177 |
-
if inputs_embeds is not None and cache_position[0] == 0:
|
178 |
-
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
179 |
-
else:
|
180 |
-
# The clone here is for the same reason as for `position_ids`.
|
181 |
-
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
182 |
-
|
183 |
-
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
184 |
-
if model_inputs["inputs_embeds"] is not None:
|
185 |
-
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
186 |
-
device = model_inputs["inputs_embeds"].device
|
187 |
-
else:
|
188 |
-
batch_size, sequence_length = model_inputs["input_ids"].shape
|
189 |
-
device = model_inputs["input_ids"].device
|
190 |
-
|
191 |
-
dtype = self.lm_head.weight.dtype
|
192 |
-
min_dtype = torch.finfo(dtype).min
|
193 |
-
|
194 |
-
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
195 |
-
attention_mask,
|
196 |
-
sequence_length=sequence_length,
|
197 |
-
target_length=past_key_values.get_max_length(),
|
198 |
-
dtype=dtype,
|
199 |
-
device=device,
|
200 |
-
min_dtype=min_dtype,
|
201 |
-
cache_position=cache_position,
|
202 |
-
batch_size=batch_size,
|
203 |
-
)
|
204 |
-
|
205 |
-
model_inputs.update(
|
206 |
-
{
|
207 |
-
"graph": graph,
|
208 |
-
"position_ids": position_ids,
|
209 |
-
"cache_position": cache_position,
|
210 |
-
"past_key_values": past_key_values,
|
211 |
-
"use_cache": use_cache,
|
212 |
-
"attention_mask": attention_mask,
|
213 |
-
}
|
214 |
-
)
|
215 |
-
return model_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
g_retriever/g_retriever_pipeline.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
from transformers import Pipeline, AutoTokenizer
|
2 |
-
from torch_geometric.data import Batch
|
3 |
-
import torch
|
4 |
-
|
5 |
-
class GRetrieverPipeline(Pipeline):
|
6 |
-
def __init__(self, **kwargs):
|
7 |
-
Pipeline.__init__(self, **kwargs)
|
8 |
-
|
9 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
|
10 |
-
self.eos_user = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
|
11 |
-
self.max_txt_len = self.model.config.max_txt_len
|
12 |
-
self.bos_length = len(self.model.config.bos_id)
|
13 |
-
|
14 |
-
def _sanitize_parameters(self, **kwargs):
|
15 |
-
preprocess_kwargs = {}
|
16 |
-
if "textualized_graph" in kwargs:
|
17 |
-
preprocess_kwargs["textualized_graph"] = kwargs["textualized_graph"]
|
18 |
-
|
19 |
-
if "graph" in kwargs:
|
20 |
-
preprocess_kwargs["graph"] = kwargs["graph"]
|
21 |
-
|
22 |
-
return preprocess_kwargs, {}, {}
|
23 |
-
|
24 |
-
def preprocess(self, inputs, textualized_graph, graph):
|
25 |
-
textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
|
26 |
-
question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
|
27 |
-
eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
|
28 |
-
|
29 |
-
input_ids = torch.tensor([
|
30 |
-
[-1]*(self.bos_length + 1)
|
31 |
-
+ textualized_graph_ids
|
32 |
-
+ question_ids
|
33 |
-
+ eos_user_ids
|
34 |
-
])
|
35 |
-
model_inputs = {
|
36 |
-
"input_ids": input_ids,
|
37 |
-
"attention_mask": torch.ones_like(input_ids)
|
38 |
-
}
|
39 |
-
model_inputs.update({
|
40 |
-
"graph": Batch.from_data_list([graph])
|
41 |
-
})
|
42 |
-
|
43 |
-
return model_inputs
|
44 |
-
|
45 |
-
def _forward(self, model_inputs):
|
46 |
-
model_outputs = self.model.generate(**model_inputs)
|
47 |
-
|
48 |
-
return model_outputs
|
49 |
-
|
50 |
-
def postprocess(self, model_outputs):
|
51 |
-
return model_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|