TextGen / feather_chat.py
abdullah10's picture
Upload 35 files
8bc7dc5
raw
history blame
2.02 kB
from transformers import AutoTokenizer
from langchain.chains import ConversationChain
from langchain.llms import GooglePalm
from langchain.memory import ConversationBufferMemory
import os
from dotenv import load_dotenv, find_dotenv
import streamlit as st
load_dotenv(find_dotenv())
google_api_key = os.environ['GOOGLE_API_KEY']
tokenizer = AutoTokenizer.from_pretrained("google/flan-ul2")
def call_palm(google_api_key, temperature=0.5, max_tokens=8000, top_p=0.95, top_k=40, n_batch=9, repeat_penalty=1.1, n_ctx=8000):
google_palm_model = GooglePalm(
google_api_key=google_api_key,
temperature=temperature,
max_output_tokens=max_tokens,
top_p=top_p,
top_k=top_k,
n_batch=n_batch,
repeat_penalty = repeat_penalty,
n_ctx = n_ctx
)
return google_palm_model
llm = call_palm(google_api_key)
memory = ConversationBufferMemory()
conversation_total_tokens = 0
new_conversation = ConversationChain(llm=llm,
verbose=False,
memory=memory, )
current_line_number = 1
while True:
message = st.text_input('Human', key = str(current_line_number))
if message=='Exit':
st.text(f"{conversation_total_tokens} tokens used in total in this conversation.")
break
if message:
formatted_prompt = new_conversation.prompt.format(input=message, history=new_conversation.memory.buffer)
st.text(f'formatted_prompt is {formatted_prompt}')
num_tokens = len(tokenizer.tokenize(formatted_prompt))
conversation_total_tokens += num_tokens
st.text(f'tokens sent {num_tokens}')
response = new_conversation.predict(input=message)
response_num_tokens = len(tokenizer.tokenize(response))
conversation_total_tokens += response_num_tokens
st.text(f"Featherica: {response}")
current_line_number = current_line_number + 1