File size: 858 Bytes
c2e9643
e05edf7
 
c2e9643
acf9e66
e05edf7
 
c2e9643
027fad9
c2e9643
027fad9
 
 
 
 
 
 
 
c2e9643
 
027fad9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import streamlit as st
from transformers import AutoTokenizer, FalconModel
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("Rocketknight1/falcon-rw-1b")
model = FalconModel.from_pretrained("Rocketknight1/falcon-rw-1b")

model.to(device)

def generate_text(prompt, max_new_tokens=100, do_sample=True):
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
    return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

st.title("KviGPT - Hugging Face Chat")

user_input = st.text_input("You:", value="My favourite condiment is ")

if st.button("Send"):
    prompt = user_input
    model_response = generate_text(prompt)[0]
    st.write("KviGPT:", model_response)