justinhl commited on
Commit
dd89ecf
1 Parent(s): 5736c83

Upload hybrid_test.ipynb

Browse files
Files changed (1) hide show
  1. hybrid_test.ipynb +410 -0
hybrid_test.ipynb ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "id": "0paOn0yhDB63"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "from hybrid_pipe import HybridQAPipeline\n",
12
+ "from transformers import pipeline\n",
13
+ "from transformers.pipelines import PIPELINE_REGISTRY\n",
14
+ "\n",
15
+ "from transformers import AutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering\n"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 3,
21
+ "metadata": {
22
+ "id": "DuwOF8yjDB66"
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "# Register new pipe\n",
27
+ "PIPELINE_REGISTRY.register_pipeline(\n",
28
+ " \"hybrid-qa\",\n",
29
+ " pipeline_class=HybridQAPipeline,\n",
30
+ " pt_model=AutoModelForQuestionAnswering,\n",
31
+ " tf_model=TFAutoModelForQuestionAnswering\n",
32
+ ")"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 4,
38
+ "metadata": {
39
+ "id": "pf_tBYQsDB67",
40
+ "colab": {
41
+ "base_uri": "https://localhost:8080/"
42
+ },
43
+ "outputId": "2d75ec1b-a844-441b-ca84-7859dd8eedc5"
44
+ },
45
+ "outputs": [
46
+ {
47
+ "output_type": "stream",
48
+ "name": "stderr",
49
+ "text": [
50
+ "You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
51
+ ]
52
+ }
53
+ ],
54
+ "source": [
55
+ "# Create pipe instance\n",
56
+ "# Note: the model specified here does not matter, we just need to\n",
57
+ "# pass something valid to satisfy the pipeline class=\n",
58
+ "hybrid_pipe = pipeline(\"hybrid-qa\", model='datarpit/distilbert-base-uncased-finetuned-natural-questions')"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 5,
64
+ "metadata": {
65
+ "colab": {
66
+ "base_uri": "https://localhost:8080/"
67
+ },
68
+ "id": "KKv6ZS2LDB67",
69
+ "outputId": "58f78991-1204-4714-af1c-bad70d120118"
70
+ },
71
+ "outputs": [
72
+ {
73
+ "output_type": "stream",
74
+ "name": "stderr",
75
+ "text": [
76
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
77
+ "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
78
+ " warnings.warn(\n"
79
+ ]
80
+ },
81
+ {
82
+ "output_type": "execute_result",
83
+ "data": {
84
+ "text/plain": [
85
+ "{'guess': 'Oslo', 'confidence': 2.0940363768613864e-14}"
86
+ ]
87
+ },
88
+ "metadata": {},
89
+ "execution_count": 5
90
+ }
91
+ ],
92
+ "source": [
93
+ "# Inference testing!\n",
94
+ "hybrid_pipe(question=\"What is the capital of Norway?\",context=\"The capital of Norway is Oslo\")"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 6,
100
+ "metadata": {
101
+ "colab": {
102
+ "base_uri": "https://localhost:8080/",
103
+ "height": 53
104
+ },
105
+ "id": "sgrDgs9-DB68",
106
+ "outputId": "7fe9f733-f19b-43cb-e68c-33e302b2be43"
107
+ },
108
+ "outputs": [
109
+ {
110
+ "output_type": "execute_result",
111
+ "data": {
112
+ "text/plain": [
113
+ "CommitInfo(commit_url='https://huggingface.co/justinhl/hybrid-qa/commit/7019d3e4971d6c754e9529b5a3de9a0425c3cccf', commit_message='Upload HybridQAPipeline', commit_description='', oid='7019d3e4971d6c754e9529b5a3de9a0425c3cccf', pr_url=None, pr_revision=None, pr_num=None)"
114
+ ],
115
+ "application/vnd.google.colaboratory.intrinsic+json": {
116
+ "type": "string"
117
+ }
118
+ },
119
+ "metadata": {},
120
+ "execution_count": 6
121
+ }
122
+ ],
123
+ "source": [
124
+ "# Pushing to hub\n",
125
+ "hybrid_pipe.push_to_hub(\"hybrid-qa\")"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 7,
131
+ "metadata": {
132
+ "colab": {
133
+ "base_uri": "https://localhost:8080/"
134
+ },
135
+ "id": "PPOf6vUhDB68",
136
+ "outputId": "1fa601e1-dfc7-4128-c430-44652916aa87"
137
+ },
138
+ "outputs": [
139
+ {
140
+ "output_type": "stream",
141
+ "name": "stderr",
142
+ "text": [
143
+ "Some weights of the model checkpoint at justinhl/hybrid-qa were not used when initializing DistilBertForQuestionAnswering: ['model_extractive.distilbert.embeddings.LayerNorm.bias', 'model_extractive.distilbert.embeddings.LayerNorm.weight', 'model_extractive.distilbert.embeddings.position_embeddings.weight', 'model_extractive.distilbert.embeddings.word_embeddings.weight', 'model_extractive.distilbert.transformer.layer.0.attention.k_lin.bias', 'model_extractive.distilbert.transformer.layer.0.attention.k_lin.weight', 'model_extractive.distilbert.transformer.layer.0.attention.out_lin.bias', 'model_extractive.distilbert.transformer.layer.0.attention.out_lin.weight', 'model_extractive.distilbert.transformer.layer.0.attention.q_lin.bias', 'model_extractive.distilbert.transformer.layer.0.attention.q_lin.weight', 'model_extractive.distilbert.transformer.layer.0.attention.v_lin.bias', 'model_extractive.distilbert.transformer.layer.0.attention.v_lin.weight', 'model_extractive.distilbert.transformer.layer.0.ffn.lin1.bias', 'model_extractive.distilbert.transformer.layer.0.ffn.lin1.weight', 'model_extractive.distilbert.transformer.layer.0.ffn.lin2.bias', 'model_extractive.distilbert.transformer.layer.0.ffn.lin2.weight', 'model_extractive.distilbert.transformer.layer.0.output_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.0.output_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.0.sa_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.0.sa_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.1.attention.k_lin.bias', 'model_extractive.distilbert.transformer.layer.1.attention.k_lin.weight', 'model_extractive.distilbert.transformer.layer.1.attention.out_lin.bias', 'model_extractive.distilbert.transformer.layer.1.attention.out_lin.weight', 'model_extractive.distilbert.transformer.layer.1.attention.q_lin.bias', 'model_extractive.distilbert.transformer.layer.1.attention.q_lin.weight', 'model_extractive.distilbert.transformer.layer.1.attention.v_lin.bias', 'model_extractive.distilbert.transformer.layer.1.attention.v_lin.weight', 'model_extractive.distilbert.transformer.layer.1.ffn.lin1.bias', 'model_extractive.distilbert.transformer.layer.1.ffn.lin1.weight', 'model_extractive.distilbert.transformer.layer.1.ffn.lin2.bias', 'model_extractive.distilbert.transformer.layer.1.ffn.lin2.weight', 'model_extractive.distilbert.transformer.layer.1.output_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.1.output_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.1.sa_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.1.sa_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.2.attention.k_lin.bias', 'model_extractive.distilbert.transformer.layer.2.attention.k_lin.weight', 'model_extractive.distilbert.transformer.layer.2.attention.out_lin.bias', 'model_extractive.distilbert.transformer.layer.2.attention.out_lin.weight', 'model_extractive.distilbert.transformer.layer.2.attention.q_lin.bias', 'model_extractive.distilbert.transformer.layer.2.attention.q_lin.weight', 'model_extractive.distilbert.transformer.layer.2.attention.v_lin.bias', 'model_extractive.distilbert.transformer.layer.2.attention.v_lin.weight', 'model_extractive.distilbert.transformer.layer.2.ffn.lin1.bias', 'model_extractive.distilbert.transformer.layer.2.ffn.lin1.weight', 'model_extractive.distilbert.transformer.layer.2.ffn.lin2.bias', 'model_extractive.distilbert.transformer.layer.2.ffn.lin2.weight', 'model_extractive.distilbert.transformer.layer.2.output_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.2.output_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.2.sa_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.2.sa_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.3.attention.k_lin.bias', 'model_extractive.distilbert.transformer.layer.3.attention.k_lin.weight', 'model_extractive.distilbert.transformer.layer.3.attention.out_lin.bias', 'model_extractive.distilbert.transformer.layer.3.attention.out_lin.weight', 'model_extractive.distilbert.transformer.layer.3.attention.q_lin.bias', 'model_extractive.distilbert.transformer.layer.3.attention.q_lin.weight', 'model_extractive.distilbert.transformer.layer.3.attention.v_lin.bias', 'model_extractive.distilbert.transformer.layer.3.attention.v_lin.weight', 'model_extractive.distilbert.transformer.layer.3.ffn.lin1.bias', 'model_extractive.distilbert.transformer.layer.3.ffn.lin1.weight', 'model_extractive.distilbert.transformer.layer.3.ffn.lin2.bias', 'model_extractive.distilbert.transformer.layer.3.ffn.lin2.weight', 'model_extractive.distilbert.transformer.layer.3.output_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.3.output_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.3.sa_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.3.sa_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.4.attention.k_lin.bias', 'model_extractive.distilbert.transformer.layer.4.attention.k_lin.weight', 'model_extractive.distilbert.transformer.layer.4.attention.out_lin.bias', 'model_extractive.distilbert.transformer.layer.4.attention.out_lin.weight', 'model_extractive.distilbert.transformer.layer.4.attention.q_lin.bias', 'model_extractive.distilbert.transformer.layer.4.attention.q_lin.weight', 'model_extractive.distilbert.transformer.layer.4.attention.v_lin.bias', 'model_extractive.distilbert.transformer.layer.4.attention.v_lin.weight', 'model_extractive.distilbert.transformer.layer.4.ffn.lin1.bias', 'model_extractive.distilbert.transformer.layer.4.ffn.lin1.weight', 'model_extractive.distilbert.transformer.layer.4.ffn.lin2.bias', 'model_extractive.distilbert.transformer.layer.4.ffn.lin2.weight', 'model_extractive.distilbert.transformer.layer.4.output_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.4.output_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.4.sa_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.4.sa_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.5.attention.k_lin.bias', 'model_extractive.distilbert.transformer.layer.5.attention.k_lin.weight', 'model_extractive.distilbert.transformer.layer.5.attention.out_lin.bias', 'model_extractive.distilbert.transformer.layer.5.attention.out_lin.weight', 'model_extractive.distilbert.transformer.layer.5.attention.q_lin.bias', 'model_extractive.distilbert.transformer.layer.5.attention.q_lin.weight', 'model_extractive.distilbert.transformer.layer.5.attention.v_lin.bias', 'model_extractive.distilbert.transformer.layer.5.attention.v_lin.weight', 'model_extractive.distilbert.transformer.layer.5.ffn.lin1.bias', 'model_extractive.distilbert.transformer.layer.5.ffn.lin1.weight', 'model_extractive.distilbert.transformer.layer.5.ffn.lin2.bias', 'model_extractive.distilbert.transformer.layer.5.ffn.lin2.weight', 'model_extractive.distilbert.transformer.layer.5.output_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.5.output_layer_norm.weight', 'model_extractive.distilbert.transformer.layer.5.sa_layer_norm.bias', 'model_extractive.distilbert.transformer.layer.5.sa_layer_norm.weight', 'model_extractive.qa_outputs.bias', 'model_extractive.qa_outputs.weight', 'model_generative.decoder.block.0.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.0.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.0.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'model_generative.decoder.block.0.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.0.layer.0.layer_norm.weight', 'model_generative.decoder.block.0.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.0.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.0.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.0.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.0.layer.1.layer_norm.weight', 'model_generative.decoder.block.0.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.0.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.0.layer.2.layer_norm.weight', 'model_generative.decoder.block.1.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.1.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.1.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.1.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.1.layer.0.layer_norm.weight', 'model_generative.decoder.block.1.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.1.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.1.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.1.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.1.layer.1.layer_norm.weight', 'model_generative.decoder.block.1.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.1.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.1.layer.2.layer_norm.weight', 'model_generative.decoder.block.10.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.10.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.10.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.10.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.10.layer.0.layer_norm.weight', 'model_generative.decoder.block.10.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.10.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.10.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.10.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.10.layer.1.layer_norm.weight', 'model_generative.decoder.block.10.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.10.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.10.layer.2.layer_norm.weight', 'model_generative.decoder.block.11.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.11.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.11.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.11.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.11.layer.0.layer_norm.weight', 'model_generative.decoder.block.11.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.11.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.11.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.11.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.11.layer.1.layer_norm.weight', 'model_generative.decoder.block.11.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.11.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.11.layer.2.layer_norm.weight', 'model_generative.decoder.block.2.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.2.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.2.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.2.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.2.layer.0.layer_norm.weight', 'model_generative.decoder.block.2.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.2.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.2.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.2.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.2.layer.1.layer_norm.weight', 'model_generative.decoder.block.2.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.2.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.2.layer.2.layer_norm.weight', 'model_generative.decoder.block.3.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.3.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.3.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.3.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.3.layer.0.layer_norm.weight', 'model_generative.decoder.block.3.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.3.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.3.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.3.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.3.layer.1.layer_norm.weight', 'model_generative.decoder.block.3.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.3.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.3.layer.2.layer_norm.weight', 'model_generative.decoder.block.4.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.4.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.4.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.4.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.4.layer.0.layer_norm.weight', 'model_generative.decoder.block.4.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.4.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.4.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.4.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.4.layer.1.layer_norm.weight', 'model_generative.decoder.block.4.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.4.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.4.layer.2.layer_norm.weight', 'model_generative.decoder.block.5.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.5.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.5.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.5.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.5.layer.0.layer_norm.weight', 'model_generative.decoder.block.5.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.5.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.5.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.5.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.5.layer.1.layer_norm.weight', 'model_generative.decoder.block.5.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.5.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.5.layer.2.layer_norm.weight', 'model_generative.decoder.block.6.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.6.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.6.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.6.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.6.layer.0.layer_norm.weight', 'model_generative.decoder.block.6.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.6.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.6.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.6.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.6.layer.1.layer_norm.weight', 'model_generative.decoder.block.6.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.6.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.6.layer.2.layer_norm.weight', 'model_generative.decoder.block.7.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.7.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.7.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.7.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.7.layer.0.layer_norm.weight', 'model_generative.decoder.block.7.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.7.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.7.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.7.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.7.layer.1.layer_norm.weight', 'model_generative.decoder.block.7.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.7.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.7.layer.2.layer_norm.weight', 'model_generative.decoder.block.8.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.8.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.8.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.8.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.8.layer.0.layer_norm.weight', 'model_generative.decoder.block.8.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.8.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.8.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.8.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.8.layer.1.layer_norm.weight', 'model_generative.decoder.block.8.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.8.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.8.layer.2.layer_norm.weight', 'model_generative.decoder.block.9.layer.0.SelfAttention.k.weight', 'model_generative.decoder.block.9.layer.0.SelfAttention.o.weight', 'model_generative.decoder.block.9.layer.0.SelfAttention.q.weight', 'model_generative.decoder.block.9.layer.0.SelfAttention.v.weight', 'model_generative.decoder.block.9.layer.0.layer_norm.weight', 'model_generative.decoder.block.9.layer.1.EncDecAttention.k.weight', 'model_generative.decoder.block.9.layer.1.EncDecAttention.o.weight', 'model_generative.decoder.block.9.layer.1.EncDecAttention.q.weight', 'model_generative.decoder.block.9.layer.1.EncDecAttention.v.weight', 'model_generative.decoder.block.9.layer.1.layer_norm.weight', 'model_generative.decoder.block.9.layer.2.DenseReluDense.wi.weight', 'model_generative.decoder.block.9.layer.2.DenseReluDense.wo.weight', 'model_generative.decoder.block.9.layer.2.layer_norm.weight', 'model_generative.decoder.embed_tokens.weight', 'model_generative.decoder.final_layer_norm.weight', 'model_generative.encoder.block.0.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.0.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.0.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'model_generative.encoder.block.0.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.0.layer.0.layer_norm.weight', 'model_generative.encoder.block.0.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.0.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.0.layer.1.layer_norm.weight', 'model_generative.encoder.block.1.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.1.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.1.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.1.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.1.layer.0.layer_norm.weight', 'model_generative.encoder.block.1.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.1.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.1.layer.1.layer_norm.weight', 'model_generative.encoder.block.10.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.10.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.10.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.10.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.10.layer.0.layer_norm.weight', 'model_generative.encoder.block.10.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.10.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.10.layer.1.layer_norm.weight', 'model_generative.encoder.block.11.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.11.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.11.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.11.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.11.layer.0.layer_norm.weight', 'model_generative.encoder.block.11.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.11.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.11.layer.1.layer_norm.weight', 'model_generative.encoder.block.2.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.2.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.2.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.2.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.2.layer.0.layer_norm.weight', 'model_generative.encoder.block.2.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.2.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.2.layer.1.layer_norm.weight', 'model_generative.encoder.block.3.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.3.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.3.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.3.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.3.layer.0.layer_norm.weight', 'model_generative.encoder.block.3.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.3.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.3.layer.1.layer_norm.weight', 'model_generative.encoder.block.4.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.4.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.4.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.4.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.4.layer.0.layer_norm.weight', 'model_generative.encoder.block.4.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.4.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.4.layer.1.layer_norm.weight', 'model_generative.encoder.block.5.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.5.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.5.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.5.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.5.layer.0.layer_norm.weight', 'model_generative.encoder.block.5.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.5.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.5.layer.1.layer_norm.weight', 'model_generative.encoder.block.6.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.6.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.6.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.6.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.6.layer.0.layer_norm.weight', 'model_generative.encoder.block.6.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.6.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.6.layer.1.layer_norm.weight', 'model_generative.encoder.block.7.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.7.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.7.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.7.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.7.layer.0.layer_norm.weight', 'model_generative.encoder.block.7.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.7.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.7.layer.1.layer_norm.weight', 'model_generative.encoder.block.8.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.8.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.8.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.8.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.8.layer.0.layer_norm.weight', 'model_generative.encoder.block.8.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.8.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.8.layer.1.layer_norm.weight', 'model_generative.encoder.block.9.layer.0.SelfAttention.k.weight', 'model_generative.encoder.block.9.layer.0.SelfAttention.o.weight', 'model_generative.encoder.block.9.layer.0.SelfAttention.q.weight', 'model_generative.encoder.block.9.layer.0.SelfAttention.v.weight', 'model_generative.encoder.block.9.layer.0.layer_norm.weight', 'model_generative.encoder.block.9.layer.1.DenseReluDense.wi.weight', 'model_generative.encoder.block.9.layer.1.DenseReluDense.wo.weight', 'model_generative.encoder.block.9.layer.1.layer_norm.weight', 'model_generative.encoder.embed_tokens.weight', 'model_generative.encoder.final_layer_norm.weight', 'model_generative.lm_head.weight', 'model_generative.shared.weight']\n",
144
+ "- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
145
+ "- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
146
+ "Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at justinhl/hybrid-qa and are newly initialized: ['embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'embeddings.word_embeddings.weight', 'qa_outputs.bias', 'qa_outputs.weight', 'transformer.layer.0.attention.k_lin.bias', 'transformer.layer.0.attention.k_lin.weight', 'transformer.layer.0.attention.out_lin.bias', 'transformer.layer.0.attention.out_lin.weight', 'transformer.layer.0.attention.q_lin.bias', 'transformer.layer.0.attention.q_lin.weight', 'transformer.layer.0.attention.v_lin.bias', 'transformer.layer.0.attention.v_lin.weight', 'transformer.layer.0.ffn.lin1.bias', 'transformer.layer.0.ffn.lin1.weight', 'transformer.layer.0.ffn.lin2.bias', 'transformer.layer.0.ffn.lin2.weight', 'transformer.layer.0.output_layer_norm.bias', 'transformer.layer.0.output_layer_norm.weight', 'transformer.layer.0.sa_layer_norm.bias', 'transformer.layer.0.sa_layer_norm.weight', 'transformer.layer.1.attention.k_lin.bias', 'transformer.layer.1.attention.k_lin.weight', 'transformer.layer.1.attention.out_lin.bias', 'transformer.layer.1.attention.out_lin.weight', 'transformer.layer.1.attention.q_lin.bias', 'transformer.layer.1.attention.q_lin.weight', 'transformer.layer.1.attention.v_lin.bias', 'transformer.layer.1.attention.v_lin.weight', 'transformer.layer.1.ffn.lin1.bias', 'transformer.layer.1.ffn.lin1.weight', 'transformer.layer.1.ffn.lin2.bias', 'transformer.layer.1.ffn.lin2.weight', 'transformer.layer.1.output_layer_norm.bias', 'transformer.layer.1.output_layer_norm.weight', 'transformer.layer.1.sa_layer_norm.bias', 'transformer.layer.1.sa_layer_norm.weight', 'transformer.layer.2.attention.k_lin.bias', 'transformer.layer.2.attention.k_lin.weight', 'transformer.layer.2.attention.out_lin.bias', 'transformer.layer.2.attention.out_lin.weight', 'transformer.layer.2.attention.q_lin.bias', 'transformer.layer.2.attention.q_lin.weight', 'transformer.layer.2.attention.v_lin.bias', 'transformer.layer.2.attention.v_lin.weight', 'transformer.layer.2.ffn.lin1.bias', 'transformer.layer.2.ffn.lin1.weight', 'transformer.layer.2.ffn.lin2.bias', 'transformer.layer.2.ffn.lin2.weight', 'transformer.layer.2.output_layer_norm.bias', 'transformer.layer.2.output_layer_norm.weight', 'transformer.layer.2.sa_layer_norm.bias', 'transformer.layer.2.sa_layer_norm.weight', 'transformer.layer.3.attention.k_lin.bias', 'transformer.layer.3.attention.k_lin.weight', 'transformer.layer.3.attention.out_lin.bias', 'transformer.layer.3.attention.out_lin.weight', 'transformer.layer.3.attention.q_lin.bias', 'transformer.layer.3.attention.q_lin.weight', 'transformer.layer.3.attention.v_lin.bias', 'transformer.layer.3.attention.v_lin.weight', 'transformer.layer.3.ffn.lin1.bias', 'transformer.layer.3.ffn.lin1.weight', 'transformer.layer.3.ffn.lin2.bias', 'transformer.layer.3.ffn.lin2.weight', 'transformer.layer.3.output_layer_norm.bias', 'transformer.layer.3.output_layer_norm.weight', 'transformer.layer.3.sa_layer_norm.bias', 'transformer.layer.3.sa_layer_norm.weight', 'transformer.layer.4.attention.k_lin.bias', 'transformer.layer.4.attention.k_lin.weight', 'transformer.layer.4.attention.out_lin.bias', 'transformer.layer.4.attention.out_lin.weight', 'transformer.layer.4.attention.q_lin.bias', 'transformer.layer.4.attention.q_lin.weight', 'transformer.layer.4.attention.v_lin.bias', 'transformer.layer.4.attention.v_lin.weight', 'transformer.layer.4.ffn.lin1.bias', 'transformer.layer.4.ffn.lin1.weight', 'transformer.layer.4.ffn.lin2.bias', 'transformer.layer.4.ffn.lin2.weight', 'transformer.layer.4.output_layer_norm.bias', 'transformer.layer.4.output_layer_norm.weight', 'transformer.layer.4.sa_layer_norm.bias', 'transformer.layer.4.sa_layer_norm.weight', 'transformer.layer.5.attention.k_lin.bias', 'transformer.layer.5.attention.k_lin.weight', 'transformer.layer.5.attention.out_lin.bias', 'transformer.layer.5.attention.out_lin.weight', 'transformer.layer.5.attention.q_lin.bias', 'transformer.layer.5.attention.q_lin.weight', 'transformer.layer.5.attention.v_lin.bias', 'transformer.layer.5.attention.v_lin.weight', 'transformer.layer.5.ffn.lin1.bias', 'transformer.layer.5.ffn.lin1.weight', 'transformer.layer.5.ffn.lin2.bias', 'transformer.layer.5.ffn.lin2.weight', 'transformer.layer.5.output_layer_norm.bias', 'transformer.layer.5.output_layer_norm.weight', 'transformer.layer.5.sa_layer_norm.bias', 'transformer.layer.5.sa_layer_norm.weight']\n",
147
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
148
+ ]
149
+ }
150
+ ],
151
+ "source": [
152
+ "# Importing from remote\n",
153
+ "imported_pipe = pipeline(\"hybrid-qa\", model=\"justinhl/hybrid-qa\", trust_remote_code=True)"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "source": [
159
+ "# Inference testing!\n",
160
+ "imported_pipe(question=\"What is the capital of Norway?\",context=\"The capital of Norway is Oslo\")"
161
+ ],
162
+ "metadata": {
163
+ "colab": {
164
+ "base_uri": "https://localhost:8080/"
165
+ },
166
+ "id": "sQsoT-UpPp0O",
167
+ "outputId": "dd922309-bd21-4684-caee-c4d4499bf69b"
168
+ },
169
+ "execution_count": 8,
170
+ "outputs": [
171
+ {
172
+ "output_type": "stream",
173
+ "name": "stderr",
174
+ "text": [
175
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
176
+ "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1141: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
177
+ " warnings.warn(\n"
178
+ ]
179
+ },
180
+ {
181
+ "output_type": "execute_result",
182
+ "data": {
183
+ "text/plain": [
184
+ "{'guess': 'Oslo', 'confidence': 2.0940363768613864e-14}"
185
+ ]
186
+ },
187
+ "metadata": {},
188
+ "execution_count": 8
189
+ }
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "source": [
195
+ "print(\"Model loaded:\", imported_pipe.model)"
196
+ ],
197
+ "metadata": {
198
+ "colab": {
199
+ "base_uri": "https://localhost:8080/"
200
+ },
201
+ "id": "GEmtld6OVT7W",
202
+ "outputId": "93217a25-668e-4a46-8fc9-9db440693a1c"
203
+ },
204
+ "execution_count": 9,
205
+ "outputs": [
206
+ {
207
+ "output_type": "stream",
208
+ "name": "stdout",
209
+ "text": [
210
+ "Model loaded: HybridQAModel(\n",
211
+ " (model_extractive): DistilBertForQuestionAnswering(\n",
212
+ " (distilbert): DistilBertModel(\n",
213
+ " (embeddings): Embeddings(\n",
214
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
215
+ " (position_embeddings): Embedding(512, 768)\n",
216
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
217
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
218
+ " )\n",
219
+ " (transformer): Transformer(\n",
220
+ " (layer): ModuleList(\n",
221
+ " (0-5): 6 x TransformerBlock(\n",
222
+ " (attention): MultiHeadSelfAttention(\n",
223
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
224
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
225
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
226
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
227
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
228
+ " )\n",
229
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
230
+ " (ffn): FFN(\n",
231
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
232
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
233
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
234
+ " (activation): GELUActivation()\n",
235
+ " )\n",
236
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
237
+ " )\n",
238
+ " )\n",
239
+ " )\n",
240
+ " )\n",
241
+ " (qa_outputs): Linear(in_features=768, out_features=2, bias=True)\n",
242
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
243
+ " )\n",
244
+ " (model_generative): T5ForConditionalGeneration(\n",
245
+ " (shared): Embedding(32128, 768)\n",
246
+ " (encoder): T5Stack(\n",
247
+ " (embed_tokens): Embedding(32128, 768)\n",
248
+ " (block): ModuleList(\n",
249
+ " (0): T5Block(\n",
250
+ " (layer): ModuleList(\n",
251
+ " (0): T5LayerSelfAttention(\n",
252
+ " (SelfAttention): T5Attention(\n",
253
+ " (q): Linear(in_features=768, out_features=768, bias=False)\n",
254
+ " (k): Linear(in_features=768, out_features=768, bias=False)\n",
255
+ " (v): Linear(in_features=768, out_features=768, bias=False)\n",
256
+ " (o): Linear(in_features=768, out_features=768, bias=False)\n",
257
+ " (relative_attention_bias): Embedding(32, 12)\n",
258
+ " )\n",
259
+ " (layer_norm): T5LayerNorm()\n",
260
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
261
+ " )\n",
262
+ " (1): T5LayerFF(\n",
263
+ " (DenseReluDense): T5DenseActDense(\n",
264
+ " (wi): Linear(in_features=768, out_features=3072, bias=False)\n",
265
+ " (wo): Linear(in_features=3072, out_features=768, bias=False)\n",
266
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
267
+ " (act): ReLU()\n",
268
+ " )\n",
269
+ " (layer_norm): T5LayerNorm()\n",
270
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
271
+ " )\n",
272
+ " )\n",
273
+ " )\n",
274
+ " (1-11): 11 x T5Block(\n",
275
+ " (layer): ModuleList(\n",
276
+ " (0): T5LayerSelfAttention(\n",
277
+ " (SelfAttention): T5Attention(\n",
278
+ " (q): Linear(in_features=768, out_features=768, bias=False)\n",
279
+ " (k): Linear(in_features=768, out_features=768, bias=False)\n",
280
+ " (v): Linear(in_features=768, out_features=768, bias=False)\n",
281
+ " (o): Linear(in_features=768, out_features=768, bias=False)\n",
282
+ " )\n",
283
+ " (layer_norm): T5LayerNorm()\n",
284
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
285
+ " )\n",
286
+ " (1): T5LayerFF(\n",
287
+ " (DenseReluDense): T5DenseActDense(\n",
288
+ " (wi): Linear(in_features=768, out_features=3072, bias=False)\n",
289
+ " (wo): Linear(in_features=3072, out_features=768, bias=False)\n",
290
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
291
+ " (act): ReLU()\n",
292
+ " )\n",
293
+ " (layer_norm): T5LayerNorm()\n",
294
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
295
+ " )\n",
296
+ " )\n",
297
+ " )\n",
298
+ " )\n",
299
+ " (final_layer_norm): T5LayerNorm()\n",
300
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
301
+ " )\n",
302
+ " (decoder): T5Stack(\n",
303
+ " (embed_tokens): Embedding(32128, 768)\n",
304
+ " (block): ModuleList(\n",
305
+ " (0): T5Block(\n",
306
+ " (layer): ModuleList(\n",
307
+ " (0): T5LayerSelfAttention(\n",
308
+ " (SelfAttention): T5Attention(\n",
309
+ " (q): Linear(in_features=768, out_features=768, bias=False)\n",
310
+ " (k): Linear(in_features=768, out_features=768, bias=False)\n",
311
+ " (v): Linear(in_features=768, out_features=768, bias=False)\n",
312
+ " (o): Linear(in_features=768, out_features=768, bias=False)\n",
313
+ " (relative_attention_bias): Embedding(32, 12)\n",
314
+ " )\n",
315
+ " (layer_norm): T5LayerNorm()\n",
316
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
317
+ " )\n",
318
+ " (1): T5LayerCrossAttention(\n",
319
+ " (EncDecAttention): T5Attention(\n",
320
+ " (q): Linear(in_features=768, out_features=768, bias=False)\n",
321
+ " (k): Linear(in_features=768, out_features=768, bias=False)\n",
322
+ " (v): Linear(in_features=768, out_features=768, bias=False)\n",
323
+ " (o): Linear(in_features=768, out_features=768, bias=False)\n",
324
+ " )\n",
325
+ " (layer_norm): T5LayerNorm()\n",
326
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
327
+ " )\n",
328
+ " (2): T5LayerFF(\n",
329
+ " (DenseReluDense): T5DenseActDense(\n",
330
+ " (wi): Linear(in_features=768, out_features=3072, bias=False)\n",
331
+ " (wo): Linear(in_features=3072, out_features=768, bias=False)\n",
332
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
333
+ " (act): ReLU()\n",
334
+ " )\n",
335
+ " (layer_norm): T5LayerNorm()\n",
336
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
337
+ " )\n",
338
+ " )\n",
339
+ " )\n",
340
+ " (1-11): 11 x T5Block(\n",
341
+ " (layer): ModuleList(\n",
342
+ " (0): T5LayerSelfAttention(\n",
343
+ " (SelfAttention): T5Attention(\n",
344
+ " (q): Linear(in_features=768, out_features=768, bias=False)\n",
345
+ " (k): Linear(in_features=768, out_features=768, bias=False)\n",
346
+ " (v): Linear(in_features=768, out_features=768, bias=False)\n",
347
+ " (o): Linear(in_features=768, out_features=768, bias=False)\n",
348
+ " )\n",
349
+ " (layer_norm): T5LayerNorm()\n",
350
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
351
+ " )\n",
352
+ " (1): T5LayerCrossAttention(\n",
353
+ " (EncDecAttention): T5Attention(\n",
354
+ " (q): Linear(in_features=768, out_features=768, bias=False)\n",
355
+ " (k): Linear(in_features=768, out_features=768, bias=False)\n",
356
+ " (v): Linear(in_features=768, out_features=768, bias=False)\n",
357
+ " (o): Linear(in_features=768, out_features=768, bias=False)\n",
358
+ " )\n",
359
+ " (layer_norm): T5LayerNorm()\n",
360
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
361
+ " )\n",
362
+ " (2): T5LayerFF(\n",
363
+ " (DenseReluDense): T5DenseActDense(\n",
364
+ " (wi): Linear(in_features=768, out_features=3072, bias=False)\n",
365
+ " (wo): Linear(in_features=3072, out_features=768, bias=False)\n",
366
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
367
+ " (act): ReLU()\n",
368
+ " )\n",
369
+ " (layer_norm): T5LayerNorm()\n",
370
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
371
+ " )\n",
372
+ " )\n",
373
+ " )\n",
374
+ " )\n",
375
+ " (final_layer_norm): T5LayerNorm()\n",
376
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
377
+ " )\n",
378
+ " (lm_head): Linear(in_features=768, out_features=32128, bias=False)\n",
379
+ " )\n",
380
+ ")\n"
381
+ ]
382
+ }
383
+ ]
384
+ }
385
+ ],
386
+ "metadata": {
387
+ "kernelspec": {
388
+ "display_name": "Python 3",
389
+ "language": "python",
390
+ "name": "python3"
391
+ },
392
+ "language_info": {
393
+ "codemirror_mode": {
394
+ "name": "ipython",
395
+ "version": 3
396
+ },
397
+ "file_extension": ".py",
398
+ "mimetype": "text/x-python",
399
+ "name": "python",
400
+ "nbconvert_exporter": "python",
401
+ "pygments_lexer": "ipython3",
402
+ "version": "3.11.7"
403
+ },
404
+ "colab": {
405
+ "provenance": []
406
+ }
407
+ },
408
+ "nbformat": 4,
409
+ "nbformat_minor": 0
410
+ }