llmgemma / app.py
Vikalp026var's picture
Create app.py
c6bc1b6 verified
raw
history blame contribute delete
No virus
1.55 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
# Initialize the model and tokenizer
set_seed(1234)
model_id = "Vikalp026var/gemma-2b-it-pythoncodegen"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda")
# Initialize chat history
chat = []
st.title('Chat with a Language Model')
st.write("This is a simple chatbot using a fine-tuned model on GPT-2. Type 'exit' to end the conversation.")
# Text input
user_input = st.text_input("User:", key="user_input")
# Function to handle chat
def handle_chat(user_input):
if user_input:
user_turn = {"role": "user", "content": user_input}
chat.append(user_turn)
token_inputs = tokenizer.apply_chat_template(chat, tokenize=True, return_tensors="pt", add_generation_prompt=True).to("cuda")
token_outputs = model.generate(input_ids=token_inputs, do_sample=True, max_new_tokens=500, temperature=.5)
new_tokens = token_outputs[0][token_inputs.shape[-1]:]
decoded_output = tokenizer.decode(new_tokens, skip_special_tokens=True)
model_turn = {"role": "model", "content": decoded_output}
chat.append(model_turn)
return decoded_output
return ""
# Display model response
if user_input.lower() == "exit":
st.stop()
else:
response = handle_chat(user_input)
st.text_area("Model:", value=response, height=200, key="model_response")
# Run the Streamlit app
if __name__ == '__main__':
st.run()