Holycanolies123 commited on
Commit
6a12e94
1 Parent(s): 5689401

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -23
app.py CHANGED
@@ -1,31 +1,33 @@
1
- import sagemaker
2
- from sagemaker.huggingface import HuggingFace
3
 
4
- # gets role for executing training job
5
- role = sagemaker.get_execution_role()
6
- hyperparameters = {
7
- 'model_name_or_path':'PygmalionAI/pygmalion-6b',
8
- 'output_dir':'/opt/ml/model'
9
- # add your remaining hyperparameters
10
- # more info here https://github.com/huggingface/transformers/tree/v4.17.0/examples/pytorch/language-modeling
11
  }
12
 
13
- # git configuration to download our fine-tuning script
14
- git_config = {'repo': 'https://github.com/huggingface/transformers.git','branch': 'v4.17.0'}
15
-
16
- # creates Hugging Face estimator
17
- huggingface_estimator = HuggingFace(
18
- entry_point='run_clm.py',
19
- source_dir='./examples/pytorch/language-modeling',
20
- instance_type='ml.p3.2xlarge',
21
- instance_count=1,
22
- role=role,
23
- git_config=git_config,
24
  transformers_version='4.17.0',
25
  pytorch_version='1.10.2',
26
  py_version='py38',
27
- hyperparameters = hyperparameters
 
 
 
 
 
 
 
28
  )
29
 
30
- # starting the train job
31
- huggingface_estimator.fit()
 
 
 
 
 
 
1
+ from sagemaker.huggingface import HuggingFaceModel
2
+ import boto3
3
 
4
+ iam_client = boto3.client('iam')
5
+ role = iam_client.get_role(RoleName='{IAM_ROLE_WITH_SAGEMAKER_PERMISSIONS}')['Role']['Arn']
6
+ # Hub Model configuration. https://huggingface.co/models
7
+ hub = {
8
+ 'HF_MODEL_ID':'PygmalionAI/pygmalion-6b',
9
+ 'HF_TASK':'conversational'
 
10
  }
11
 
12
+ # create Hugging Face Model Class
13
+ huggingface_model = HuggingFaceModel(
 
 
 
 
 
 
 
 
 
14
  transformers_version='4.17.0',
15
  pytorch_version='1.10.2',
16
  py_version='py38',
17
+ env=hub,
18
+ role=role,
19
+ )
20
+
21
+ # deploy model to SageMaker Inference
22
+ predictor = huggingface_model.deploy(
23
+ initial_instance_count=1, # number of instances
24
+ instance_type='ml.m5.xlarge' # ec2 instance type
25
  )
26
 
27
+ predictor.predict({
28
+ 'inputs': {
29
+ "past_user_inputs": ["Which movie is the best ?"],
30
+ "generated_responses": ["It's Die Hard for sure."],
31
+ "text": "Can you explain why ?"
32
+ }
33
+ })