Spaces:
Sleeping
Sleeping
File size: 1,263 Bytes
6e9f8ce b7ead25 73004ed 5a72973 6e9f8ce 3022295 f198b15 84b5b05 6e9f8ce 84b5b05 6e9f8ce 84b5b05 b7ead25 84b5b05 6e9f8ce 73004ed 0b07e1e b7ead25 e1abb31 6e9f8ce b7ead25 6e9f8ce b7ead25 e1abb31 6e9f8ce b7ead25 84b5b05 a088b25 4bc419a 3022295 33d77bb 6e9f8ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline
import torch
@st.cache_resource()
def get_model():
# load base LLM model and tokenizer
model_id = "niclasfw/schlager-bot-004"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
return tokenizer, model
tokenizer, model = get_model()
st.title('Schlager Bot')
user_input = st.text_area('Enter verse: ')
button = st.button('Generate Lyrics')
verse = f"""{user_input} \n"""
if user_input and button:
prompt = f"""### Instruction:
Benuzte den gegebenen Input um ein Schlager Lied zu schreiben.
### Input:
{verse}
### Response:
"""
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
lyrics = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]
st.write(f"""Generated Lyrics: \n {lyrics}""")
|