Lev McKinney commited on
Commit
4c39b84
1 Parent(s): 21e4838

Attempting to render plot

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -1,8 +1,17 @@
1
  from platform import python_version
 
 
 
 
2
  import gradio as gr
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!" + "Using python version: " + python_version() + "."
 
 
 
 
 
6
 
7
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
8
  iface.launch()
 
1
  from platform import python_version
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from tuned_lens.nn import TunedLens
4
+ from tuned_lens.plotting import plot_lens
5
+
6
  import gradio as gr
7
 
8
+ LENS_PATH = '<PATH TO LENS>'
9
+
10
+ def plot_lens_outputs(text):
11
+ model = AutoModelForCausalLM.from_pretrained('gpt2')
12
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
13
+ #lens = TunedLens.load(LENS_PATH)
14
+ return gr.outputs.Plot(plot_lens(model, tokenizer, text=text))
15
 
16
+ iface = gr.Interface(fn=plot_lens_outputs, inputs="text", outputs=gr.outputs.Plot(type="auto"))
17
  iface.launch()