BucketOfFish commited on
Commit
16cc769
1 Parent(s): a420fe7

Updated imports in script

Browse files
Files changed (1) hide show
  1. streaming_inference.py +6 -6
streaming_inference.py CHANGED
@@ -3,8 +3,8 @@ from threading import Thread
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import torch
5
 
6
- from .configuration_phi import PhiConfig
7
- from .modeling_phi import PhiForCausalLM
8
 
9
 
10
  # This works, but is not streaming
@@ -12,8 +12,8 @@ from .modeling_phi import PhiForCausalLM
12
  if __name__ == "__main__":
13
  device = "cuda"
14
 
15
- model_config = PhiConfig(**json.load(open("simplified_phi2/config.json")))
16
- model = PhiForCausalLM(model_config).to(device)
17
  phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
18
  model.load_state_dict(phi_model.state_dict())
19
 
@@ -45,8 +45,8 @@ if __name__ == "__main__":
45
 
46
  # make model and run model.generate(streamer=TextIteratorStreamer) on a thread
47
  device = "cuda"
48
- model_config = PhiConfig(**json.load(open("simplified_phi2/config.json")))
49
- model = PhiForCausalLM(model_config).to(device)
50
  phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
51
  model.load_state_dict(phi_model.state_dict())
52
  thread = Thread(
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import torch
5
 
6
+ from .phi2_configuration import Phi2Config
7
+ from .phi2_model import Phi2ModelForCausalLM
8
 
9
 
10
  # This works, but is not streaming
 
12
  if __name__ == "__main__":
13
  device = "cuda"
14
 
15
+ model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
16
+ model = Phi2ModelForCausalLM(model_config).to(device)
17
  phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
18
  model.load_state_dict(phi_model.state_dict())
19
 
 
45
 
46
  # make model and run model.generate(streamer=TextIteratorStreamer) on a thread
47
  device = "cuda"
48
+ model_config = Phi2Config(**json.load(open("simplified_phi2/config.json")))
49
+ model = Phi2ModelForCausalLM(model_config).to(device)
50
  phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
51
  model.load_state_dict(phi_model.state_dict())
52
  thread = Thread(