slackdemo / metadata_extracter.py
svummidi's picture
POC for passive monitoring
a31ba66
raw
history blame
4.27 kB
from enum import Enum
from typing import List, Optional
from llama_index import ServiceContext
from llama_index.llms import OpenAI
from llama_index.llms.base import LLM
from llama_index.llms.openai_utils import to_openai_function
from pydantic import BaseModel, ValidationError
class SentimentEnum(str, Enum):
"""
Enum for predicted overall sentiment of a discussion thread
"""
POSITIVE = "POSITIVE"
NEGATIVE = "NEGATIVE"
NEUTRAL = "NEUTRAL"
MIXED = "MIXED"
UNKNOWN = "UNKNOWN"
class DiscussionStatusEnum(str, Enum):
"""
Enum for representing the predicted status of the discussion thread
"""
ON_GOING = "ON_GOING"
RESOLVED_OR_CLOSED = "RESOLVED_OR_CLOSED"
UNKNOWN = "UNKNOWN"
class ThreadMetadata(BaseModel):
""" Metadata of a discussion thread for topics and sentiment. Topics must be under 50 characters. """
list_of_positive_topics: List[str]
list_of_negative_topics: List[str]
overall_sentiment: Optional[SentimentEnum]
discussion_status_enum: Optional[DiscussionStatusEnum]
class MetadataExtractor:
def __init__(self, llm: LLM):
self.llm = llm
def extract_metadata(self, thread_summary: str) -> Optional[ThreadMetadata]:
"""
Extracts the metadata from the thread summary
:param thread_summary: of the thread
:return: metadata of the thread
"""
api_spec = to_openai_function(ThreadMetadata)
response = self.llm.complete(
"Analyze the thread summary: " + thread_summary,
functions=[api_spec],
)
function_call_resp = response.additional_kwargs["function_call"]["arguments"]
# print(f"Function calling spec: {function_call_resp}")
try:
return ThreadMetadata.parse_raw(function_call_resp)
except ValidationError:
print(f"Error while parsing the detected question metadata: {function_call_resp}")
return None
if __name__ == "__main__":
import csv
gpt_turbo: OpenAI = OpenAI(temperature=0, model="gpt-3.5-turbo")
service_context = ServiceContext.from_defaults(llm=gpt_turbo, chunk_size=1024)
metadata_extractor = MetadataExtractor(gpt_turbo)
# Input CSV file name
input_csv = "csv/platform-engg.csv"
# Output CSV file name
output_csv = "csv/platform-engg-updated.csv"
# Column header to read from
column_to_read = "Summary"
new_column_header = "Predicted Status"
metadata_extractor = MetadataExtractor(gpt_turbo)
# Initialize lists to hold the headers and rows
headers = []
rows = []
# Reading the input CSV file
with open(input_csv, mode='r', newline='', encoding='utf-8') as infile:
csvreader = csv.reader(infile)
headers = next(csvreader)
# Check if the column exists in the CSV
if column_to_read not in headers:
print(f"Error: Column '{column_to_read}' not found in the CSV file.")
exit(1)
# Get index of the column to read
index_to_read = headers.index(column_to_read)
# Read the rows into a list
for row in csvreader:
rows.append(row)
# Create a new header for the new column
# Add the new header to the headers list
headers.append(new_column_header)
# Iterate over each row to create a new column
for row in rows:
old_value = row[index_to_read]
metadata = metadata_extractor.extract_metadata(old_value)
new_value = "UNKNOWN"
if metadata is not None and metadata.discussion_status_enum is not None:
if (metadata.discussion_status_enum.value == "RESOLVED_OR_CLOSED" or
metadata.discussion_status_enum.value == "ON_GOING" or
metadata.discussion_status_enum.value == "UNKNOWN"):
new_value = metadata.discussion_status_enum.value
row.append(new_value)
# Writing to the output CSV file
with open(output_csv, mode='w', newline='', encoding='utf-8') as outfile:
csvwriter = csv.writer(outfile)
csvwriter.writerow(headers)
for row in rows:
csvwriter.writerow(row)
print(f"Successfully added a new column '{new_column_header}' to the '{output_csv}' file.")