gbarone77 commited on
Commit
1704b25
·
1 Parent(s): 6fff48b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -1,10 +1,22 @@
1
  import gradio
 
 
2
 
3
- def my_inference_function(name):
4
- return "Hello " + name + "!"
5
 
 
 
 
 
 
 
 
 
 
 
6
  gradio_interface = gradio.Interface(
7
- fn = my_inference_function,
8
  inputs = "text",
9
  outputs = "text"
10
  )
 
1
  import gradio
2
+ import transformers
3
+ import torch
4
 
5
+ tokenizer = AutoTokenizer.from_pretrained('t5-small-finetuned-wikisql-with-cols')
6
+ model = T5ForConditionalGeneration.from_pretrained('t5-small-finetuned-wikisql-with-cols')
7
 
8
+ def translate_to_sql(text):
9
+ inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt')
10
+ input_ids = inputs.input_ids
11
+ attention_mask = inputs.attention_mask
12
+ output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
13
+ return tokenizer.decode(output[0], skip_special_tokens=True)
14
+
15
+
16
+
17
+ #Input example: 'translate to SQL: When was Olympic games held in Rome? table ID: ID, city, year, cost, attendees'
18
  gradio_interface = gradio.Interface(
19
+ fn = translate_to_sql,
20
  inputs = "text",
21
  outputs = "text"
22
  )