rashmi commited on
Commit
e28c898
1 Parent(s): 7ebfc36
Files changed (1) hide show
  1. app.py +121 -1
app.py CHANGED
@@ -41,12 +41,132 @@ theme = gr.themes.Monochrome(
41
  font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
42
  )
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def do_submit(question, response):
45
  full_text = question + " " + response
46
  # result = do_inference(full_text)
47
  return "result"
48
 
49
-
50
  @spaces.GPU
51
  def greet():
52
  pass
 
41
  font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
42
  )
43
 
44
+ ### Load the model
45
+ class CFG:
46
+ num_workers = os.cpu_count()
47
+ llm_backbone = "HuggingFaceH4/zephyr-7b-beta"
48
+ # tokenizer_path = "HuggingFaceH4/zephyr-7b-beta"
49
+ tokenizer_path = "/home/rashmi/Documents/kaggle/h2oai_predict_llm/src/models_exp56/tokenizer"
50
+ tokenizer = AutoTokenizer.from_pretrained(
51
+ tokenizer_path, add_prefix_space=False, use_fast=True, trust_remote_code=True, add_eos_token=True
52
+ )
53
+ batch_size = 1
54
+ max_len = 650
55
+ seed = 42
56
+
57
+ num_labels = 7
58
+
59
+ lora = True
60
+ lora_r = 4
61
+ lora_alpha = 16
62
+ lora_dropout = 0.05
63
+ lora_target_modules = ""
64
+ gradient_checkpointing = True
65
+
66
+
67
+ class CustomModel(nn.Module):
68
+ """
69
+ Model for causal language modeling problem type.
70
+ """
71
+
72
+ def __init__(self):
73
+ super().__init__()
74
+
75
+ self.backbone_config = AutoConfig.from_pretrained(
76
+ CFG.llm_backbone, trust_remote_code=True
77
+ )
78
+
79
+ quantization_config = BitsAndBytesConfig(
80
+ load_in_4bit=True,
81
+ bnb_4bit_compute_dtype=torch.float16,
82
+ bnb_4bit_quant_type="nf4",
83
+ )
84
+
85
+ self.model = AutoModelForCausalLM.from_pretrained(
86
+ CFG.llm_backbone,
87
+ config=self.backbone_config,
88
+ quantization_config=quantization_config,
89
+ )
90
+
91
+ if CFG.lora:
92
+ target_modules = []
93
+ for name, module in self.model.named_modules():
94
+ if (
95
+ isinstance(module, (torch.nn.Linear, torch.nn.Conv1d))
96
+ and "head" not in name
97
+ ):
98
+ name = name.split(".")[-1]
99
+ if name not in target_modules:
100
+ target_modules.append(name)
101
+
102
+ lora_config = LoraConfig(
103
+ r=CFG.lora_r,
104
+ lora_alpha=CFG.lora_alpha,
105
+ target_modules=target_modules,
106
+ lora_dropout=CFG.lora_dropout,
107
+ bias="none",
108
+ task_type="CAUSAL_LM",
109
+ )
110
+ if CFG.gradient_checkpointing:
111
+ self.model.enable_input_require_grads()
112
+ self.model = get_peft_model(self.model, lora_config)
113
+ self.model.print_trainable_parameters()
114
+
115
+ self.classification_head = nn.Linear(
116
+ self.backbone_config.vocab_size, CFG.num_labels, bias=False
117
+ )
118
+ self._init_weights(self.classification_head)
119
+
120
+ def _init_weights(self, module):
121
+ if isinstance(module, nn.Linear):
122
+ module.weight.data.normal_(mean=0.0, std=self.backbone_config.initializer_range)
123
+ if module.bias is not None:
124
+ module.bias.data.zero_()
125
+ elif isinstance(module, nn.Embedding):
126
+ module.weight.data.normal_(mean=0.0, std=self.backbone_config.initializer_range)
127
+ if module.padding_idx is not None:
128
+ module.weight.data[module.padding_idx].zero_()
129
+ elif isinstance(module, nn.LayerNorm):
130
+ module.bias.data.zero_()
131
+ module.weight.data.fill_(1.0)
132
+
133
+ def forward(
134
+ self,
135
+ batch
136
+ ):
137
+ # disable cache if gradient checkpointing is enabled
138
+ if CFG.gradient_checkpointing:
139
+ self.model.config.use_cache = False
140
+
141
+ self.model.config.pretraining_tp = 1
142
+
143
+ output = self.model(
144
+ input_ids=batch["input_ids"],
145
+ attention_mask=batch["attention_mask"],
146
+ )
147
+
148
+ output.logits = self.classification_head(output[0][:, -1].float())
149
+
150
+ # enable cache again if gradient checkpointing is enabled
151
+ if CFG.gradient_checkpointing:
152
+ self.model.config.use_cache = True
153
+
154
+ return output.logits
155
+
156
+ model = CustomModel()
157
+
158
+
159
+
160
+ ### End Load the model
161
+
162
+
163
+
164
+
165
  def do_submit(question, response):
166
  full_text = question + " " + response
167
  # result = do_inference(full_text)
168
  return "result"
169
 
 
170
  @spaces.GPU
171
  def greet():
172
  pass