import copy import torch import math import torch.nn as nn from torch.nn.parameter import Parameter import random import numpy as np from load_weights import load_weight from sklearn.model_selection import train_test_split from transformers import GPT2TokenizerFast import pandas as pd from torch.utils.data import Dataset, DataLoader from transformers import AdamW, get_linear_schedule_with_warmup torch.manual_seed(42) import nltk nltk.download('punkt') from transformers import GPT2Tokenizer from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler import datetime import time import os os.environ["CUDA_LAUNCH_BLOCKING"] = "1" from tqdm import trange import gradio as gr import re def gelu(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) class Conv1D(nn.Module): def __init__(self, nf, nx): super(Conv1D, self).__init__() self.nf = nf w = torch.empty(nx, nf) nn.init.normal_(w, std=0.02) self.weight = Parameter(w) self.bias = Parameter(torch.zeros(nf)) def forward(self, x): size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) x = x.view(*size_out) return x class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class Attention(nn.Module): def __init__(self, nx, n_ctx, config, scale=False): super(Attention, self).__init__() n_state = nx # in Attention: n_state=768 (nx=n_embd) # [switch nx => n_state from Block to Attention to keep identical to TF implem] assert n_state % config.n_head == 0 self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) self.n_head = config.n_head self.split_size = n_state self.scale = scale self.c_attn = Conv1D(n_state * 3, nx) self.c_proj = Conv1D(n_state, nx) def _attn(self, q, k, v): w = torch.matmul(q, k) if self.scale: w = w / math.sqrt(v.size(-1)) nd, ns = w.size(-2), w.size(-1) b = self.bias[:, :, ns-nd:ns, :ns] w = w * b - 1e10 * (1 - b) w = nn.Softmax(dim=-1)(w) return torch.matmul(w, v) def merge_heads(self, x): x = x.permute(0, 2, 1, 3).contiguous() new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states def split_heads(self, x, k=False): new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states if k: return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) else: return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def forward(self, x, layer_past=None): x = self.c_attn(x) query, key, value = x.split(self.split_size, dim=2) query = self.split_heads(query) key = self.split_heads(key, k=True) value = self.split_heads(value) if layer_past is not None: past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below key = torch.cat((past_key, key), dim=-1) value = torch.cat((past_value, value), dim=-2) present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking a = self._attn(query, key, value) a = self.merge_heads(a) a = self.c_proj(a) return a, present class MLP(nn.Module): def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) super(MLP, self).__init__() nx = config.n_embd self.c_fc = Conv1D(n_state, nx) self.c_proj = Conv1D(nx, n_state) self.act = gelu def forward(self, x): h = self.act(self.c_fc(x)) h2 = self.c_proj(h) return h2 class Block(nn.Module): def __init__(self, n_ctx, config, scale=False): super(Block, self).__init__() nx = config.n_embd self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.attn = Attention(nx, n_ctx, config, scale) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.mlp = MLP(4 * nx, config) def forward(self, x, layer_past=None): a, present = self.attn(self.ln_1(x), layer_past=layer_past) x = x + a m = self.mlp(self.ln_2(x)) x = x + m return x, present class GPT2Model(nn.Module): def __init__(self, config): super(GPT2Model, self).__init__() self.n_layer = config.n_layer self.n_embd = config.n_embd self.n_vocab = config.vocab_size self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) block = Block(config.n_ctx, config, scale=True) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) def set_embeddings_weights(self, model_embeddings_weights): embed_shape = model_embeddings_weights.shape self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder.weight = model_embeddings_weights # Tied weights def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): if (input_ids >= self.n_vocab).any(): raise ValueError(f"Invalid token ID found in input_ids: {input_ids}") # print(f"input_ids: {input_ids}") # Debugging statement # print(f"Max input_id: {input_ids.max().item()}") # Debugging statement # print(f"Min input_id: {input_ids.min().item()}") # Debugging statement if past is None: past_length = 0 past = [None] * len(self.h) else: past_length = past[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_ids.size(-1)) position_ids = position_ids.view(-1, position_ids.size(-1)) inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) # print(f"inputs_embeds shape: {inputs_embeds.shape}") # print(f"position_embeds shape: {position_embeds.shape}") if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) token_type_embeds = self.wte(token_type_ids) else: token_type_embeds = 0 hidden_states = inputs_embeds + position_embeds + token_type_embeds presents = [] for block, layer_past in zip(self.h, past): hidden_states, present = block(hidden_states, layer_past) presents.append(present) hidden_states = self.ln_f(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) return hidden_states.view(*output_shape), presents class GPT2LMHead(nn.Module): def __init__(self, model_embeddings_weights, config): super(GPT2LMHead, self).__init__() self.n_embd = config.n_embd self.set_embeddings_weights(model_embeddings_weights) def set_embeddings_weights(self, model_embeddings_weights): embed_shape = model_embeddings_weights.shape self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder.weight = model_embeddings_weights # Tied weights def forward(self, hidden_state): # Truncated Language modeling logits (we remove the last token) # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) lm_logits = self.decoder(hidden_state) return lm_logits import torch.nn.functional as F def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) top_k > 0: keep only top k tokens with highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). filter_value: value to replace filtered logits. """ assert logits.dim() == 2 # batch size x vocabulary size top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits class GPT2LMHeadModel(nn.Module): def __init__(self, config): super(GPT2LMHeadModel, self).__init__() self.transformer = GPT2Model(config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) def set_tied(self): """ Make sure we are sharing the embeddings """ self.lm_head.set_embeddings_weights(self.transformer.wte.weight) def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) lm_logits = self.lm_head(hidden_states) outputs = (lm_logits,presents) if lm_labels is not None: shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = lm_labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) outputs = (loss,) + outputs return outputs import torch.nn.functional as F def generate( self, input_ids, max_length, temperature=1.0, top_k=0, top_p=0.9, repetition_penalty=1.0, device='cuda' ): self.eval() input_ids = input_ids.to(device) batch_size = input_ids.shape[0] past = None generated = input_ids with torch.no_grad(): for _ in range(max_length): outputs = self(input_ids, past=past) next_token_logits = outputs[0][:, -1, :] past = outputs[1] for i in range(batch_size): for token_id in set(generated[i].tolist()): next_token_logits[i, token_id] /= repetition_penalty next_token_logits = next_token_logits / temperature filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) generated = torch.cat((generated, next_token), dim=1) if (next_token == self.config.eos_token_id).all(): break input_ids = next_token return generated class GPT2Config(object): def __init__( self, vocab_size_or_config_json_file=50257, n_positions=1024, n_ctx=1024, n_embd=768, n_layer=12, n_head=12, layer_norm_epsilon=1e-5, initializer_range=0.02, ): self.vocab_size = vocab_size_or_config_json_file self.n_ctx = n_ctx self.n_positions = n_positions self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range device = torch.device("cuda" if torch.cuda.is_available() else "cpu") config = GPT2Config() model = GPT2LMHeadModel(config) state_dict = torch.load(r'epoch_5.pth', map_location='cpu' if not torch.cuda.is_available() else None) model = load_weight(model, state_dict) model.to(device) print(model) model.eval() tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size x vocabulary size) top_k > 0: keep only top k tokens with highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) """ assert logits.dim() == 2, "Expected logits dimension to be 2 (batch size x vocabulary size)" top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(nn.Softmax(dim=-1)(sorted_logits), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # Ensure that the dimensions match if sorted_indices_to_remove.size() != sorted_indices.size(): raise ValueError(f"Size mismatch: {sorted_indices_to_remove.size()} vs {sorted_indices.size()}") indices_to_remove = sorted_indices[sorted_indices_to_remove] # Expand dimensions to match logits tensor and use scatter_ for batch_idx in range(logits.size(0)): logits[batch_idx, indices_to_remove[batch_idx]] = filter_value return logits # prompt_text = "What is the classical conceptualisation of oxidation and reduction in redox reactions?" # prompt = f"\n<|startoftext|>[WP] {prompt_text} \n[RESPONSE]" # input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) # max_length = 50 # temperature = 0.7 # top_k = 50 # top_p = 0.95 # repetition_penalty = 1.0 # with torch.no_grad(): # for _ in range(max_length): # outputs = model(input_ids) # logits = outputs[0] # next_token_logits = logits[:, -1, :] / temperature # # Apply repetition penalty # for i in range(input_ids.size(0)): # for token_id in set(input_ids[i].tolist()): # next_token_logits[0, token_id] /= repetition_penalty # # Filter logits using top-k and/or top-p filtering # filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) # next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) # input_ids = torch.cat([input_ids, next_token], dim=-1).to(device) # import re # # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) # # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:] # print(input_ids[0]) # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:] # print(wp_responses) # Define the generation function def generate_text(prompt_text, max_length=50, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0): prompt = f"\n[WP] {prompt_text} \n[RESPONSE]" input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) with torch.no_grad(): for _ in range(max_length): outputs = model(input_ids) logits = outputs[0] next_token_logits = logits[:, -1, :] / temperature # Apply repetition penalty for i in range(input_ids.size(0)): for token_id in set(input_ids[i].tolist()): next_token_logits[0, token_id] /= repetition_penalty # Filter logits using top-k and/or top-p filtering filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=-1).to(device) generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:] return wp_responses[1] # Define example prompts examples = [ "What is the classical conceptualisation of oxidation and reduction in redox reactions?", "What is the difference between alkenes and alkynes in terms of reactivity?", "What is the first law of thermodynamics in thermodynamics?", "What is a popular type of vending machine in banking services?", "What has the worldwide campaign against smallpox led to?" ] # Define the Gradio interface using Blocks with gr.Blocks() as demo: with gr.Row(): gr.Markdown("