{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Import Dependencies" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\smitg\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n", "import gradio as gr\n", "import requests" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')\n", "model = GPT2LMHeadModel.from_pretrained(\"gpt2-large\", pad_token_id=tokenizer.eos_token_id)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tokenize Sentences" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "sentence = \"What is the temperature right now?\"\n", "input_ids = tokenizer.encode(sentence, return_tensors='pt')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[2061, 318, 262, 5951, 826, 783, 30]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_ids" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(318)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_ids[0][1]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "' is'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.decode(input_ids[0][1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate and Decode Text" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "output = model.generate(input_ids, max_length=500, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 2061, 318, 262, 5951, 826, 783, 30, 198, 198, 464,\n", " 5951, 318, 287, 262, 2837, 286, 4317, 284, 4019, 7370,\n", " 35935, 357, 2481, 284, 1679, 7370, 34186, 8, 290, 318,\n", " 2938, 284, 2652, 612, 329, 262, 1306, 1178, 1528, 11,\n", " 1864, 284, 262, 2351, 10692, 291, 290, 41516, 8694, 357,\n", " 15285, 3838, 8, 287, 27437, 11, 7492, 11, 543, 318,\n", " 9904, 262, 6193, 13, 383, 2351, 15615, 4809, 357, 45,\n", " 19416, 8, 468, 4884, 257, 6049, 18355, 12135, 2342, 329,\n", " 881, 286, 262, 8830, 1578, 1829, 11, 1390, 3354, 286,\n", " 3442, 11, 12087, 11, 7943, 11, 968, 5828, 11, 10202,\n", " 11, 20071, 11, 24533, 11, 18311, 11, 8819, 11, 2669,\n", " 290, 3517, 9309, 11, 355, 880, 355, 262, 5398, 17812,\n", " 286, 15555, 11, 31346, 11, 28293, 11, 10553, 11, 14778,\n", " 11, 17711, 32586, 11, 9005, 10443, 5451, 11, 290, 968,\n", " 32211, 13, 628, 198, 2061, 389, 262, 8395, 286, 6290,\n", " 290, 18355, 38563, 428, 5041, 290, 1306, 1285, 30, 50256]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 160])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.shape" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'What is the temperature right now?\\n\\nThe temperature is in the range of 70 to 80 degrees Fahrenheit (21 to 25 degrees Celsius) and is expected to stay there for the next few days, according to the National Oceanic and Atmospheric Administration (NOAA) in Boulder, Colorado, which is monitoring the weather. The National Weather Service (NWS) has issued a severe thunderstorm watch for much of the western United States, including parts of California, Nevada, Arizona, New Mexico, Utah, Idaho, Wyoming, Montana, Oregon, Washington and British Columbia, as well as the Canadian provinces of Alberta, Saskatchewan, Manitoba, Ontario, Quebec, Nova Scotia, Prince Edward Island, and New Brunswick.\\n\\n\\nWhat are the chances of rain and thunderstorms this weekend and next week?'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.decode(output[0], skip_special_tokens=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Output Result" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "text = tokenizer.decode(output[0], skip_special_tokens=True)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "with open('blogpost.text', 'w') as f:\n", " f.write(text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deployment" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "IMPORTANT: You are using gradio version 3.50.2, however version 4.29.0 is available, please upgrade.\n", "--------\n", "Running on public URL: https://f37310546fa2823990.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def generate_blog_post(prompt):\n", " input_ids = tokenizer.encode(prompt, return_tensors='pt')\n", " output = model.generate(input_ids, max_length=500, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)\n", " text = tokenizer.decode(output[0], skip_special_tokens=True)\n", " return text\n", "\n", "# Set up the Gradio interface\n", "iface = gr.Interface(\n", " fn=generate_blog_post,\n", " inputs=\"text\",\n", " outputs=\"text\",\n", " title=\"Blog Post Generator\",\n", " description=\"Enter a prompt to generate a blog post using GPT-2.\"\n", ")\n", "\n", "# Launch the Gradio app\n", "if __name__ == \"__main__\":\n", " iface.launch(share=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 2 }