Spaces:
Sleeping
Sleeping
File size: 1,801 Bytes
6e9f8ce b7ead25 73004ed 6e9f8ce 3022295 84b5b05 6e9f8ce 84b5b05 6e9f8ce 84b5b05 b7ead25 84b5b05 6e9f8ce 84b5b05 73004ed 84b5b05 703c573 a088b25 73004ed b7ead25 6e9f8ce b7ead25 6e9f8ce b7ead25 6e9f8ce b7ead25 a088b25 73004ed 84b5b05 a088b25 3022295 6e9f8ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline
@st.cache(allow_output_mutation=True)
def get_model():
# load base LLM model and tokenizer
model_id = "niclasfw/schlager-bot-004"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
return tokenizer, model
tokenizer, model = get_model()
# model_id = "niclasfw/schlager-bot-004"
# model = AutoModelForCausalLM.from_pretrained(model_id)
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# generator = pipeline(task="text-generation", model=model_id, tokenizer=model_id)
st.title('Schlager Bot')
user_input = st.text_area('Enter verse (minimum of 15 words): ')
button = st.button('Generate Lyrics')
if user_input and button:
prompt = f"""### Instruction:
Benuzte den gegebenen Input um ein Schlager Lied zu schreiben.
### Input:
{user_input}
### Response:
"""
# output = generator(prompt, do_sample=True, max_new_tokens=500, top_p=0.75, temperature=0.95, top_k=15)
# st.write("Prompt: ", user_input)
# input = tokenizer(prompt, padding=True, return_tensors="pt")
# generate_ids = model.generate(input.input_ids, max_length=500, top_p=0.75, temperature=0.95, top_k=15)
# output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
st.write(output)
|