File size: 1,111 Bytes
0c09017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import List
import csv
import requests


def main(csv_path: str, target_col: int = 0, source_col: int = 1):
    target_texts = []
    source_texts = []
    with open(csv_path, 'r', newline='', encoding='utf-8') as csvfile:
        reader = csv.reader(csvfile)
        for row in reader: 
            target_texts.append(row[target_col])
            source_texts.append(row[source_col])
    
    similarities = get_similarity(target_texts, source_texts)
    with open('./tests/output.csv', mode="w", newline="", encoding="utf-8") as new_file:
        writer = csv.writer(new_file)
        for i in range(0, len(target_texts)):
            writer.writerow([ target_texts[i], source_texts[i], similarities[i] ])


def get_similarity(texts1: List[str], texts2: List[str]):
    response = requests.post("http://localhost:8000/api/similarity", json={ 
        "texts1": texts1,
        "texts2": texts2,
    })

    response_body = response.json()
    similarities = list(map(lambda i: i['similarity'], response_body))
    return similarities



if __name__ == "__main__":
    main('./tests/input.csv')