File size: 1,801 Bytes
6e9f8ce
b7ead25
73004ed
6e9f8ce
3022295
84b5b05
 
 
6e9f8ce
84b5b05
 
 
 
 
 
 
 
6e9f8ce
84b5b05
b7ead25
84b5b05
6e9f8ce
84b5b05
73004ed
84b5b05
 
703c573
a088b25
73004ed
 
b7ead25
 
6e9f8ce
 
b7ead25
 
 
6e9f8ce
b7ead25
 
6e9f8ce
b7ead25
 
a088b25
73004ed
 
 
 
84b5b05
a088b25
3022295
 
6e9f8ce
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline


@st.cache(allow_output_mutation=True)
def get_model():
    # load base LLM model and tokenizer

    model_id = "niclasfw/schlager-bot-004"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    )

    return tokenizer, model

tokenizer, model = get_model()

# model_id = "niclasfw/schlager-bot-004"

# model = AutoModelForCausalLM.from_pretrained(model_id)
# tokenizer = AutoTokenizer.from_pretrained(model_id)

# generator = pipeline(task="text-generation", model=model_id, tokenizer=model_id)

st.title('Schlager Bot')
user_input = st.text_area('Enter verse (minimum of 15 words): ')
button = st.button('Generate Lyrics')


if user_input and button:
    prompt = f"""### Instruction:
    Benuzte den gegebenen Input um ein Schlager Lied zu schreiben.

    ### Input:
    {user_input}

    ### Response:
    """
    # output = generator(prompt, do_sample=True, max_new_tokens=500, top_p=0.75, temperature=0.95, top_k=15)
    # st.write("Prompt: ", user_input)
    # input = tokenizer(prompt, padding=True, return_tensors="pt")
    # generate_ids = model.generate(input.input_ids, max_length=500, top_p=0.75, temperature=0.95, top_k=15)
    # output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
    outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)

    st.write(output)