Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers import pipeline | |
# @st.cache(allow_output_mutation=True) | |
# def get_model(): | |
# # load base LLM model and tokenizer | |
# model_id = "niclasfw/schlager-bot-004" | |
# tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# model = AutoModelForCausalLM.from_pretrained( | |
# model_id, | |
# # low_cpu_mem_usage=True, | |
# # torch_dtype=torch.float16, | |
# # load_in_4bit=True, | |
# ) | |
# return tokenizer, model | |
# tokenizer, model = get_model() | |
model_id = "niclasfw/bloom-1b5-schlager-bot-004" | |
model = AutoModelForCausalLM.from_pretrained(model_id) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
generator = pipeline(task="text-generation", model=model_id, tokenizer=model_id) | |
st.title('Schlager Bot') | |
user_input = st.text_area('Enter verse (minimum of 15 words): ') | |
button = st.button('Generate Lyrics') | |
if user_input and button: | |
prompt = f"""### Instruction: | |
Benuzte den gegebenen Input um ein Schlager Lied zu schreiben. | |
### Input: | |
{user_input} | |
### Response: | |
""" | |
output = generator(prompt, do_sample=True, max_new_tokens=500, top_p=0.75, temperature=0.95, top_k=15) | |
# st.write("Prompt: ", user_input) | |
# input = tokenizer(prompt, padding=True, return_tensors="pt") | |
# generate_ids = model.generate(input.input_ids, max_length=500, top_p=0.75, temperature=0.95, top_k=15) | |
# output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
# input_ids = tokenizer(prompt, return_tensors="pt", truncation=True) | |
# outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15) | |
st.write(output) | |