PeteBleackley commited on
Commit
679a7b2
1 Parent(s): 1aab673

Download training data from S3

Browse files
Files changed (2) hide show
  1. app.py +7 -3
  2. scripts.py +11 -0
app.py CHANGED
@@ -10,8 +10,10 @@ import gradio as gr
10
  import scripts
11
  import pandas
12
 
13
- def greet(name):
14
- return "Hello " + name + "!!"
 
 
15
 
16
  def train():
17
  history = scripts.train_models('PlayfulTechnology')
@@ -19,8 +21,10 @@ def train():
19
 
20
 
21
  with gr.Blocks() as trainer:
22
- training_button = gr.Button(value="Train models")
 
23
  loss_plot = gr.Plot()
 
24
  training_button.click(train,inputs=[],outputs=[loss_plot])
25
 
26
  trainer.launch()
 
10
  import scripts
11
  import pandas
12
 
13
+ def download(button):
14
+ scripts.download_training_data()
15
+ return gr.Button.update(interactive=True)
16
+
17
 
18
  def train():
19
  history = scripts.train_models('PlayfulTechnology')
 
21
 
22
 
23
  with gr.Blocks() as trainer:
24
+ download_button = gr.Button(value='Doenload training_data')
25
+ training_button = gr.Button(value="Train models",interactive=False)
26
  loss_plot = gr.Plot()
27
+ download_button.click(download,inputs=download_button,outputs=training_button)
28
  training_button.click(train,inputs=[],outputs=[loss_plot])
29
 
30
  trainer.launch()
scripts.py CHANGED
@@ -19,6 +19,7 @@ import scipy.spatial
19
  import seaborn
20
  import tqdm
21
  import gradio
 
22
 
23
  class SequenceCrossEntropyLoss(torch.nn.Module):
24
  def __init__(self):
@@ -55,6 +56,16 @@ def clean_question(doc):
55
  words.append('?')
56
  return ''.join(words)
57
 
 
 
 
 
 
 
 
 
 
 
58
  def prepare_wiki_qa(filename,outfilename):
59
  data = pandas.read_csv(filename,sep='\t')
60
  data['QNum']=data['QuestionID'].apply(lambda x: int(x[1:]))
 
19
  import seaborn
20
  import tqdm
21
  import gradio
22
+ import boto3
23
 
24
  class SequenceCrossEntropyLoss(torch.nn.Module):
25
  def __init__(self):
 
56
  words.append('?')
57
  return ''.join(words)
58
 
59
+ def download_training_data():
60
+ if not os.path.exists('corpora'):
61
+ os.makedirs('corpora')
62
+ s3 = boto3.client('s3',
63
+ aws_access_key_id=os.environ['AWS_KEY'],
64
+ aws_secret_access_key=os.evviron['AWS_SECRET'])
65
+ for obj in s3.list_objects(Bucket='qarac')['Contents']:
66
+ filename = obj['Key']
67
+ s3.download_file('qarac',filename,'corpora/{}'.format(filename))
68
+
69
  def prepare_wiki_qa(filename,outfilename):
70
  data = pandas.read_csv(filename,sep='\t')
71
  data['QNum']=data['QuestionID'].apply(lambda x: int(x[1:]))