import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import pipeline import torch @st.cache_resource() def get_model(): # load base LLM model and tokenizer model_id = "niclasfw/schlager-bot-004" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16, load_in_4bit=True, ) return tokenizer, model tokenizer, model = get_model() st.title('Schlager Bot') user_input = st.text_area('Enter verse: ') button = st.button('Generate Lyrics') verse = f"""{user_input} \n""" if user_input and button: prompt = f"""### Instruction: Benuzte den gegebenen Input um ein Schlager Lied zu schreiben. ### Input: {verse} ### Response: """ input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda() outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15) lyrics = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):] st.write(f"""Generated Lyrics: \n {lyrics}""")