jmdu commited on
Commit
56ba8e8
1 Parent(s): b4d7feb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import torch
4
+
5
+ # Load the model and tokenizer
6
+ @st.cache(allow_output_mutation=True)
7
+ def load_model():
8
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/SFR-Embedding-Mistral")
9
+ model = AutoModel.from_pretrained("Salesforce/SFR-Embedding-Mistral")
10
+ return tokenizer, model
11
+
12
+ tokenizer, model = load_model()
13
+
14
+ def embed_text(text):
15
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=32768)
16
+ outputs = model(**inputs)
17
+ return outputs.last_hidden_state.mean(dim=1).detach().numpy()
18
+
19
+ def main():
20
+ st.title("Text Embedding using Salesforce/SFR-Embedding-Mistral")
21
+
22
+ # Text input
23
+ text = st.text_area("Enter text here:", height=150)
24
+
25
+ if st.button("Get Embeddings"):
26
+ if text:
27
+ with st.spinner('Fetching embeddings...'):
28
+ embeddings = embed_text(text)
29
+ st.write(embeddings)
30
+ else:
31
+ st.warning("Please enter some text to process.")
32
+
33
+ if __name__ == "__main__":
34
+ main()