Commit
•
a38b23a
1
Parent(s):
7c45d36
Update error message duplicate mistakes
Browse files- README.md +1 -1
- app.py +8 -4
- chat_interface_preference.py +87 -77
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🦾💪🏽
|
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
license: mit
|
|
|
4 |
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.39
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
license: mit
|
app.py
CHANGED
@@ -1,13 +1,17 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
import os
|
3 |
import random
|
4 |
-
from threading import Thread
|
5 |
from typing import Iterator
|
6 |
|
7 |
import gradio as gr
|
8 |
import spaces
|
9 |
-
import torch
|
10 |
-
from transformers import
|
|
|
|
|
|
|
|
|
11 |
|
12 |
from chat_interface_preference import ChatInterface
|
13 |
|
@@ -118,7 +122,7 @@ chat_interface = ChatInterface(
|
|
118 |
title="💪🏽🦾 Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (DPO) 🦾💪🏽",
|
119 |
description="".join(
|
120 |
[
|
121 |
-
"This is an adaptation of the [`gr.ChatInferface`](https://www.gradio.app/docs/gradio/chatinterface)
|
122 |
"Another cool tool for capturing Gradio interactions is the [`gr.HuggingFaceDatasetSaver`](https://www.gradio.app/guides/using-flagging#the-hugging-face-dataset-saver-callback). ",
|
123 |
"This demo shows how you might capture human feedback directly from applications within Gradio. ",
|
124 |
"The captured feedback can directly be used for fine-tuning LLMs within framework like [transformers](https://github.com/huggingface/transformers), [TRL](https://github.com/huggingface/trl) or [AutoTrain](https://huggingface.co/autotrain), ",
|
|
|
1 |
#!/usr/bin/env python
|
2 |
import os
|
3 |
import random
|
4 |
+
from threading import Thread # noqa
|
5 |
from typing import Iterator
|
6 |
|
7 |
import gradio as gr
|
8 |
import spaces
|
9 |
+
import torch # noqa
|
10 |
+
from transformers import (
|
11 |
+
AutoModelForCausalLM, # noqa
|
12 |
+
AutoTokenizer, # noqa
|
13 |
+
TextIteratorStreamer, # noqa
|
14 |
+
)
|
15 |
|
16 |
from chat_interface_preference import ChatInterface
|
17 |
|
|
|
122 |
title="💪🏽🦾 Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (DPO) 🦾💪🏽",
|
123 |
description="".join(
|
124 |
[
|
125 |
+
"This is an adaptation of the [`gr.ChatInferface`](https://www.gradio.app/docs/gradio/chatinterface) which also uses the [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler) to allow for human feedback collection. ",
|
126 |
"Another cool tool for capturing Gradio interactions is the [`gr.HuggingFaceDatasetSaver`](https://www.gradio.app/guides/using-flagging#the-hugging-face-dataset-saver-callback). ",
|
127 |
"This demo shows how you might capture human feedback directly from applications within Gradio. ",
|
128 |
"The captured feedback can directly be used for fine-tuning LLMs within framework like [transformers](https://github.com/huggingface/transformers), [TRL](https://github.com/huggingface/trl) or [AutoTrain](https://huggingface.co/autotrain), ",
|
chat_interface_preference.py
CHANGED
@@ -607,7 +607,7 @@ class ChatInterface(Blocks):
|
|
607 |
if turn[-1]:
|
608 |
conversation += self._get_chat_message(turn[-1], role="user", turn=(idx + 1))
|
609 |
|
610 |
-
return "<body>" +
|
611 |
|
612 |
def _get_conversation_in_openai_format(self, history):
|
613 |
conversation = []
|
@@ -644,6 +644,7 @@ class ChatInterface(Blocks):
|
|
644 |
|
645 |
@staticmethod
|
646 |
def _check_if_two_responses(response):
|
|
|
647 |
if response:
|
648 |
matches = pattern.findall(response)
|
649 |
return matches
|
@@ -683,30 +684,34 @@ class ChatInterface(Blocks):
|
|
683 |
|
684 |
self._check_message(message)
|
685 |
self._check_num_turns(history)
|
686 |
-
|
|
|
|
|
|
|
687 |
if self._check_if_two_responses(response):
|
688 |
-
|
|
|
|
|
|
|
689 |
|
690 |
-
|
|
|
|
|
|
|
|
|
|
|
691 |
|
692 |
-
|
693 |
-
|
694 |
-
response = await self.fn(*inputs)
|
695 |
else:
|
696 |
-
|
697 |
-
|
698 |
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
if self.multimodal and isinstance(message, dict):
|
706 |
-
self._append_multimodal_history(message, response, history)
|
707 |
-
elif isinstance(message, str):
|
708 |
-
history.append([message, response])
|
709 |
-
return history, history
|
710 |
|
711 |
async def _stream_fn(
|
712 |
self,
|
@@ -723,67 +728,35 @@ class ChatInterface(Blocks):
|
|
723 |
history = history_with_input[:-1]
|
724 |
self._check_message(message)
|
725 |
self._check_num_turns(history)
|
726 |
-
_, response = history_with_input[-1]
|
727 |
-
if self._check_if_two_responses(response):
|
728 |
-
raise Error("Two options detected: undo, log or random pick continuation.")
|
729 |
-
|
730 |
-
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
731 |
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
else:
|
742 |
-
first_response_formatted = first_response
|
743 |
-
if self.multimodal and isinstance(message, dict):
|
744 |
-
for x in message["files"]:
|
745 |
-
history.append([(x,), None])
|
746 |
-
update = history + [[message["text"], first_response_formatted]]
|
747 |
-
yield update, update
|
748 |
-
else:
|
749 |
-
update = history + [[message, first_response_formatted]]
|
750 |
-
yield update, update
|
751 |
-
except StopIteration:
|
752 |
-
if self.multimodal and isinstance(message, dict):
|
753 |
-
self._append_multimodal_history(message, None, history)
|
754 |
-
yield history, history
|
755 |
-
else:
|
756 |
-
update = history + [[message, None]]
|
757 |
-
yield update, update
|
758 |
-
async for response in generator:
|
759 |
-
if n_generations == 2:
|
760 |
-
response_formatted = self._get_chat_message_comparison(response, "")
|
761 |
-
else:
|
762 |
-
response_formatted = response
|
763 |
-
if self.multimodal and isinstance(message, dict):
|
764 |
-
update = history + [[message["text"], response_formatted]]
|
765 |
-
yield update, update
|
766 |
-
else:
|
767 |
-
update = history + [[message, response_formatted]]
|
768 |
-
yield update, update
|
769 |
|
770 |
-
if n_generations == 2:
|
771 |
-
if self.is_async:
|
772 |
-
generator_two = self.fn(*inputs)
|
773 |
-
else:
|
774 |
-
generator_two = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
775 |
-
generator_two = SyncToAsyncIterator(generator_two, self.limiter)
|
776 |
try:
|
777 |
-
|
778 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
779 |
if self.multimodal and isinstance(message, dict):
|
780 |
for x in message["files"]:
|
781 |
history.append([(x,), None])
|
782 |
-
|
783 |
-
update = history + [[message["text"], first_response_two_formatted]]
|
784 |
yield update, update
|
785 |
else:
|
786 |
-
update = history + [[message,
|
787 |
yield update, update
|
788 |
except StopIteration:
|
789 |
if self.multimodal and isinstance(message, dict):
|
@@ -792,15 +765,52 @@ class ChatInterface(Blocks):
|
|
792 |
else:
|
793 |
update = history + [[message, None]]
|
794 |
yield update, update
|
795 |
-
async for
|
796 |
-
|
|
|
|
|
|
|
797 |
if self.multimodal and isinstance(message, dict):
|
798 |
-
update = history + [[message["text"],
|
799 |
yield update, update
|
800 |
else:
|
801 |
-
update = history + [[message,
|
802 |
yield update, update
|
803 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
804 |
async def _log_fn(
|
805 |
self, message: str | dict[str, list], history: list[list[str | tuple | None]], log: str
|
806 |
) -> tuple[
|
|
|
607 |
if turn[-1]:
|
608 |
conversation += self._get_chat_message(turn[-1], role="user", turn=(idx + 1))
|
609 |
|
610 |
+
return "<body>" + conversation + "</body>"
|
611 |
|
612 |
def _get_conversation_in_openai_format(self, history):
|
613 |
conversation = []
|
|
|
644 |
|
645 |
@staticmethod
|
646 |
def _check_if_two_responses(response):
|
647 |
+
print(response)
|
648 |
if response:
|
649 |
matches = pattern.findall(response)
|
650 |
return matches
|
|
|
684 |
|
685 |
self._check_message(message)
|
686 |
self._check_num_turns(history)
|
687 |
+
if history:
|
688 |
+
_, response = history[-1]
|
689 |
+
else:
|
690 |
+
response = None
|
691 |
if self._check_if_two_responses(response):
|
692 |
+
Info("Two options detected: provide preference, undo or clear to continue conversation.")
|
693 |
+
return history, history
|
694 |
+
else:
|
695 |
+
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
696 |
|
697 |
+
async def _get_response():
|
698 |
+
if self.is_async:
|
699 |
+
response = await self.fn(*inputs)
|
700 |
+
else:
|
701 |
+
response = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
702 |
+
return response
|
703 |
|
704 |
+
if n_generations == 1:
|
705 |
+
response = await _get_response()
|
|
|
706 |
else:
|
707 |
+
response_one, response_two = await _get_response(), await _get_response()
|
708 |
+
response = self._get_chat_message_comparison(response_one, response_two)
|
709 |
|
710 |
+
if self.multimodal and isinstance(message, dict):
|
711 |
+
self._append_multimodal_history(message, response, history)
|
712 |
+
elif isinstance(message, str):
|
713 |
+
history.append([message, response])
|
714 |
+
return history, history
|
|
|
|
|
|
|
|
|
|
|
|
|
715 |
|
716 |
async def _stream_fn(
|
717 |
self,
|
|
|
728 |
history = history_with_input[:-1]
|
729 |
self._check_message(message)
|
730 |
self._check_num_turns(history)
|
|
|
|
|
|
|
|
|
|
|
731 |
|
732 |
+
if history:
|
733 |
+
_, response = history[-1]
|
734 |
+
else:
|
735 |
+
response = None
|
736 |
+
if self._check_if_two_responses(response):
|
737 |
+
Info("Two options detected: provide preference, undo or clear to continue conversation.")
|
738 |
+
yield history, history
|
739 |
+
else:
|
740 |
+
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
741 |
|
|
|
|
|
|
|
|
|
|
|
|
|
742 |
try:
|
743 |
+
if self.is_async:
|
744 |
+
generator = self.fn(*inputs)
|
745 |
+
else:
|
746 |
+
generator = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
747 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
748 |
+
first_response = await async_iteration(generator)
|
749 |
+
if n_generations == 2:
|
750 |
+
first_response_formatted = self._get_chat_message_comparison(first_response, "")
|
751 |
+
else:
|
752 |
+
first_response_formatted = first_response
|
753 |
if self.multimodal and isinstance(message, dict):
|
754 |
for x in message["files"]:
|
755 |
history.append([(x,), None])
|
756 |
+
update = history + [[message["text"], first_response_formatted]]
|
|
|
757 |
yield update, update
|
758 |
else:
|
759 |
+
update = history + [[message, first_response_formatted]]
|
760 |
yield update, update
|
761 |
except StopIteration:
|
762 |
if self.multimodal and isinstance(message, dict):
|
|
|
765 |
else:
|
766 |
update = history + [[message, None]]
|
767 |
yield update, update
|
768 |
+
async for response in generator:
|
769 |
+
if n_generations == 2:
|
770 |
+
response_formatted = self._get_chat_message_comparison(response, "")
|
771 |
+
else:
|
772 |
+
response_formatted = response
|
773 |
if self.multimodal and isinstance(message, dict):
|
774 |
+
update = history + [[message["text"], response_formatted]]
|
775 |
yield update, update
|
776 |
else:
|
777 |
+
update = history + [[message, response_formatted]]
|
778 |
yield update, update
|
779 |
|
780 |
+
if n_generations == 2:
|
781 |
+
if self.is_async:
|
782 |
+
generator_two = self.fn(*inputs)
|
783 |
+
else:
|
784 |
+
generator_two = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
785 |
+
generator_two = SyncToAsyncIterator(generator_two, self.limiter)
|
786 |
+
try:
|
787 |
+
first_response_two = await async_iteration(generator_two)
|
788 |
+
first_response_two_formatted = self._get_chat_message_comparison(response, first_response_two)
|
789 |
+
if self.multimodal and isinstance(message, dict):
|
790 |
+
for x in message["files"]:
|
791 |
+
history.append([(x,), None])
|
792 |
+
|
793 |
+
update = history + [[message["text"], first_response_two_formatted]]
|
794 |
+
yield update, update
|
795 |
+
else:
|
796 |
+
update = history + [[message, first_response_two_formatted]]
|
797 |
+
yield update, update
|
798 |
+
except StopIteration:
|
799 |
+
if self.multimodal and isinstance(message, dict):
|
800 |
+
self._append_multimodal_history(message, None, history)
|
801 |
+
yield history, history
|
802 |
+
else:
|
803 |
+
update = history + [[message, None]]
|
804 |
+
yield update, update
|
805 |
+
async for response_two in generator_two:
|
806 |
+
response_two = self._get_chat_message_comparison(response, response_two)
|
807 |
+
if self.multimodal and isinstance(message, dict):
|
808 |
+
update = history + [[message["text"], response_two]]
|
809 |
+
yield update, update
|
810 |
+
else:
|
811 |
+
update = history + [[message, response_two]]
|
812 |
+
yield update, update
|
813 |
+
|
814 |
async def _log_fn(
|
815 |
self, message: str | dict[str, list], history: list[list[str | tuple | None]], log: str
|
816 |
) -> tuple[
|