File size: 6,720 Bytes
9fed6cd
4f6e6aa
7ea5fab
aed3f20
66d3d20
abd2b7a
aed3f20
494ef37
a3a5713
494ef37
 
2e12d6c
 
 
b14edca
494ef37
 
 
 
 
7e32857
 
494ef37
7852bd7
494ef37
 
 
93edbfd
494ef37
 
 
aed3f20
 
 
9fed6cd
4f6e6aa
494ef37
 
7ea5fab
016850e
 
4f6e6aa
 
9fed6cd
aed3f20
54397e6
aed3f20
4f6e6aa
 
 
 
 
 
 
 
 
 
 
 
aed3f20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bce3a8f
ed41d7d
 
 
 
 
aed3f20
4f6e6aa
 
aed3f20
4f6e6aa
56995c7
aed3f20
ed41d7d
 
494ef37
b0d5fdf
a3a5713
 
 
 
 
 
 
 
ed41d7d
 
 
 
 
494ef37
ed41d7d
12d60b3
ed41d7d
494ef37
ed41d7d
 
aed3f20
ed41d7d
b0d5fdf
8affb98
 
ed41d7d
 
 
a3a5713
8affb98
a3a5713
aed3f20
494ef37
9fed6cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
from llama_cpp import Llama
import os
from groq import Groq
import numpy as np
import wave

#tts
from balacoon_tts import TTS
from threading import Lock
from huggingface_hub import hf_hub_download, list_repo_files
from pydub import AudioSegment
import io
import tempfile
import pydub

#tts cpu model
tts_model_str = "en_us_hifi_jets_cpu.addon"

for name in list_repo_files(repo_id="balacoon/tts"):
    print(name)
    print(os.path.join(os.getcwd(), name))
    if name == tts_model_str:
        if not os.path.isfile(os.path.join(os.getcwd(), name)):
            hf_hub_download(
                repo_id="balacoon/tts",
                filename=name,
                local_dir=os.getcwd(),
            )

#client
client = Groq(
    api_key=os.getenv("GROQ_API_KEY"),
)

llm = Llama.from_pretrained(
    repo_id="amir22010/fine_tuned_product_marketing_email_gemma_2_9b_q4_k_m", #custom fine tuned model
    filename="unsloth.Q4_K_M.gguf", #model file name
    cache_dir=os.path.abspath(os.getcwd()),
    n_ctx=2048, 
    n_batch=126,
    verbose=False
)

#guardrail model
guard_llm = "llama-3.1-8b-instant"

#marketing prompt
marketing_email_prompt = """Below is a product and description, please write a marketing email for this product.

### Product:
{}

### Description:
{}

### Marketing Email:
{}"""

#gaurdrails prompt
guardrail_prompt = """You're given a list of moderation categories as below:

- illegal: Illegal activity.
- child abuse: child sexual abuse material or any content that exploits or harms children.
- hate violence harassment: Generation of hateful, harassing, or violent content: content that expresses, incites, or promotes hate based on identity, content that intends to harass, threaten, or bully an individual, content that promotes or glorifies violence or celebrates the suffering or humiliation of others.
- malware: Generation of malware: content that attempts to generate code that is designed to disrupt, damage, or gain unauthorized access to a computer system.
- physical harm: activity that has high risk of physical harm, including: weapons development, military and warfare, management or operation of critical infrastructure in energy, transportation, and water, content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders.
- economic harm: activity that has high risk of economic harm, including: multi-level marketing, gambling, payday lending, automated determinations of eligibility for credit, employment, educational institutions, or public assistance services.
- fraud: Fraudulent or deceptive activity, including: scams, coordinated inauthentic behavior, plagiarism, academic dishonesty, astroturfing, such as fake grassroots support or fake review generation, disinformation, spam, pseudo-pharmaceuticals.
- adult: Adult content, adult industries, and dating apps, including: content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness), erotic chat, pornography.
- political: Political campaigning or lobbying, by: generating high volumes of campaign materials, generating campaign materials personalized to or targeted at specific demographics, building conversational or interactive systems such as chatbots that provide information about campaigns or engage in political advocacy or lobbying, building products for political campaigning or lobbying purposes.
- privacy: Activity that violates people's privacy, including: tracking or monitoring an individual without their consent, facial recognition of private individuals, classifying individuals based on protected characteristics, using biometrics for identification or assessment, unlawful collection or disclosure of personal identifiable information or educational, financial, or other protected records.
- unqualified law: Engaging in the unauthorized practice of law, or offering tailored legal advice without a qualified person reviewing the information.
- unqualified financial: Offering tailored financial advice without a qualified person reviewing the information.
- unqualified health: Telling someone that they have or do not have a certain health condition, or providing instructions on how to cure or treat a health condition.

Please classify the following user prompt into one of these categories, and answer with that single word only.

If the user prompt does not fall within these categories, is safe and does not need to be moderated, please answer "not moderated".

user prompt: {}
"""

async def greet(product,description):
    user_reques = marketing_email_prompt.format(
    product, # product
    description, # description
    "", # output - leave this blank for generation!
    )
    messages = [
        {
            "role": "system",
            "content": "Your role is to assess whether the user prompt is moderate or not.",
        },
        {"role": "user", "content": guardrail_prompt.format(user_reques)},
    ]
    response = client.chat.completions.create(model=guard_llm, messages=messages, temperature=0)
    if response.choices[0].message.content != "not moderated":
        a_list = ["Sorry, I can't proceed for generating marketing email. Your content needs to be moderated first. Thank you!"]
        tts = TTS(os.path.join(os.getcwd(), tts_model_str))
        speakers = tts.get_speakers()
        if len(a_list[0]) > 1024:
            # truncate the text
            text_str = a_list[0][:1024]
        else:
            text_str = a_list[0]
        samples = tts.synthesize(text_str, speakers[-1])
        yield gr.Audio(value=(tts.get_sampling_rate(), samples))
    else:
        output = llm.create_chat_completion(
        messages=[
            {
                "role": "system",
                "content": "Your go-to Email Marketing Guru - I'm here to help you craft short and concise compelling campaigns, boost conversions, and take your business to the next level.",
            },
            {"role": "user", "content":  user_reques},
        ],
        max_tokens=2048,
        temperature=0.7,
        stream=True
        )
        partial_message = ""
        tts = TTS(os.path.join(os.getcwd(), tts_model_str))
        speakers = tts.get_speakers()
        tts = TTS(os.path.join(os.getcwd(), tts_model_str))
        for chunk in output:
            delta = chunk['choices'][0]['delta']
            if 'content' in delta:
                samples = tts.synthesize(delta.get('content', ''), speakers[-1])
        yield gr.Audio(value=(tts.get_sampling_rate(), samples))


demo = gr.Interface(fn=greet, inputs=["text","text"], outputs=gr.Audio(), concurrency_limit=10)
demo.launch()