File size: 5,350 Bytes
e294914
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import time
import logging

from llm.apimodels.gemini_model import Gemini
from llm.apimodels.hf_model import HF_Mistaril, HF_TinyLlama
from llm.llamacpp.lc_model import LC_Phi3, LC_TinyLlama

from typing import Optional, Any

from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

file_handler = logging.FileHandler(
    "logs/chelsea_llm_chat.log")  # for all modules here template for logs file is "llm/logs/chelsea_{module_name}_{entity}.log"
logger.setLevel(logging.INFO)  # informed

formatted = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatted)

logger.addHandler(file_handler)
logger.info("Getting information from chat module")

def prettify(raw_text: str) -> str:        
    pretty = raw_text.replace("**", "")
    return pretty.strip()

# option + command + F -> replace

memory: ConversationBufferWindowMemory = ConversationBufferWindowMemory(k=3, ai_prefix="Chelsea")

DELAY: int = 300  # 5 minutes

def has_failed(conversation, prompt) -> Optional[str]:
    """
    Checks if the LLM conversation prediction fails and returns None if so.

    Args:
        conversation: The LLM conversation object used for prediction.
        prompt: The prompt to be used for prediction.

    Returns:
        None, otherwise the prettified response.
    """

    try:
        response = conversation.predict(input=prompt)
        print(f"response: {response}")
        result = prettify(raw_text=response)
        return result
    except Exception as e:
        logger.error(msg="Error during prediction with conversation in has_failed function", exc_info=e)
        print(f"Error during prediction with conversation in has_failed function: {e}")
        return None


def has_delay(conversation, prompt) -> Optional[str]:
    """
    Checks if the LLM conversation prediction takes longer than a set delay.

    Args:
        conversation: The LLM conversation object used for prediction.
        prompt: The prompt to be used for prediction.

    Returns:
        None if the execution time exceeds the delay,
        otherwise, the prettified response from the conversation object.
    """

    start_time = time.perf_counter()  # Start timer before prediction
    try:
        response = conversation.predict(input=prompt)
        execution_time = time.perf_counter() - start_time  # Calculate execution time

        if execution_time > DELAY: 
            return None  # Return None if delayed
        
        result = prettify(raw_text=response)  # Prettify the response
        return result  # Return the prettified response

    except Exception as e:
        logger.error(msg="Error during prediction with conversation in has_delay function", exc_info=e)
        print(f"Error during prediction with conversation in has_delay function: {e}")


class Conversation:
    def __init__(self):
        """
        Initializes the Conversation class with a prompt and a list of LLM model classes.

        Args:
            model_classes (list, optional): A list of LLM model classes to try in sequence.
                Defaults to [Gemini, HF_Mistaril, HF_TinyLlama, LC_Phi3, LC_TinyLlama].
        """

        self.model_classes = [Gemini, HF_Mistaril, HF_TinyLlama, LC_Phi3, LC_TinyLlama]
        self.current_model_index = 0

    def _get_conversation(self) -> Any:
        """
        Creates a ConversationChain object using the current model class.
        """
        try:
            current_model_class = self.model_classes[self.current_model_index]
            print("current model class is: ", current_model_class)
            return ConversationChain(llm=current_model_class().execution(), memory=memory, return_final_only=True)
        except Exception as e:
            logger.error(msg="Error during conversation chain in get_conversation function", exc_info=e)
            print(f"Error during conversation chain in get_conversation function: {e}")

    def chatting(self, prompt: str) -> str:
        """
        Carries out the conversation with the user, handling errors and delays.

        Args: 
            prompt(str): The prompt to be used for prediction.

        Returns:
            Optional[str]: The final conversation response or None if all models fail.
        """

        if prompt is None or prompt == "":
            raise Exception(f"Prompt must be string not None or empty string: {prompt}")

        while self.current_model_index < len(self.model_classes):
            conversation = self._get_conversation()

            result = has_failed(conversation=conversation, prompt=prompt)
            if result is not None:
                return result
            print(f"chat - chatting result : {result}")

            result = has_delay(conversation=conversation, prompt=prompt)
            if result is None:
                self.current_model_index += 1  # Switch to next model after delay
                continue

            return result

        return "All models failed conversation. Please, try again"
    
    def __str__(self) -> str:
        return f"prompt: {type(self.prompt)}"
    
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(prompt: {type(self.prompt)})"