Annaamalai commited on
Commit
ff1bd0f
1 Parent(s): 1ab3e78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -30
app.py CHANGED
@@ -1,38 +1,37 @@
1
  import streamlit as st
2
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
 
4
- # Load the Mistral model for text generation
5
- model_name = "salesforce/gpt-3.5-mistral-small"
6
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
7
- model = GPT2LMHeadModel.from_pretrained(model_name)
8
 
9
  # Streamlit app
10
  def main():
11
- st.title("Cherry Bot")
12
- st.markdown("*Your Companion for Suicide Prevention*")
13
-
14
- # Add a sidebar to the app
15
- with st.sidebar:
16
- st.markdown("# Text Generation Settings")
17
- # Slider for adjusting the maximum length of generated text
18
- max_length = st.slider("Max Length of Generated Text:", min_value=10, max_value=200, value=50, step=10)
19
-
20
- # Text input for user to input starting text for generation
21
- starting_text = st.text_area("Enter starting text for generation:", height=100)
22
-
23
- # Perform text generation when the user clicks the button
24
- if st.button("Generate Text"):
25
- if starting_text:
26
- # Tokenize input text
27
- input_ids = tokenizer.encode(starting_text, return_tensors="pt")
28
-
29
- # Generate text using the loaded model
30
- generated_ids = model.generate(input_ids, max_length=max_length, num_return_sequences=1)
31
- generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
32
-
33
- # Display the generated text
34
- st.write("Generated Text:")
35
- st.write(generated_text)
36
 
37
  if __name__ == "__main__":
38
  main()
 
1
  import streamlit as st
2
+ from transformers import BigBirdForQuestionAnswering, BigBirdTokenizer
3
 
4
+ # Load the BigBird model and tokenizer
5
+ model_name = "google/bigbird-base-trivia-itc"
6
+ model = BigBirdForQuestionAnswering.from_pretrained(model_name)
7
+ tokenizer = BigBirdTokenizer.from_pretrained(model_name)
8
 
9
  # Streamlit app
10
  def main():
11
+ st.title("CherryBot")
12
+
13
+ # Text input for user to input the question
14
+ question = st.text_input("Enter your question:")
15
+
16
+ # Text area for user to input the context
17
+ context = st.text_area("Enter the context for answering:", height=200)
18
+
19
+ # Perform question answering when the user clicks the button
20
+ if st.button("Get Answer"):
21
+ if question and context:
22
+ # Tokenize input question and context
23
+ encoded_input = tokenizer(question, context, return_tensors='pt')
24
+
25
+ # Perform question answering using the loaded model
26
+ output = model(**encoded_input)
27
+
28
+ # Extract and display the answer
29
+ answer_start = output.start_logits.argmax()
30
+ answer_end = output.end_logits.argmax() + 1
31
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(encoded_input.input_ids[0][answer_start:answer_end]))
32
+
33
+ st.write("Answer:")
34
+ st.write(answer)
 
35
 
36
  if __name__ == "__main__":
37
  main()