File size: 5,296 Bytes
8824f88 46218ed 05c0b89 8824f88 0adfa8b ba0ce5d 8824f88 c4bb11b 8824f88 05c0b89 ba0ce5d 8824f88 c4bb11b 8824f88 46218ed 05c0b89 8824f88 05c0b89 8824f88 05c0b89 8824f88 05c0b89 8824f88 ba0ce5d 8824f88 17aeee6 d355157 e466982 c4bb11b 8824f88 c4bb11b 8824f88 b1038ed 8824f88 2c31d1f 8824f88 ba0ce5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
#!/usr/bin/env python
import os
import random
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from chat_interface_preference import ChatInterface
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
if torch.cuda.is_available():
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.06,
top_p: float = 0.95,
top_k: int = 40,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
system_message = random.choice(["concise", "explicit", "simple", "complex", "usefull", "helpfull"])
conversation = [{"role": "system", "content": f"Communicate {system_message}."}]
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = ChatInterface(
fn=generate,
prefence_techniques="dpo",
min_turns=1,
max_turns=10,
repo_id="llm-human-feedback-collector-chat-interface-dpo",
chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
cache_examples=False,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.05,
maximum=1.2,
step=0.05,
value=0.2,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
examples=[
["""What word doesn't make sense in this row: "car, airplane, lama, bus"?"""],
["Write a news article about the usage of Lama's by the CSI"],
["What are great things cook when getting started with Asian cooking?"],
["Who was Anthony Bourdain?"],
],
title="💪🏽🦾 Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (DPO) 🦾💪🏽",
description="".join(
[
"This is an adaptation of the [`gr.ChatInferface`](https://www.gradio.app/docs/gradio/chatinterface) and [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler) which allows for human feedback collection. ",
"Another cool tool for capturing Gradio interactions is the [`gr.HuggingFaceDatasetSaver`](https://www.gradio.app/guides/using-flagging#the-hugging-face-dataset-saver-callback). ",
"This demo shows how you might capture human feedback directly from applications within Gradio. ",
"The captured feedback can directly be used for fine-tuning LLMs within framework like [transformers](https://github.com/huggingface/transformers), [TRL](https://github.com/huggingface/trl) or [AutoTrain](https://huggingface.co/autotrain), ",
"however, it might benefit from additional data curation with something like [Argilla](https://github.com/argilla-io/argilla/) for human feedback and/or [distilabel](https://github.com/argilla-io/distilabel/) for AI feedback. Argilla can even be [deployed for free on Hugging Face Spaces](https://argilla-io.github.io/argilla/latest/getting_started/huggingface-spaces/).",
]
),
)
with gr.Blocks(css="style.css") as demo:
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|