niclasfw commited on
Commit
b7ead25
1 Parent(s): 6e9f8ce

Pushing updates to huggingface space.

Browse files
Files changed (2) hide show
  1. app.py +26 -22
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,38 +1,42 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoModelForCausalLM
4
- from transformers import AutoTokenizer
5
 
6
- model_id = "niclasfw/schlager-bot-004"
 
 
7
 
8
- # load base LLM model and tokenizer
9
- model = AutoModelForCausalLM.from_pretrained(
 
10
  model_id,
11
  low_cpu_mem_usage=True,
12
  torch_dtype=torch.float16,
13
  load_in_4bit=True,
14
- )
15
- tokenizer = AutoTokenizer.from_pretrained(model_id)
16
 
17
- st.title('Schlager Bot')
 
 
18
 
19
- print("Enter verse (minimum of 15 words): ")
20
- sample = input()
 
21
 
22
- prompt = f"""### Instruction:
23
- Benuzte den gegebenen Input um ein Schlager Lied zu schreiben.
24
 
25
- ### Input:
26
- {sample}
 
27
 
28
- ### Response:
29
- """
30
 
31
- query = st.text_input(prompt, key="input")
 
 
 
 
 
 
32
 
33
- input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids
34
- outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
35
 
36
- print(f"Prompt:\n{sample}\n")
37
- print("**************")
38
- st.write(outputs)
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
 
5
+ @st.cache(allow_output_mutation=True)
6
+ def get_model():
7
+ # load base LLM model and tokenizer
8
 
9
+ model_id = "niclasfw/schlager-bot-004"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
  model_id,
13
  low_cpu_mem_usage=True,
14
  torch_dtype=torch.float16,
15
  load_in_4bit=True,
16
+ )
 
17
 
18
+ return tokenizer, model
19
+
20
+ tokenizer, model = get_model()
21
 
22
+ st.title('Schlager Bot')
23
+ user_input = st.text_area('Enter verse (minimum of 15 words): ')
24
+ button = st.button('Generate Lyrics')
25
 
 
 
26
 
27
+ if user_input and button:
28
+ prompt = f"""### Instruction:
29
+ Benuzte den gegebenen Input um ein Schlager Lied zu schreiben.
30
 
31
+ ### Input:
32
+ {user_input}
33
 
34
+ ### Response:
35
+ """
36
+ input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids
37
+ outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
38
+ st.write("Prompt: ", user_input)
39
+ st.write("**************")
40
+ st.write(outputs)
41
 
 
 
42
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers