File size: 1,263 Bytes
6e9f8ce
b7ead25
73004ed
5a72973
6e9f8ce
3022295
f198b15
84b5b05
 
6e9f8ce
84b5b05
 
 
 
 
 
 
 
6e9f8ce
84b5b05
b7ead25
84b5b05
6e9f8ce
73004ed
0b07e1e
b7ead25
e1abb31
6e9f8ce
 
b7ead25
 
 
6e9f8ce
b7ead25
e1abb31
6e9f8ce
b7ead25
 
84b5b05
a088b25
4bc419a
3022295
33d77bb
6e9f8ce
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline
import torch


@st.cache_resource()
def get_model():
    # load base LLM model and tokenizer

    model_id = "niclasfw/schlager-bot-004"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    )

    return tokenizer, model

tokenizer, model = get_model()

st.title('Schlager Bot')
user_input = st.text_area('Enter verse: ')
button = st.button('Generate Lyrics')
verse = f"""{user_input} \n"""


if user_input and button:
    prompt = f"""### Instruction:
    Benuzte den gegebenen Input um ein Schlager Lied zu schreiben.

    ### Input:
    {verse}

    ### Response:
    """
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
    outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
    lyrics = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]

    st.write(f"""Generated Lyrics: \n {lyrics}""")