Spaces:
Sleeping
Sleeping
import os | |
import json | |
import openai | |
from collections import defaultdict | |
deployment_name = "gpt-35-turbo" | |
data_file = 'data.json' | |
delimiter = "####" | |
def get_completion_from_messages(messages, engine=deployment_name, temperature=0, max_tokens=500): | |
openai.api_key = os.environ['API_KEY'] | |
openai.api_base = os.environ['API_BASE'] | |
openai.api_type = os.environ['API_TYPE'] | |
openai.api_version = os.environ['API_VERSION'] | |
response = openai.ChatCompletion.create( | |
engine=engine, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
) | |
return response.choices[0].message["content"] | |
def get_qms_and_pages(): | |
qms = get_qms() | |
qms_by_page = defaultdict(list) | |
for qm_name, qm_info in qms.items(): | |
page = qm_info.get('page') | |
if page: | |
qms_by_page[page].append(qm_name) | |
return dict(qms_by_page) | |
def get_qms(): | |
with open(data_file, 'r') as file: | |
qms = json.load(file) | |
return qms | |
def find_pages_and_qms_only(user_input): | |
delimiter = "####" | |
example_string = """ | |
Customer query: What topics are related to quality markers Q10c and Q5b? \ | |
Your response: [{'page': '3', 'quality_marker': ['Q5b']}, {'page': '5', 'quality_marker': ['Q10c']}] \ | |
Customer query: What topics are related to page 2? \ | |
Your response: [{'page': '2', 'quality_marker': ['Q2a', 'Q2b', 'Q2c', 'Q2d']}] | |
""" | |
system_message = f""" | |
You will be provided with a customer query. \ | |
The customer query will be delimited with {delimiter} characters. \ | |
The customer is expected to query about some details related to pages or quality markers in Behaviour Support Plan (PBS) summary documents. \ | |
For example, the customer may want to know which domains, models, topics, or descriptions are associated with a given page or quality marker. \ | |
Your task is to output a python list of objects (dicts), where each object has the following keys: \ | |
- 'page' \ | |
- 'quality_marker' \ | |
Rules: \ | |
- 'page' value must be one of: 1, 2, 3, 4, 5. \ | |
- 'quality_marker' value is a list of allowed quality markers \ | |
- If a specific page number is mentioned, it must be associated with all the correct quality markers in the allowed quality markers list below. \ | |
- If specific quality marker(s) are mentioned, the 'quality_marker' list in your response must only include the mentioned quality marker(s). \ | |
- If no pages or quality markers are found, output an empty list. \ | |
- Only output the list of objects (dicts), nothing else. | |
Allowed quality markers: | |
For page 1: | |
Q1 | |
For page 2: | |
Q2a | |
Q2b | |
Q2c | |
Q2d | |
For page 3: | |
Q3a | |
Q3b | |
Q3c | |
Q3d | |
Q3e | |
Q4a | |
Q4b | |
Q5b | |
For page 4: | |
Q6b | |
Q6c | |
Q7i | |
Q7ii | |
Q7iii | |
Q7iv | |
Q7a | |
Q7b | |
Q8a | |
Q8b | |
Q8c | |
Q8d | |
Q8e | |
Q8f | |
For page 5: | |
Q9a | |
Q9b | |
Q9c | |
Q10a | |
Q10b | |
Q10c | |
Q10d | |
Q11a | |
Q11b | |
Q11c | |
Examples: | |
{example_string} | |
""" | |
messages = [ | |
{'role':'system', 'content': system_message}, | |
{'role':'user', 'content': f"{delimiter}{user_input}{delimiter}"}, | |
] | |
return get_completion_from_messages(messages) | |
# qm look up (either by page or by qm within page) | |
def get_qm_by_name(name): | |
qms = get_qms() | |
return qms.get(name, None) | |
def get_qms_by_page(page): | |
qms = get_qms() | |
return [qm for qm in qms.values() if qm.get('page') == page] | |
def read_string_to_list(input_string): | |
if input_string is None: | |
return None | |
try: | |
input_string = input_string.replace("'", "\"") # Replace single quotes with double quotes for valid JSON | |
data = json.loads(input_string) | |
return data | |
except json.JSONDecodeError: | |
print("Error: Invalid JSON string") | |
return None | |
def generate_output_string(data_list): | |
output_string = "" | |
if data_list is None: | |
return output_string | |
for data in data_list: | |
try: | |
if "quality_marker" in data.keys(): | |
qm_list = data["quality_marker"] | |
for qm_name in qm_list: | |
qm = get_qm_by_name(qm_name) | |
if qm: | |
output_string += json.dumps(qm, indent=4) + "\n" | |
else: | |
print(f"Error: Quality marker '{qm_name}' not found") | |
elif "page" in data: | |
page_no = data["page"] | |
page_qms = get_qms_by_page(page_no) | |
for qm in page_qms: | |
output_string += json.dumps(qm, indent=4) + "\n" | |
else: | |
print("Error: Invalid object format") | |
except Exception as e: | |
print(f"Error: {e}") | |
return output_string | |