davidberenstein1957 HF staff commited on
Commit
5579546
1 Parent(s): 088561c

small code updates

Browse files
Files changed (4) hide show
  1. app copy.py +1 -5
  2. app.py +1 -0
  3. chat_interface_preference.py +5 -9
  4. test.py +0 -2
app copy.py CHANGED
@@ -9,10 +9,6 @@ import spaces
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
- from chat_interface_preference import ChatInterface
13
-
14
- MAX_MAX_NEW_TOKENS = 2048
15
- DEFAULT_MAX_NEW_TOKENS = 1024
16
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
17
 
18
  if torch.cuda.is_available():
@@ -63,7 +59,7 @@ def generate(
63
  yield "".join(outputs)
64
 
65
 
66
- chat_interface = ChatInterface(
67
  fn=generate,
68
  chatbot=gr.Chatbot(
69
  height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
 
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
 
 
 
 
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
13
 
14
  if torch.cuda.is_available():
 
59
  yield "".join(outputs)
60
 
61
 
62
+ chat_interface = gr.ChatInterface(
63
  fn=generate,
64
  chatbot=gr.Chatbot(
65
  height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
app.py CHANGED
@@ -29,6 +29,7 @@ if torch.cuda.is_available():
29
  model_id = "Qwen/Qwen2-0.5B-Instruct"
30
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
31
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
32
  style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
33
 
34
  client = rg.Argilla(api_url="https://davidberenstein1957-argilla-gradio.hf.space", api_key="owner.apikey")
 
29
  model_id = "Qwen/Qwen2-0.5B-Instruct"
30
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
31
  tokenizer = AutoTokenizer.from_pretrained(model_id)
32
+
33
  style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
34
 
35
  client = rg.Argilla(api_url="https://davidberenstein1957-argilla-gradio.hf.space", api_key="owner.apikey")
chat_interface_preference.py CHANGED
@@ -23,8 +23,9 @@ from gradio.components import (
23
  get_component_instance,
24
  )
25
  from gradio.events import Dependency, on
26
- from gradio.helpers import Error, Info, Warning, special_args
27
  from gradio.helpers import create_examples as Examples # noqa: N812
 
28
  from gradio.layouts import Accordion, Group, Row
29
  from gradio.routes import Request
30
  from gradio.themes import ThemeClass as Theme
@@ -629,10 +630,6 @@ class ChatInterface(Blocks):
629
  n_generations: int = 1,
630
  *args,
631
  ) -> AsyncGenerator:
632
- _, response = history_with_input[-1]
633
- if self._check_if_two_responses(response):
634
- raise Error("Two options detected: undo, log or random pick continuation.")
635
-
636
  if self.multimodal and isinstance(message, dict):
637
  remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
638
  history = history_with_input[:-remove_input]
@@ -646,14 +643,13 @@ class ChatInterface(Blocks):
646
  generator = self.fn(*inputs)
647
  else:
648
  generator = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
649
- generator = SyncToAsyncIterator(generator, self.limiter)
650
  first_response = await async_iteration(generator)
651
  if n_generations == 2:
652
  first_response_formatted = self._get_chat_message_comparison(first_response, "")
653
  if self.multimodal and isinstance(message, dict):
654
  for x in message["files"]:
655
  history.append([(x,), None])
656
-
657
  update = history + [[message["text"], first_response_formatted]]
658
  yield update, update
659
  else:
@@ -670,10 +666,10 @@ class ChatInterface(Blocks):
670
  if n_generations == 2:
671
  response_formatted = self._get_chat_message_comparison(response, "")
672
  if self.multimodal and isinstance(message, dict):
673
- update = history + [[message["text"], response_formatted]]
674
  yield update, update
675
  else:
676
- update = history + [[message, response_formatted]]
677
  yield update, update
678
 
679
  if n_generations == 2:
 
23
  get_component_instance,
24
  )
25
  from gradio.events import Dependency, on
26
+ from gradio.helpers import Error, Info, Warning
27
  from gradio.helpers import create_examples as Examples # noqa: N812
28
+ from gradio.helpers import special_args
29
  from gradio.layouts import Accordion, Group, Row
30
  from gradio.routes import Request
31
  from gradio.themes import ThemeClass as Theme
 
630
  n_generations: int = 1,
631
  *args,
632
  ) -> AsyncGenerator:
 
 
 
 
633
  if self.multimodal and isinstance(message, dict):
634
  remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
635
  history = history_with_input[:-remove_input]
 
643
  generator = self.fn(*inputs)
644
  else:
645
  generator = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
646
+ generator = SyncToAsyncIterator(generator, self.limiter)
647
  first_response = await async_iteration(generator)
648
  if n_generations == 2:
649
  first_response_formatted = self._get_chat_message_comparison(first_response, "")
650
  if self.multimodal and isinstance(message, dict):
651
  for x in message["files"]:
652
  history.append([(x,), None])
 
653
  update = history + [[message["text"], first_response_formatted]]
654
  yield update, update
655
  else:
 
666
  if n_generations == 2:
667
  response_formatted = self._get_chat_message_comparison(response, "")
668
  if self.multimodal and isinstance(message, dict):
669
+ update = history + [[message["text"], response]]
670
  yield update, update
671
  else:
672
+ update = history + [[message, response]]
673
  yield update, update
674
 
675
  if n_generations == 2:
test.py CHANGED
@@ -1,8 +1,6 @@
1
  import random
2
 
3
 
4
- import argilla as rg
5
-
6
  from chat_interface_preference import ChatInterface
7
 
8
 
 
1
  import random
2
 
3
 
 
 
4
  from chat_interface_preference import ChatInterface
5
 
6