Florian commited on
Commit
af6c532
·
1 Parent(s): 5b2e6a5

remove penalty alpha and just put the model

Browse files
Files changed (1) hide show
  1. app.py +5 -24
app.py CHANGED
@@ -25,23 +25,8 @@ def reset():
25
  _, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence'])
26
 
27
  @st.cache_resource
28
- def load_model(penalty_alpha):
29
- penalty_map = {0.1:"model_20240118-144039.bin",
30
- 0.5:"model_20240118-192548.bin",
31
- 2:"model_20240118-211943.bin",
32
- 5:"model_20240118-231333.bin",
33
- 10:"model_20240119-010725.bin",
34
- 20:"model_20240119-030115.bin",
35
- 0:"model_20240119-135506.bin",
36
- 1:"model_20240119-154900.bin",
37
- -20: "model_20240208-072350.bin",
38
- -10: "model_20240208-052958.bin",
39
- -5: "model_20240208-033606.bin",
40
- -2: "model_20240208-014211.bin",
41
- -1: "model_20240207-234817.bin",
42
- -0.5: "model_20240207-215423.bin",
43
- -0.1: "model_20240207-200020.bin"}
44
-
45
  model_str = "susnato/phi-1_5_dev"
46
  model = AutoModelForCausalLM.from_pretrained(model_str).to("cuda:1")
47
  tokenizer = AutoTokenizer.from_pretrained(model_str)
@@ -49,19 +34,15 @@ def load_model(penalty_alpha):
49
  branch_locations = list(range(0, 23, 5))
50
  model = BranchyModel(branch_locations= branch_locations, model= model).to("cuda:1")
51
 
52
- # Load the specific model based on penalty_alpha
53
- model_path = penalty_map.get(penalty_alpha)
54
- if model_path:
55
- model.load_state_dict(torch.load(model_path, map_location="cuda:1"))
56
- else:
57
- print("Invalid penalty_alpha. Using default model weights.")
58
 
59
  return model, tokenizer
60
 
61
 
62
  if "model" not in st.session_state or "tokenizer" not in st.session_state:
63
  print("Loading model...")
64
- st.session_state.model, st.session_state.tokenizer = load_model(penalty_alpha=-2) # Example penalty_alpha
65
  st.session_state["head_number"] = len(st.session_state.model.branch_locations) + 1
66
  print(f"Head number: {st.session_state['head_number']}")
67
  # Session state to store the current sentence
 
25
  _, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence'])
26
 
27
  @st.cache_resource
28
+ def load_model(model_path):
29
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  model_str = "susnato/phi-1_5_dev"
31
  model = AutoModelForCausalLM.from_pretrained(model_str).to("cuda:1")
32
  tokenizer = AutoTokenizer.from_pretrained(model_str)
 
34
  branch_locations = list(range(0, 23, 5))
35
  model = BranchyModel(branch_locations= branch_locations, model= model).to("cuda:1")
36
 
37
+ # Load the specific model
38
+ model.load_state_dict(torch.load(model_path, map_location="cuda:1"))
 
 
 
 
39
 
40
  return model, tokenizer
41
 
42
 
43
  if "model" not in st.session_state or "tokenizer" not in st.session_state:
44
  print("Loading model...")
45
+ st.session_state.model, st.session_state.tokenizer = load_model("model/model.bin")
46
  st.session_state["head_number"] = len(st.session_state.model.branch_locations) + 1
47
  print(f"Head number: {st.session_state['head_number']}")
48
  # Session state to store the current sentence