|
from enum import Enum |
|
from typing import List, Optional |
|
|
|
from llama_index import ServiceContext |
|
from llama_index.llms import OpenAI |
|
from llama_index.llms.base import LLM |
|
from llama_index.llms.openai_utils import to_openai_function |
|
from pydantic import BaseModel, ValidationError |
|
|
|
|
|
class SentimentEnum(str, Enum): |
|
""" |
|
Enum for predicted overall sentiment of a discussion thread |
|
""" |
|
|
|
POSITIVE = "POSITIVE" |
|
NEGATIVE = "NEGATIVE" |
|
NEUTRAL = "NEUTRAL" |
|
MIXED = "MIXED" |
|
UNKNOWN = "UNKNOWN" |
|
|
|
|
|
class DiscussionStatusEnum(str, Enum): |
|
""" |
|
Enum for representing the predicted status of the discussion thread |
|
""" |
|
|
|
ON_GOING = "ON_GOING" |
|
RESOLVED_OR_CLOSED = "RESOLVED_OR_CLOSED" |
|
UNKNOWN = "UNKNOWN" |
|
|
|
|
|
class ThreadMetadata(BaseModel): |
|
""" Metadata of a discussion thread for topics and sentiment. Topics must be under 50 characters. """ |
|
list_of_positive_topics: List[str] |
|
list_of_negative_topics: List[str] |
|
overall_sentiment: Optional[SentimentEnum] |
|
discussion_status_enum: Optional[DiscussionStatusEnum] |
|
|
|
|
|
class MetadataExtractor: |
|
def __init__(self, llm: LLM): |
|
self.llm = llm |
|
|
|
def extract_metadata(self, thread_summary: str) -> Optional[ThreadMetadata]: |
|
""" |
|
Extracts the metadata from the thread summary |
|
:param thread_summary: of the thread |
|
:return: metadata of the thread |
|
""" |
|
api_spec = to_openai_function(ThreadMetadata) |
|
response = self.llm.complete( |
|
"Analyze the thread summary: " + thread_summary, |
|
functions=[api_spec], |
|
) |
|
function_call_resp = response.additional_kwargs["function_call"]["arguments"] |
|
|
|
try: |
|
return ThreadMetadata.parse_raw(function_call_resp) |
|
except ValidationError: |
|
print(f"Error while parsing the detected question metadata: {function_call_resp}") |
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
import csv |
|
|
|
gpt_turbo: OpenAI = OpenAI(temperature=0, model="gpt-3.5-turbo") |
|
service_context = ServiceContext.from_defaults(llm=gpt_turbo, chunk_size=1024) |
|
metadata_extractor = MetadataExtractor(gpt_turbo) |
|
|
|
|
|
input_csv = "csv/platform-engg.csv" |
|
|
|
output_csv = "csv/platform-engg-updated.csv" |
|
|
|
column_to_read = "Summary" |
|
new_column_header = "Predicted Status" |
|
|
|
metadata_extractor = MetadataExtractor(gpt_turbo) |
|
|
|
headers = [] |
|
rows = [] |
|
|
|
|
|
with open(input_csv, mode='r', newline='', encoding='utf-8') as infile: |
|
csvreader = csv.reader(infile) |
|
headers = next(csvreader) |
|
|
|
|
|
if column_to_read not in headers: |
|
print(f"Error: Column '{column_to_read}' not found in the CSV file.") |
|
exit(1) |
|
|
|
|
|
index_to_read = headers.index(column_to_read) |
|
|
|
|
|
for row in csvreader: |
|
rows.append(row) |
|
|
|
|
|
|
|
|
|
headers.append(new_column_header) |
|
|
|
|
|
for row in rows: |
|
old_value = row[index_to_read] |
|
metadata = metadata_extractor.extract_metadata(old_value) |
|
new_value = "UNKNOWN" |
|
if metadata is not None and metadata.discussion_status_enum is not None: |
|
if (metadata.discussion_status_enum.value == "RESOLVED_OR_CLOSED" or |
|
metadata.discussion_status_enum.value == "ON_GOING" or |
|
metadata.discussion_status_enum.value == "UNKNOWN"): |
|
new_value = metadata.discussion_status_enum.value |
|
row.append(new_value) |
|
|
|
|
|
with open(output_csv, mode='w', newline='', encoding='utf-8') as outfile: |
|
csvwriter = csv.writer(outfile) |
|
csvwriter.writerow(headers) |
|
for row in rows: |
|
csvwriter.writerow(row) |
|
|
|
print(f"Successfully added a new column '{new_column_header}' to the '{output_csv}' file.") |
|
|