Tabish009 commited on
Commit
84d8890
·
verified ·
1 Parent(s): abc76fc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import accelerate
4
+
5
+ # Load the model and tokenizer
6
+ @st.cache_resource
7
+ def load_model_and_tokenizer():
8
+ model_name_or_path = "anthropic/mistral-7b"
9
+ accelerator = accelerate.Accelerator(device_map="auto")
10
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map=accelerator.device_map)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
12
+ return model, tokenizer
13
+
14
+ # Function to generate the response
15
+ @st.cache_data
16
+ def generate_response(prompt):
17
+ prompt_template = f'''
18
+ <|prompter|>:{prompt}
19
+ <|assistant|>:
20
+ '''
21
+ input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids
22
+ with accelerator.autocast():
23
+ output = model.generate(inputs=input_ids, temperature=0.7, do_sample=True, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, max_new_tokens=512)
24
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
25
+ return response
26
+
27
+ # Streamlit app
28
+ def main():
29
+ st.title("Mistral 7B Language Model")
30
+ model, tokenizer = load_model_and_tokenizer()
31
+
32
+ prompt = st.text_area("Enter your query:")
33
+ if st.button("Submit"):
34
+ with st.spinner("Generating response..."):
35
+ response = generate_response(prompt)
36
+ st.write(response)
37
+
38
+ if __name__ == "__main__":
39
+ main()