clezcano commited on
Commit
11d6fdc
2 Parent(s): d9461c7 b6a6e61

Merge remote-tracking branch 'origin/main'

Browse files
Files changed (1) hide show
  1. app.py +31 -21
app.py CHANGED
@@ -1,33 +1,43 @@
1
- # Import necessary libraries
2
  import streamlit as st
3
- import transformers
4
  import torch
5
- from transformers import pipeline
6
 
7
- # Set up the Streamlit app
8
  st.title("Emotion Detection with Transformers")
9
 
10
- # Create a text input widget
11
  user_input = st.text_area("Enter your text:")
12
 
13
 
14
- # Define a function for sentiment analysis using transformers
15
- @st.cache_data
16
- def load_model():
17
- return pipeline("sentiment-analysis")
 
 
 
18
 
19
 
20
- # Load the sentiment analysis model
21
- sentiment_analyzer = load_model()
22
 
23
- # Create a button to analyze the emotion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  if st.button("Analyze Emotion"):
25
- if user_input:
26
- # Perform sentiment analysis on user input
27
- result = sentiment_analyzer(user_input)
28
-
29
- # Display the result
30
- emotion = result[0]['label']
31
- st.write(f"Emotion: {emotion}")
32
- else:
33
- st.warning("Please enter some text to analyze.")
 
 
1
  import streamlit as st
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ # Set up Streamlit
6
  st.title("Emotion Detection with Transformers")
7
 
8
+ # Text input
9
  user_input = st.text_area("Enter your text:")
10
 
11
 
12
+ # Function to load model and tokenizer using @st.cache_data
13
+ @st.cache_data()
14
+ def load_model_and_tokenizer():
15
+ model_name = "mrm8488/t5-base-finetuned-emotion"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
+ return tokenizer, model
19
 
20
 
21
+ tokenizer, model = load_model_and_tokenizer()
 
22
 
23
+
24
+ # Function to analyze emotion
25
+ def analyze_emotion(text):
26
+ if text.strip() == "":
27
+ return "Please enter some text to analyze."
28
+
29
+ input_ids = tokenizer.encode(text + '</s>', return_tensors='pt')
30
+
31
+ output = model.generate(input_ids=input_ids,
32
+ max_length=2)
33
+
34
+ dec = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
35
+ label = dec[0]
36
+
37
+ return f"Emotion: {label.capitalize()}"
38
+
39
+
40
+ # Analyze button
41
  if st.button("Analyze Emotion"):
42
+ result = analyze_emotion(user_input)
43
+ st.write(result)