File size: 3,150 Bytes
aa4f694
8537019
9c0dccd
 
3e68ccf
 
 
 
724babe
 
 
9c0dccd
aa4f694
9c0dccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e68ccf
 
9c0dccd
724babe
 
 
9c0dccd
 
724babe
 
6d7d653
 
 
 
9c0dccd
6d7d653
 
 
724babe
 
3e68ccf
 
 
 
9c0dccd
 
3e68ccf
 
6d7d653
e55d16a
 
3e68ccf
724babe
9c0dccd
 
724babe
 
 
 
 
 
 
 
9c0dccd
724babe
 
 
 
 
 
 
 
 
 
 
 
 
9c0dccd
469fc38
724babe
3e68ccf
 
8537019
 
 
 
 
3e68ccf
8537019
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import logging
import requests
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_core.language_models import LLM

from global_config import GlobalConfig


HF_API_URL = f"https://api-inference.huggingface.co/models/{GlobalConfig.HF_LLM_MODEL_NAME}"
HF_API_HEADERS = {"Authorization": f"Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}"}

logger = logging.getLogger(__name__)


def get_hf_endpoint() -> LLM:
    """
    Get an LLM via the HuggingFaceEndpoint.

    :return: The LLM.
    """

    logger.debug('Getting LLM via HF endpoint')

    return HuggingFaceEndpoint(
        repo_id=GlobalConfig.HF_LLM_MODEL_NAME,
        max_new_tokens=GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
        top_k=40,
        top_p=0.95,
        temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
        repetition_penalty=1.03,
        streaming=True,
        huggingfacehub_api_token=GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
        return_full_text=False,
        stop_sequences=['</s>'],
    )


def hf_api_query(payload: dict) -> dict:
    """
    Invoke HF inference end-point API.

    :param payload: The prompt for the LLM and related parameters.
    :return: The output from the LLM.
    """

    try:
        response = requests.post(HF_API_URL, headers=HF_API_HEADERS, json=payload, timeout=15)
        result = response.json()
    except requests.exceptions.Timeout as te:
        logger.error('*** Error: hf_api_query timeout! %s', str(te))
        result = {}

    return result


def generate_slides_content(topic: str) -> str:
    """
    Generate the outline/contents of slides for a presentation on a given topic.

    :param topic: Topic on which slides are to be generated.
    :return: The content in JSON format.
    """

    with open(GlobalConfig.SLIDES_TEMPLATE_FILE, 'r', encoding='utf-8') as in_file:
        template_txt = in_file.read().strip()
        template_txt = template_txt.replace('<REPLACE_PLACEHOLDER>', topic)

    output = hf_api_query({
        'inputs': template_txt,
        'parameters': {
            'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
            'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
            'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
            'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
            'num_return_sequences': 1,
            'return_full_text': False,
            # "repetition_penalty": 0.0001
        },
        'options': {
            'wait_for_model': True,
            'use_cache': True
        }
    })

    output = output[0]['generated_text'].strip()
    # output = output[len(template_txt):]

    json_end_idx = output.rfind('```')
    if json_end_idx != -1:
        # logging.debug(f'{json_end_idx=}')
        output = output[:json_end_idx]

    logger.debug('generate_slides_content: output: %s', output)

    return output


if __name__ == '__main__':
    # results = get_related_websites('5G AI WiFi 6')
    #
    # for a_result in results.results:
    #     print(a_result.title, a_result.url, a_result.extract)

    # get_ai_image('A talk on AI, covering pros and cons')
    pass