niclasfw commited on
Commit
84b5b05
1 Parent(s): a088b25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -3,27 +3,27 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers import pipeline
4
 
5
 
6
- # @st.cache(allow_output_mutation=True)
7
- # def get_model():
8
- # # load base LLM model and tokenizer
9
 
10
- # model_id = "niclasfw/schlager-bot-004"
11
- # tokenizer = AutoTokenizer.from_pretrained(model_id)
12
- # model = AutoModelForCausalLM.from_pretrained(
13
- # model_id,
14
- # # low_cpu_mem_usage=True,
15
- # # torch_dtype=torch.float16,
16
- # # load_in_4bit=True,
17
- # )
18
 
19
- # return tokenizer, model
20
 
21
- # tokenizer, model = get_model()
22
 
23
- model_id = "niclasfw/schlager-bot-004"
24
 
25
- model = AutoModelForCausalLM.from_pretrained(model_id)
26
- tokenizer = AutoTokenizer.from_pretrained(model_id)
27
 
28
  # generator = pipeline(task="text-generation", model=model_id, tokenizer=model_id)
29
 
@@ -46,7 +46,7 @@ if user_input and button:
46
  # input = tokenizer(prompt, padding=True, return_tensors="pt")
47
  # generate_ids = model.generate(input.input_ids, max_length=500, top_p=0.75, temperature=0.95, top_k=15)
48
  # output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
49
- input_ids = tokenizer(prompt, return_tensors="pt", truncation=True)
50
  outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
51
 
52
  st.write(output)
 
3
  from transformers import pipeline
4
 
5
 
6
+ @st.cache(allow_output_mutation=True)
7
+ def get_model():
8
+ # load base LLM model and tokenizer
9
 
10
+ model_id = "niclasfw/schlager-bot-004"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_id,
14
+ low_cpu_mem_usage=True,
15
+ torch_dtype=torch.float16,
16
+ load_in_4bit=True,
17
+ )
18
 
19
+ return tokenizer, model
20
 
21
+ tokenizer, model = get_model()
22
 
23
+ # model_id = "niclasfw/schlager-bot-004"
24
 
25
+ # model = AutoModelForCausalLM.from_pretrained(model_id)
26
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
27
 
28
  # generator = pipeline(task="text-generation", model=model_id, tokenizer=model_id)
29
 
 
46
  # input = tokenizer(prompt, padding=True, return_tensors="pt")
47
  # generate_ids = model.generate(input.input_ids, max_length=500, top_p=0.75, temperature=0.95, top_k=15)
48
  # output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
49
+ input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
50
  outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
51
 
52
  st.write(output)