k3ybladewielder's picture
Update app.py
9016d16
raw
history blame contribute delete
No virus
2.47 kB
from langchain import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import os
import torch
import gradio as gr
import subprocess
#command = 'pip install git+https://github.com/huggingface/transformers'
#subprocess.run(command, shell=True)
# check if cuda is available
torch.cuda.is_available()
# define the model id
# model_id = "tiiuae/falcon-40b-instruct"
model_id = "tiiuae/falcon-7b-instruct"
# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# load the model
## params:
## cache_dir: Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. \n
## device_map: ensures the model is moved to your GPU(s)
cache_dir = "./workspace/"
torch_dtype = torch.bfloat16
trust_remote_code = True
device_map = "auto"
offload_folder = "offload"
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, torch_dtype = torch_dtype,
trust_remote_code=trust_remote_code, device_map=device_map, offload_folder=offload_folder)
# set pt model to inference mode
model.eval()
# build the hf transformers pipeline
task = "text-generation"
max_length = 400
do_sample = True
top_k = 10
num_return_sequences = 1
eos_token_id = tokenizer.eos_token_id
pipeline = transformers.pipeline("text-generation", model = model, tokenizer = tokenizer,
device_map = device_map, max_length = max_length,
do_sample = do_sample, top_k = top_k,
num_return_sequences = num_return_sequences,
eos_token_id = eos_token_id)
# setup promt template
template = PromptTemplate(input_variables = ['input'], template = '{input}')
# pass hf pipeline to langhcain class
llm = HuggingFacePipeline(pipeline=pipeline)
# build stacked llm chain, ie prompt-formatting + llm
chain = LLMChain(llm=llm, prompt=template)
# create generate function
def generate(prommpt):
# the prompt will get passes to the llm chain
return chain.run(prompt)
# and will return responses
title = "Falcon 40-b-Instruct πŸ¦…"
description = "Web app application using the open-source `Falcon-40b-Instruct` LLM"
# build gradio interface
gr.Interface(fn=generate,
input=["text"],
outputs=["text"],
title=title,
descrption=description).launch()