brancengregory commited on
Commit
0ef555c
1 Parent(s): 8fd318d

Add scripts

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. pyproject.toml +21 -0
  3. scripts/main.py +8 -0
  4. scripts/prep.py +54 -0
  5. scripts/upload.py +36 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ data/*.csv
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "demo-argilla"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Brancen Gregory <brancengregory@gmail.com>"]
6
+ readme = "README.md"
7
+ packages = [{include = "demo_argilla"}]
8
+
9
+ [tool.poetry.dependencies]
10
+ python = "^3.10, <3.11"
11
+ pandas = "^1.5.3"
12
+ argilla = "^1.3.0"
13
+ spacy = {extras = ["apple", "transformers"], version = "^3.5.0"}
14
+ datasets = "^2.9.0"
15
+
16
+ [tool.poetry.dependencies.en_core_web_trf]
17
+ url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.5.0/en_core_web_trf-3.5.0.tar.gz"
18
+
19
+ [build-system]
20
+ requires = ["poetry-core"]
21
+ build-backend = "poetry.core.masonry.api"
scripts/main.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import argilla as rg
2
+
3
+ rg.init(
4
+ api_url='https://brancengregory-demo-argilla.hf.space',
5
+ api_key='team.apikey'
6
+ )
7
+
8
+ dataset = rg.load("plaintiff_sample").prepare_for_training()
scripts/prep.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas
3
+ import psycopg2
4
+
5
+
6
+ def connect():
7
+ if os.getenv('NEW_OJO_HOST') == '':
8
+ print("No configuration for the OJO database was found. Please create one now using `ojo_auth()`.")
9
+ return
10
+ else:
11
+ conn = psycopg2.connect(
12
+ host = os.getenv('NEW_OJO_HOST'),
13
+ database = "ojodb",
14
+ user = os.getenv('NEW_OJO_DEFAULT_USER'),
15
+ password = os.getenv('NEW_OJO_DEFAULT_PASS'),
16
+ port = os.getenv('NEW_OJO_PORT'),
17
+ sslmode = os.getenv('NEW_OJO_SSL_MODE'),
18
+ sslrootcert = os.getenv('NEW_OJO_SSL_ROOT_CERT'),
19
+ sslcert = os.getenv('NEW_OJO_SSL_CERT'),
20
+ sslkey = os.getenv('NEW_OJO_SSL_KEY')
21
+ )
22
+ return conn
23
+
24
+ # A function to get the list of plaintiffs; Takes a parameter n which is the number of plaintiffs to return;
25
+ # If n is None, all plaintiffs are returned
26
+ def plaintiffs(n=None):
27
+ conn = connect()
28
+ with conn:
29
+ if n is None:
30
+ sql = """select distinct(filed_by) from eviction_addresses.case c left join public.issue i on c.id = i.case_id;"""
31
+ else:
32
+ sql = """select distinct(filed_by) from eviction_addresses.case c left join public.issue i on c.id = i.case_id limit {};""".format(n)
33
+
34
+ data = pandas.read_sql_query(sql, conn)
35
+ conn.close()
36
+ return data
37
+
38
+ data = plaintiffs().dropna()
39
+ data.to_csv('data/plaintiffs.csv', index=False, header=True)
40
+
41
+ def minutes(n=None):
42
+ conn = connect()
43
+ with conn:
44
+ if n is None:
45
+ sql = """select distinct(description) from eviction_addresses.case c left join public.minute m on c.id = m.case_id;"""
46
+ else:
47
+ sql = """select distinct(description) from eviction_addresses.case c left join public.minute m on c.id = m.case_id limit {};""".format(n)
48
+
49
+ data = pandas.read_sql_query(sql, conn)
50
+ conn.close()
51
+ return data
52
+
53
+ data = minutes().dropna()
54
+ data.to_csv('data/minutes.csv', index=False, header=True)
scripts/upload.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import argilla as rg
3
+ import spacy
4
+ from datasets import Dataset
5
+
6
+
7
+ # Configuration
8
+ rg.init(
9
+ api_url='https://brancengregory-demo-argilla.hf.space',
10
+ api_key='team.apikey'
11
+ )
12
+
13
+
14
+ # Plaintiffs
15
+ data = pd.read_csv("data/labelled_plaintiffs.csv")
16
+ data = data.rename(columns={"filed_by": "text"})
17
+
18
+ dataset = rg.read_pandas(data, task="TextClassification")
19
+
20
+ rg.log(dataset, "plaintiff_sample")
21
+
22
+
23
+ # Minutes
24
+ dataset = Dataset.from_csv("data/minutes.csv").rename_column("description", "text")
25
+
26
+ nlp = spacy.load("en_core_web_trf")
27
+
28
+ def tokenize(row):
29
+ tokens = [token.text for token in nlp(row["text"])]
30
+ return {"tokens": tokens}
31
+
32
+ dataset = dataset.map(tokenize)
33
+
34
+ dataset = rg.read_datasets(dataset, task="TokenClassification")
35
+
36
+ rg.log(dataset, "minutes_sample")