from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread import torch import re import solara from typing import List from typing_extensions import TypedDict class MessageDict(TypedDict): role: str content: str # Load the model and tokenizer # Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Ammartatox/newqwen1e") tokenizer = AutoTokenizer.from_pretrained("Ammartatox/newqwen1e") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) def response_generator(message, stream_length=1): text = tokenizer.apply_chat_template( [{"role": "user", "content": message}], tokenize=False, add_generation_prompt=True ) inputs = tokenizer(text, return_tensors="pt") generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() for i, chunk in enumerate(streamer): if i >= stream_length: break yield chunk def add_chunk_to_ai_message(chunk: str): messages.value = [ *messages.value[:-1], { "role": "assistant", "content": messages.value[-1]["content"] + chunk, }, ] messages: solara.Reactive[List[MessageDict]] = solara.reactive([]) @solara.component def Page(): solara.lab.theme.themes.light.primary = "#0000ff" solara.lab.theme.themes.light.secondary = "#0000ff" solara.lab.theme.themes.dark.primary = "#0000ff" solara.lab.theme.themes.dark.secondary = "#0000ff" title = "Mawared hr gpt" with solara.Head(): solara.Title(f"{title}") with solara.Column(align="center"): user_message_count = len([m for m in messages.value if m["role"] == "user"]) def send(message): messages.value = [*messages.value, {"role": "user", "content": message}] def response(message): messages.value = [*messages.value, {"role": "assistant", "content": ""}] for chunk in response_generator(message, stream_length=10): # Adjust stream_length as needed add_chunk_to_ai_message(chunk) def result(): if messages.value != []: response(messages.value[-1]["content"]) result = solara.lab.use_task(result, dependencies=[user_message_count]) with solara.lab.ChatBox(style={"position": "fixed", "overflow-y": "scroll", "scrollbar-width": "none", "-ms-overflow-style": "none", "top": "0", "bottom": "10rem", "width": "70%"}): for item in messages.value: with solara.lab.ChatMessage( user=item["role"] == "user", name="User" if item["role"] == "user" else "AI", avatar_background_color="#33cccc" if item["role"] == "assistant" else "#ff991f", border_radius="20px", style="background-color:darkgrey!important;" if solara.lab.theme.dark_effective else "background-color:lightgrey!important;" ): item["content"] = re.sub('<\|im_end\|>', '', item["content"]) solara.Markdown(item["content"]) solara.lab.ChatInput(send_callback=send, style={"position": "fixed", "bottom": "3rem", "width": "70%"})