{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import transformers\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", "PROMPT_TEMPLATES_DIR = os.path.dirname(os.path.abspath(os.getcwd()))\n", "PROMPT_TEMPLATES_DIR += '/config/api/prompt_templates/'\n", "\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prompt_template = 'sythia_v1.3'\n", "with open(PROMPT_TEMPLATES_DIR + f'{prompt_template}.txt', 'r') as f:\n", " prompt_template = f.read()\n", "\n", "context = ''\n", "question = 'How to fix a bike?'\n", "\n", "prompt = prompt_template.format(context=context, question=question)\n", "print(f'prompt len: {len(prompt)}\\n')\n", "print(prompt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_id = 'migtissera/SynthIA-7B-v1.3'\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_id,\n", " torch_dtype=torch.bfloat16,\n", " trust_remote_code=True,\n", " load_in_8bit=False,\n", " device_map='auto',\n", " resume_download=True,\n", ")\n", "\n", "pipeline = transformers.pipeline(\n", " 'text-generation',\n", " model=model,\n", " tokenizer=tokenizer,\n", " device_map='auto',\n", " torch_dtype=torch.bfloat16,\n", " eos_token_id=tokenizer.eos_token_id,\n", " pad_token_id=tokenizer.eos_token_id,\n", " min_new_tokens=64,\n", " max_new_tokens=800,\n", " temperature=0.5,\n", " do_sample=True,\n", ")\n", "\n", "output_text = pipeline(prompt)[0]['generated_text']\n", "output_text = output_text.replace(prompt+'\\n', '')\n", "print(output_text)" ] } ], "metadata": { "kernelspec": { "display_name": "hf_qa_bot", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "e769ac600d1c65682759767682b2a946c0eaa09d353302f712fe4c2e822e15df" } } }, "nbformat": 4, "nbformat_minor": 2 }