import streamlit as st # from transformers import T5Tokenizer,AutoModelForCausalLM model_name = "rinna/japanese-gpt2-small" # tokenizer = T5Tokenizer.from_pretrained(model_name) # model = AutoModelForCausalLM.from_pretrained(model_name) import torch from transformers import AutoTokenizer, AutoModelForCausalLM # Load the pre-trained GPT-2 model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) # App title st.set_page_config(page_title="ChatBot") if "messages" not in st.session_state.keys(): st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}] # Display chat messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) # Function for generating LLM response # def generate_response(prompt_input): # input = tokenizer.encode(prompt_input, return_tensors="pt") # output = model.generate(input, do_sample=True, max_length=30, num_return_sequences=1) # return tokenizer.batch_decode(output) def generate_response(prompt, max_length=50): input_ids = tokenizer.encode(prompt, return_tensors="pt") # Generate response with torch.no_grad(): output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, pad_token_id=50256) response = tokenizer.decode(output[0], skip_special_tokens=True) return response # User-provided prompt if prompt := st.chat_input(): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.write(prompt) # Generate a new response if last message is not from assistant if st.session_state.messages[-1]["role"] != "assistant": with st.chat_message("assistant"): with st.spinner("Thinking..."): response = generate_response(prompt) st.write(response) message = {"role": "assistant", "content": response} st.session_state.messages.append(message)