bhaveshgoel07 commited on
Commit
f3b0572
·
1 Parent(s): 8772405

fixed model loading

Browse files
Files changed (3) hide show
  1. =3 +0 -0
  2. app.py +16 -6
  3. requirements.txt +2 -2
=3 ADDED
File without changes
app.py CHANGED
@@ -1,17 +1,27 @@
 
1
 
2
- import keras_nlp
3
  import tensorflow as tf
4
  import gradio as gr
5
  import json
6
  import os
 
7
 
8
- # Load model configuration
9
- with open("gemma_finetuned_model/config.json", "r") as f:
 
 
 
 
10
  config = json.load(f)
11
 
12
- # Load the fine-tuned model
13
- gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en", **config)
14
- gemma_lm.load_weights("gemma_finetuned_model/model_weights.h5")
 
 
 
 
 
15
 
16
  # Gradio app function
17
  def generate_response(prompt):
 
1
+ import keras
2
 
 
3
  import tensorflow as tf
4
  import gradio as gr
5
  import json
6
  import os
7
+ from keras_hub.models import GemmaCausalLM
8
 
9
+ # Path to your locally saved model and config
10
+ model_weights_path = "gemma_finetuned_model/finetunedmodel.weights.h5"
11
+ config_path = "gemma_finetuned_model/config.json"
12
+
13
+ # Load the model using the local config and weights
14
+ with open(config_path, "r") as f:
15
  config = json.load(f)
16
 
17
+ gemma_lm = GemmaCausalLM.from_config(config)
18
+ gemma_lm.load_weights(model_weights_path)
19
+
20
+
21
+ #
22
+ # # Load the fine-tuned model
23
+ # gemma_lm = GemmaCausalLM.from_preset("gemma_2b_en", **config)
24
+ # gemma_lm.load_weights("gemma_finetuned_model/model_weights.h5")
25
 
26
  # Gradio app function
27
  def generate_response(prompt):
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
-
2
- keras-nlp
3
  tensorflow
4
  gradio
 
1
+ keras
2
+ keras-hub
3
  tensorflow
4
  gradio