Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers import pipeline | |
import torch | |
def get_model(): | |
# load base LLM model and tokenizer | |
model_id = "niclasfw/schlager-bot-004" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float16, | |
load_in_4bit=True, | |
) | |
return tokenizer, model | |
tokenizer, model = get_model() | |
st.title('Schlager Bot') | |
user_input = st.text_area('Enter verse: ') | |
button = st.button('Generate Lyrics') | |
verse = f"""{user_input} \n""" | |
if user_input and button: | |
prompt = f"""### Instruction: | |
Benuzte den gegebenen Input um ein Schlager Lied zu schreiben. | |
### Input: | |
{verse} | |
### Response: | |
""" | |
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda() | |
outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15) | |
lyrics = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):] | |
st.write(f"""Generated Lyrics: \n {lyrics}""") | |