g8a9 commited on
Commit
a806f8f
1 Parent(s): a1acca8

use ferret to explain

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. single.py +7 -5
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- transformers==4.20.1
 
 
1
+ transformers==4.20.1
2
+ ferret-xai>=0.1.0
single.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
3
 
4
 
5
  @st.cache()
@@ -7,7 +8,6 @@ def get_model(model_name):
7
  return AutoModelForSequenceClassification.from_pretrained(model_name)
8
 
9
 
10
- @st.cache()
11
  def get_tokenizer(tokenizer_name):
12
  return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
13
 
@@ -38,8 +38,10 @@ def body():
38
 
39
  compute = st.button("Compute")
40
 
41
- if text and compute and model_name and tokenizer_name:
42
- st.text("hellp")
 
43
 
44
- # model = get_model(model_name)
45
- # tokenizer = get_tokenizer(tokenizer_name)
 
 
1
  import streamlit as st
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ from ferret import Benchmark
4
 
5
 
6
  @st.cache()
 
8
  return AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
 
 
11
  def get_tokenizer(tokenizer_name):
12
  return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
13
 
 
38
 
39
  compute = st.button("Compute")
40
 
41
+ if compute and model_name and tokenizer_name:
42
+ model = get_model(model_name)
43
+ tokenizer = get_tokenizer(tokenizer_name)
44
 
45
+ bench = Benchmark(model, tokenizer)
46
+ explanations = bench.explain(text)
47
+ st.dataframe(bench.show_table(explanations))