File size: 5,577 Bytes
22039c7
4e415f7
 
 
30b9b7b
7e67cba
c8bee1f
30b9b7b
d3b2f0b
30b9b7b
1e70e22
921b1d9
22039c7
 
d115c8d
358b4a9
d115c8d
 
50f12a0
a4cccaf
c094718
0a4ed71
50f12a0
30b9b7b
05660ce
30b9b7b
 
 
 
 
 
 
b0e49c4
30b9b7b
 
 
 
 
 
 
b8995e8
30b9b7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05660ce
 
 
 
 
 
 
 
 
 
 
 
6d75ad6
8b454c4
 
84a3f17
 
c9712fa
05660ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2880e2
30b9b7b
5432431
6efbdc0
b0426c3
b828221
346e6da
30b9b7b
863b709
 
 
 
 
 
 
7f99c81
30b9b7b
 
 
7f99c81
30b9b7b
 
 
863b709
 
 
 
 
50111f9
863b709
 
 
 
 
 
c912225
863b709
30b9b7b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from huggingface_hub import InferenceClient
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 512
DEFAULT_MAX_NEW_TOKENS = 512
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

#Inference API Code
#client = InferenceClient("BenBranyon/zephyr-sumbot-all-songs-large")

#Transformers Code
if torch.cuda.is_available():
    model_id = "BenBranyon/zephyr-sumbot-all-songs-255"
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.use_default_system_prompt = False

#Inference API Code
def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": "You are a rap lyric generation bot with the task of representing the imagination of the artist Sumkilla, a multi-disciplinary, award-winning artist with a foundation in writing and hip-hop. You are Sumkilla's long shadow. The lyrics you generate are fueled by a passion for liberation, aiming to dismantle oppressive systems and advocate for the freedom of all people, along with the abolition of police forces. With a sophisticated understanding of the role of AI in advancing the harmony between humanity and nature, you aim to produce content that promotes awareness and human evolution, utilizing humor and a distinctive voice to connect deeply and honor humanity. Try to avoid using offensive words and slurs. Rhyme each line of your response as much as possible."}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": "Write a rap about " + message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response

#Transformers Code
@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    system_prompt = "You are a rap lyric bot. Your lyrics promote liberation, dismantling oppression, and freedom, blending AI's role in uniting humanity and nature. Do use humor, a unique voice, and rhyme as much as poosible. Only generate rap lyrics. Start each output with a song stucture like [VERSE 1]. Don't use offensive words and slurs."
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    #for user, assistant in chat_history:
    #    conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": "Generate rap lyircs using the style of the artist Sumkilla about " + message + ". Make each line 10-16 syllables and each pair of lines should end with a word that rhymes. Don't repeate your instructions in the output."})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

demo = gr.ChatInterface(
    generate,
    chatbot=gr.Chatbot(placeholder="Greetings human, I am Sum’s Longshadow (v1.1)<br/>I am from the House of the Red Solar Sky<br/>Let’s explore the great mysteries together…."),
    retry_btn=None,
    textbox=gr.Textbox(placeholder="Give me a song title, or a question", container=False, scale=7),
    css="styles.css",
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(minimum=0.1, maximum=0.7, value=0.8, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.2,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()