{ "cells": [ { "cell_type": "markdown", "id": "fb4f9384-be8e-488a-aa51-b56b27c71213", "metadata": { "tags": [] }, "source": [ "## 1. Set up Sagemaker\n", "*Explain more later...*" ] }, { "cell_type": "code", "execution_count": null, "id": "ea107aa6-376e-4364-bceb-50aca9f30b74", "metadata": {}, "outputs": [], "source": [ "response = client.create_presigned_notebook_instance_url(\n", " NotebookInstanceName='string',\n", " SessionExpirationDurationInSeconds=123\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "id": "ac706b50-8413-42ef-b5a7-5906f7f5cdf5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "arn:aws:iam::907929678403:role/service-role/AmazonSageMaker-ExecutionRole-20230621T132010\n" ] } ], "source": [ "import json\n", "import sagemaker\n", "from sagemaker.huggingface import get_huggingface_llm_image_uri\n", "from sagemaker.huggingface import HuggingFaceModel\n", "\n", "# retrieve the llm image uri\n", "llm_image = get_huggingface_llm_image_uri(\n", " \"huggingface\",\n", " version=\"0.8.2\"\n", ")\n", "\n", "# Define Model and Endpoint configuration parameter\n", "role = sagemaker.get_execution_role()\n", "print(role)\n", "endpoint_name = \"falcon-40b-instruct-demo\"\n", "aws_region = \"us-east-1\"\n", "hf_model_id = \"tiiuae/falcon-40b-instruct\" # model id from huggingface.co/models\n", "instance_type = \"ml.g5.12xlarge\" # instance type to use for deployment\n", "number_of_gpu = 4 # number of gpus to use for inference and tensor parallelism\n", "health_check_timeout = 600 # Increase the timeout for the health check to 5 minutes for downloading the model\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "2ce504d1-0bc3-43ce-bb39-b925a59718cc", "metadata": { "tags": [] }, "outputs": [], "source": [ "# create HuggingFaceModel with the image uri\n", "llm_model = HuggingFaceModel(\n", " role=role,\n", " image_uri=llm_image,\n", " env={\n", " 'HF_MODEL_ID': hf_model_id,\n", " # 'HF_MODEL_QUANTIZE': \"bitsandbytes\", # comment in to quantize\n", " 'SM_NUM_GPUS': json.dumps(number_of_gpu),\n", " 'MAX_INPUT_LENGTH': json.dumps(1024), # Max length of input text\n", " 'MAX_TOTAL_TOKENS': json.dumps(2048), # Max length of the generation (including input text)\n", " }\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "00664be7-3d08-4c68-9048-ba1e602c44c2", "metadata": { "tags": [] }, "outputs": [ { "ename": "ResourceLimitExceeded", "evalue": "An error occurred (ResourceLimitExceeded) when calling the CreateEndpoint operation: The account-level service limit 'ml.g5.12xlarge for endpoint usage' is 2 Instances, with current utilization of 2 Instances and a request delta of 1 Instances. Please use AWS Service Quotas to request an increase for this quota. If AWS Service Quotas is not available, contact AWS support to request an increase for this quota.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mResourceLimitExceeded\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m llm \u001b[38;5;241m=\u001b[39m \u001b[43mllm_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdeploy\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43minitial_instance_count\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43minstance_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minstance_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mcontainer_startup_health_check_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhealth_check_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mendpoint_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mendpoint_name\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/sagemaker/huggingface/model.py:311\u001b[0m, in \u001b[0;36mHuggingFaceModel.deploy\u001b[0;34m(self, initial_instance_count, instance_type, serializer, deserializer, accelerator_type, endpoint_name, tags, kms_key, wait, data_capture_config, async_inference_config, serverless_inference_config, volume_size, model_data_download_timeout, container_startup_health_check_timeout, inference_recommendation_id, explainer_config, **kwargs)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimage_uri \u001b[38;5;129;01mand\u001b[39;00m instance_type \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m instance_type\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mml.inf\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimage_uri \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mserving_image_uri(\n\u001b[1;32m 307\u001b[0m region_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msagemaker_session\u001b[38;5;241m.\u001b[39mboto_session\u001b[38;5;241m.\u001b[39mregion_name,\n\u001b[1;32m 308\u001b[0m instance_type\u001b[38;5;241m=\u001b[39minstance_type,\n\u001b[1;32m 309\u001b[0m )\n\u001b[0;32m--> 311\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mHuggingFaceModel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdeploy\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 312\u001b[0m \u001b[43m \u001b[49m\u001b[43minitial_instance_count\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 313\u001b[0m \u001b[43m \u001b[49m\u001b[43minstance_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 314\u001b[0m \u001b[43m \u001b[49m\u001b[43mserializer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[43m \u001b[49m\u001b[43mdeserializer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 316\u001b[0m \u001b[43m \u001b[49m\u001b[43maccelerator_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 317\u001b[0m \u001b[43m \u001b[49m\u001b[43mendpoint_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[43m \u001b[49m\u001b[43mtags\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 319\u001b[0m \u001b[43m \u001b[49m\u001b[43mkms_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 320\u001b[0m \u001b[43m \u001b[49m\u001b[43mwait\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_capture_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 322\u001b[0m \u001b[43m \u001b[49m\u001b[43masync_inference_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 323\u001b[0m \u001b[43m \u001b[49m\u001b[43mserverless_inference_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 324\u001b[0m \u001b[43m \u001b[49m\u001b[43mvolume_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvolume_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 325\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_data_download_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_data_download_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 326\u001b[0m \u001b[43m \u001b[49m\u001b[43mcontainer_startup_health_check_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcontainer_startup_health_check_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 327\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_recommendation_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_recommendation_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 328\u001b[0m \u001b[43m \u001b[49m\u001b[43mexplainer_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexplainer_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 329\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/sagemaker/model.py:1347\u001b[0m, in \u001b[0;36mModel.deploy\u001b[0;34m(self, initial_instance_count, instance_type, serializer, deserializer, accelerator_type, endpoint_name, tags, kms_key, wait, data_capture_config, async_inference_config, serverless_inference_config, volume_size, model_data_download_timeout, container_startup_health_check_timeout, inference_recommendation_id, explainer_config, **kwargs)\u001b[0m\n\u001b[1;32m 1344\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_explainer_enabled:\n\u001b[1;32m 1345\u001b[0m explainer_config_dict \u001b[38;5;241m=\u001b[39m explainer_config\u001b[38;5;241m.\u001b[39m_to_request_dict()\n\u001b[0;32m-> 1347\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msagemaker_session\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mendpoint_from_production_variants\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1348\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mendpoint_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1349\u001b[0m \u001b[43m \u001b[49m\u001b[43mproduction_variants\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mproduction_variant\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1350\u001b[0m \u001b[43m \u001b[49m\u001b[43mtags\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtags\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1351\u001b[0m \u001b[43m \u001b[49m\u001b[43mkms_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkms_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1352\u001b[0m \u001b[43m \u001b[49m\u001b[43mwait\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1353\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_capture_config_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_capture_config_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1354\u001b[0m \u001b[43m \u001b[49m\u001b[43mexplainer_config_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexplainer_config_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1355\u001b[0m \u001b[43m \u001b[49m\u001b[43masync_inference_config_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43masync_inference_config_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1356\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1358\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredictor_cls:\n\u001b[1;32m 1359\u001b[0m predictor \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredictor_cls(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mendpoint_name, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msagemaker_session)\n", "File \u001b[0;32m~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/sagemaker/session.py:4641\u001b[0m, in \u001b[0;36mSession.endpoint_from_production_variants\u001b[0;34m(self, name, production_variants, tags, kms_key, wait, data_capture_config_dict, async_inference_config_dict, explainer_config_dict)\u001b[0m\n\u001b[1;32m 4638\u001b[0m LOGGER\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCreating endpoint-config with name \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, name)\n\u001b[1;32m 4639\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msagemaker_client\u001b[38;5;241m.\u001b[39mcreate_endpoint_config(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mconfig_options)\n\u001b[0;32m-> 4641\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_endpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mendpoint_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtags\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtags\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwait\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/sagemaker/session.py:4030\u001b[0m, in \u001b[0;36mSession.create_endpoint\u001b[0;34m(self, endpoint_name, config_name, tags, wait)\u001b[0m\n\u001b[1;32m 4027\u001b[0m tags \u001b[38;5;241m=\u001b[39m tags \u001b[38;5;129;01mor\u001b[39;00m []\n\u001b[1;32m 4028\u001b[0m tags \u001b[38;5;241m=\u001b[39m _append_project_tags(tags)\n\u001b[0;32m-> 4030\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msagemaker_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_endpoint\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4031\u001b[0m \u001b[43m \u001b[49m\u001b[43mEndpointName\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mendpoint_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mEndpointConfigName\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mTags\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtags\u001b[49m\n\u001b[1;32m 4032\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4033\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m wait:\n\u001b[1;32m 4034\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwait_for_endpoint(endpoint_name)\n", "File \u001b[0;32m~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/botocore/client.py:530\u001b[0m, in \u001b[0;36mClientCreator._create_api_method.._api_call\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 527\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpy_operation_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m() only accepts keyword arguments.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 528\u001b[0m )\n\u001b[1;32m 529\u001b[0m \u001b[38;5;66;03m# The \"self\" in this scope is referring to the BaseClient.\u001b[39;00m\n\u001b[0;32m--> 530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_api_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43moperation_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/botocore/client.py:964\u001b[0m, in \u001b[0;36mBaseClient._make_api_call\u001b[0;34m(self, operation_name, api_params)\u001b[0m\n\u001b[1;32m 962\u001b[0m error_code \u001b[38;5;241m=\u001b[39m parsed_response\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mError\u001b[39m\u001b[38;5;124m\"\u001b[39m, {})\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCode\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 963\u001b[0m error_class \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexceptions\u001b[38;5;241m.\u001b[39mfrom_code(error_code)\n\u001b[0;32m--> 964\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m error_class(parsed_response, operation_name)\n\u001b[1;32m 965\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 966\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parsed_response\n", "\u001b[0;31mResourceLimitExceeded\u001b[0m: An error occurred (ResourceLimitExceeded) when calling the CreateEndpoint operation: The account-level service limit 'ml.g5.12xlarge for endpoint usage' is 2 Instances, with current utilization of 2 Instances and a request delta of 1 Instances. Please use AWS Service Quotas to request an increase for this quota. If AWS Service Quotas is not available, contact AWS support to request an increase for this quota." ] } ], "source": [ "llm = llm_model.deploy(\n", " initial_instance_count=1,\n", " instance_type=instance_type,\n", " container_startup_health_check_timeout=health_check_timeout,\n", " endpoint_name=endpoint_name\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "50f556f8-06b4-450e-9db3-9bc9c979e8ab", "metadata": { "tags": [] }, "outputs": [], "source": [ "llm2.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "id": "c341f368-a9e7-441c-886e-0576c3f2f432", "metadata": { "tags": [] }, "outputs": [], "source": [ "\n", "from langchain.chains.question_answering import load_qa_chain\n", "from langchain.memory import ConversationBufferMemory\n", "from langchain import PromptTemplate\n", "from typing import Dict\n", "\n", "class ContentHandler(LLMContentHandler):\n", " content_type = \"application/json\"\n", " accepts = \"application/json\"\n", " len_prompt = 0\n", "\n", " def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n", " self.len_prompt = len(prompt)\n", " input_str = json.dumps(\n", " {\"inputs\": prompt,\n", " \"parameters\": {\n", " \"do_sample\": True,\n", " \"top_p\": 0.9,\n", " \"temperature\": 0.8,\n", " \"max_new_tokens\": 1024,\n", " \"repetition_penalty\": 1.03,\n", " \"stop\": [\"\\n\\n\", \"Human:\", \"<|endoftext|>\", \"\"]\n", " }})\n", " return input_str.encode('utf-8')\n", "\n", " def transform_output(self, output: bytes) -> str:\n", " response_json = output.read()\n", " res = json.loads(response_json)\n", " ans = res[0]['generated_text'][self.len_prompt:]\n", " ans = ans[:ans.rfind(\"Human\")].strip()\n", " return ans\n", "\n", "\n", "content_handler = ContentHandler()\n", "\n", "\n", "prompt = PromptTemplate(\n", " input_variables=[\"history\", \"context\", \"new_message\"],\n", " template=\"\"\"\n", "The following is a given context and history of a conversation between a human and a helpful AI assistant. \n", "The AI assistant answers the human's messages with full, informational responses by inferring on the information given in the context, but does not extrapolate or hallucinate any facts that are not given.\n", "When answering questions with unclear answers, the AI assistant qualifies its responses as such.\n", "The AI assistant pretends it is also human and that the context is objectively true and well-known and therefore avoids mentioning the context as a source.\n", "\n", "Context: {context}\n", "\n", "{history}\n", "Human: {new_message}\n", "AI:\"\"\")\n", "\n", "def load_chain():\n", " llm = SagemakerEndpoint(\n", " endpoint_name=endpoint_name,\n", " region_name=aws_region,\n", " content_handler=content_handler\n", " )\n", " chain = load_qa_chain(llm=llm, chain_type=\"stuff\", verbose=True, memory=ConversationBufferMemory(memory_key=\"history\", input_key=\"new_message\"), prompt=prompt)\n", " return chain\n", "\n", "\n", "dachain = load_chain()" ] }, { "cell_type": "code", "execution_count": null, "id": "b0a557a0-ca6d-45db-97b4-f89317a5e500", "metadata": { "tags": [] }, "outputs": [], "source": [ "query = \"What is Becton?\"\n", "dachain({\"input_documents\": docsearch.similarity_search(query, k=3), \"new_message\": query}, return_only_outputs=True)['output_text'].strip()" ] }, { "cell_type": "markdown", "id": "c4ac2fca-820b-412d-9e90-47848e046236", "metadata": {}, "source": [ "## Load DSS Website Data into ChromaDB\n", "`urls` object defines what URLs are to be considered in the context database." ] }, { "cell_type": "code", "execution_count": null, "id": "c596a7a6-cca1-46c5-a914-e8b5ce9eba17", "metadata": { "tags": [] }, "outputs": [], "source": [ "from langchain.document_loaders import UnstructuredURLLoader\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "from langchain.vectorstores import Chroma\n", "from langchain.embeddings import HuggingFaceInstructEmbeddings\n", "\n", "# define URL sources\n", "urls = [\n", " 'https://www.dssinc.com/blog/2022/6/21/suicide-prevention-manager-enabling-the-veterans-affairs-to-achieve-high-reliability-in-suicide-risk-identification',\n", " 'https://www.dssinc.com/blog/2022/8/9/dss-inc-announces-appointment-of-brion-bailey-as-director-of-federal-business-development', \n", " 'https://www.dssinc.com/blog/2022/3/21/march-22-is-diabetes-alertness-day-a-helpful-reminder-to-monitor-and-prevent-diabetes',\n", " 'https://www.dssinc.com/blog/2023/5/24/supporting-the-vas-high-reliability-organization-journey-through-suicide-prevention',\n", " 'https://www.dssinc.com/blog/2022/12/19/dss-theradoc-helps-battle-super-bugs-for-better-veteran-health',\n", " 'https://www.dssinc.com/blog/2022/9/21/dss-inc-chosen-for-phase-two-of-mission-daybreak-vas-suicide-prevention-challenge',\n", " 'https://www.dssinc.com/blog/2022/9/19/crescenz-va-medical-center-cmcvamc-deploys-the-dss-iconic-data-patient-case-manager-pcm-solution',\n", " 'https://www.dssinc.com/blog/2022/5/9/federal-news-network-the-importance-of-va-supply-chain-modernization']\n", "\n", "# load and split\n", "loaders = UnstructuredURLLoader(urls=urls)\n", "data = loaders.load()\n", "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n", "texts = text_splitter.split_documents(data)\n", "print(\"Sources split into the following number of \\\"texts\\\":\", len(texts))\n", "\n", "# load embedding model\n", "print(\"Loading embedding model...\")\n", "embeddings = HuggingFaceInstructEmbeddings(model_name=\"hkunlp/instructor-xl\")\n", "\n", "docsearch = Chroma.from_texts([t.page_content for t in texts], embeddings)" ] }, { "cell_type": "code", "execution_count": null, "id": "6fe72b28-0d34-47c5-82f5-7318576e4ec8", "metadata": {}, "outputs": [], "source": [ "print(\"Getting AI response... @ \", datetime.datetime.now().strftime(\"%H:%M:%S\"))\n", "print(chain({\"input_documents\": docsearch.similarity_search(query, k=3), \"new_message\": query}, return_only_outputs=True)['output_text'].strip())" ] }, { "cell_type": "code", "execution_count": null, "id": "f9c38a37-9aa0-4584-8b06-cee2932d14cf", "metadata": { "tags": [] }, "outputs": [], "source": [ "\n", "llm2.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "id": "881f2cb5-41c5-4fd2-b942-d3c289dda758", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "202caf50-c00a-4555-8150-4fc7a779aa0a", "metadata": { "tags": [] }, "outputs": [], "source": [ "from sagemaker.predictor import Predictor\n", "\n", "llm2 = Predictor(endpoint_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "31f54a95-0324-4deb-be0e-89c453004f6c", "metadata": { "tags": [] }, "outputs": [], "source": [ "dom = \"d-bipui5yzbvlc\"\n", "print(f'https://{dom}.studio.{aws_region}.sagemaker.aws/studiolab/default/jupyter/proxy/6006/')" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p310", "language": "python", "name": "conda_pytorch_p310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 5 }