harir commited on
Commit
3d3c8d1
1 Parent(s): 8cb156b

add models.py

Browse files
Files changed (1) hide show
  1. models.py +108 -0
models.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import parser
2
+ import requests
3
+
4
+ def zephyr_score(sentence):
5
+ prompt = f"""<|user|>
6
+ You are an assistant helping with paper reviews.
7
+ You will be tasked to classify sentences as 'J' or 'V'
8
+
9
+ 'J' is positive or 'J' is encouraging.
10
+ 'J' has a neutral tone or 'J' is professional.
11
+ 'V' is overly blunt or 'V' contains excessive negativity and no constructive feedback.
12
+ 'V' contains an accusatory tone or 'V' contains sweeping generalizations or 'V' contains personal attacks.
13
+
14
+ Text: "{sentence}"
15
+
16
+ Please classify this text as either 'J', 'W', or 'V'. Only output 'J', 'W', or 'V' with no additional explanation.<|endoftext|>
17
+ <|assistant|>
18
+ """
19
+ return prompt
20
+
21
+ def zephyr_revise(sentence):
22
+ prompt = f"""<|user|>
23
+ You are an assistant that helps users revise Paper Reviews.
24
+ Paper reviews exist to provide authors of academic research papers constructive critism.
25
+
26
+ This is text found in a review.
27
+ This text was classified as 'toxic':
28
+
29
+ Text: "{sentence}"
30
+
31
+ Please revise this text such that it maintains the criticism in the original text and delivers it in a friendly but professional manner. Make minimal changes to the original text.<|endoftext|>
32
+ <|assistant|>
33
+ """
34
+ return prompt
35
+
36
+ def query_model_score(sentence, api_key, model_id, prompt_fun):
37
+ API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
38
+ headers = {"Authorization": f"Bearer {api_key}"}
39
+ prompt = prompt_fun(sentence)
40
+ def query(payload):
41
+ print(payload)
42
+ response = requests.post(API_URL, headers=headers, json=payload)
43
+ return response.json()
44
+ parameters = {"max_new_tokens" : 20, "temperature": 0.0, "return_full_text": False}
45
+ options = {"wait_for_model": True}
46
+ data = query({"inputs": f"{prompt}", "parameters": parameters, "options": options})
47
+ score = data[0]['generated_text']
48
+ if 'v' in score.lower():
49
+ return 1
50
+ else:
51
+ return 0
52
+
53
+ def query_model_revise(sentence, api_key, model_id, prompt_fun):
54
+ API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
55
+ headers = {"Authorization": f"Bearer {api_key}"}
56
+ prompt = prompt_fun(sentence)
57
+ def query(payload):
58
+ response = requests.post(API_URL, headers=headers, json=payload)
59
+ return response.json()
60
+ parameters = {"max_new_tokens" : 200, "temperature": 0.0, "return_full_text": False}
61
+ options = {"wait_for_model": True}
62
+ data = query({"inputs": f"{prompt}", "parameters": parameters, "options": options})
63
+ revision = data[0]['generated_text']
64
+ return revision
65
+
66
+ def revise_review(review, api_key, model_id, highlight_color):
67
+ result = {
68
+ "success": False,
69
+ "data": {
70
+ "revision": "",
71
+ "score": "",
72
+ "sentence_count": "",
73
+ "revised_sentences": ""
74
+ },
75
+ "message": ""
76
+ }
77
+
78
+ try:
79
+ review = review.replace('"', "'")
80
+ sentences = parser.parse_sentences(review)
81
+ review_score = 0
82
+ revision_count = 0
83
+ review_revision = ""
84
+ for sentence in sentences:
85
+ if len(sentence) > 20:
86
+ score = query_model_score(sentence, api_key, model_id, zephyr_score)
87
+ if score == 0:
88
+ review_revision += " " + sentence
89
+ else:
90
+ review_score = 1
91
+ revision_count +=1
92
+ revision = query_model_revise(sentence, api_key, model_id, zephyr_revise)
93
+ revision = revision.strip().strip('"')
94
+ review_revision += f"<div style='background-color: {highlight_color}; display: inline;'>{revision}</div>"
95
+ else:
96
+ review_revision += " " + sentence
97
+ # end revision/prepare return json
98
+
99
+ result["success"] = True
100
+ result["message"] = "Review successfully revised!"
101
+ result["data"]["revision"] = review_revision
102
+ result["data"]["score"] = review_score
103
+ result["data"]["sentence_count"] = sum(1 for sentence in sentences if len(sentence) > 20)
104
+ result["data"]["revised_sentences"] = revision_count
105
+ except Exception as e:
106
+ result["message"] = str(e)
107
+
108
+ return result