{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Import Libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pypdf import PdfReader\n", "import os" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pre_processing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load Embedding Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pre_processing import embedding_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Process Doc" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "base_path = os.getcwd()\n", "file_name = 'attention_is_all_you_need.pdf'\n", "full_path = os.path.join(base_path,file_name)\n", "reader = PdfReader(full_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "first_section = \"abstract\"\n", "ignore_after = \"references\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "context_list = pre_processing.parese_doc(reader,first_section,ignore_after)\n", "index = pre_processing.create_embedding(context_list)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Linking ONXX Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import onnxruntime_genai as og" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "phi3_model_path = 'cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4'\n", "full_model_path = os.path.join(base_path,phi3_model_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = og.Model(full_model_path)\n", "tokenizer = og.Tokenizer(model)\n", "tokenizer_stream = tokenizer.create_stream()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "chat_template = '<|user|>\\n{input} <|end|>\\n<|assistant|>'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "search_options ={}\n", "search_options['temperature'] = 1\n", "#search_options['max_length'] = 4000" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "while True:\n", " text = input(\"Input: \")\n", " if not text:\n", " print(\"Error, input cannot be empty\")\n", " break\n", "\n", " query_embedding = embedding_model.encode(text).reshape(1, -1)\n", " top_k = 1\n", " _scores, binary_ids = index.search(query_embedding, top_k)\n", " binary_ids = binary_ids[0]\n", " _scores = _scores[0]\n", " temp_list = []\n", " for idx in binary_ids:\n", " temp_list.append(context_list[idx])\n", " context = '. '.join(temp_list)\n", " \n", " text += \" With respect to context: \"+context\n", " \n", "\n", " prompt = f'{chat_template.format(input=text)}'\n", " input_tokens = tokenizer.encode(prompt)\n", "\n", " params = og.GeneratorParams(model)\n", " params.try_graph_capture_with_max_batch_size(1)\n", " params.set_search_options(**search_options)\n", " params.input_ids = input_tokens\n", " generator = og.Generator(model, params)\n", "\n", " print()\n", " print(\"Output: \", end='', flush=True)\n", "\n", " try:\n", " while not generator.is_done():\n", " generator.compute_logits()\n", " generator.generate_next_token()\n", " new_token = generator.get_next_tokens()[0]\n", " print(tokenizer_stream.decode(new_token), end='', flush=True)\n", " except KeyboardInterrupt:\n", " print(\" --control+c pressed, aborting generation--\")\n", " print()\n", " print()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".phi3_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }