xiaokunchen commited on
Commit
ba2194a
1 Parent(s): a46e4f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -2
app.py CHANGED
@@ -1,4 +1,30 @@
1
  import streamlit as st
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ # Title of the Streamlit app
5
+ st.title("Neo Scalinglaw 250M Model")
6
+
7
+ # Text input for user prompt
8
+ user_input = st.text_input("Enter your prompt:")
9
+
10
+ # Load the tokenizer and model
11
+ @st.cache_resource
12
+ def load_model():
13
+ model_path = 'm-a-p/neo_scalinglaw_250M'
14
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
15
+ model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype='auto').eval()
16
+ return tokenizer, model
17
+
18
+ tokenizer, model = load_model()
19
+
20
+ # Generate text when the user inputs a prompt and presses the button
21
+ if st.button("Generate"):
22
+ if user_input:
23
+ with st.spinner("Generating response..."):
24
+ input_ids = tokenizer(user_input, add_generation_prompt=True, return_tensors='pt').to(model.device)
25
+ output_ids = model.generate(**input_ids, max_new_tokens=20)
26
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
27
+ st.success("Generated response:")
28
+ st.write(response)
29
+ else:
30
+ st.error("Please enter a prompt.")