File size: 2,840 Bytes
5e303ca
 
fe0ac3f
9d1c8f9
 
 
 
 
5e303ca
 
 
9e61cec
5e303ca
 
 
64109dd
9e61cec
 
 
9d1c8f9
9e61cec
 
5e303ca
fe0ac3f
 
9d1c8f9
fe0ac3f
 
 
 
 
 
9d1c8f9
 
 
 
 
 
 
fe0ac3f
 
9d1c8f9
 
 
 
 
 
 
 
 
 
 
 
fe0ac3f
 
9d1c8f9
 
 
 
fe0ac3f
9d1c8f9
 
 
fe0ac3f
9d1c8f9
 
 
fe0ac3f
64109dd
9d1c8f9
 
 
 
 
 
 
 
fe0ac3f
9d1c8f9
9e61cec
 
c350f94
 
 
 
 
 
3c20001
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import time
import numpy as np
from torch.nn import functional as F
import os
from threading import Thread


title = "🦅Falcon 🗨️ChatBot"
description = "Falcon-RW-1B is a 1B parameters causal decoder-only model built by TII and trained on 350B tokens of RefinedWeb."
examples = [["How are you?"]]


tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")
model = AutoModelForCausalLM.from_pretrained(
    "tiiuae/falcon-rw-1b",
    trust_remote_code=True,
    torch_dtype=torch.float16,
)


class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [0]
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


def user(message, history):
    # Append the user's message to the conversation history
    return "", history + [[message, ""]]


def chat(curr_system_message, history):
    # Initialize a StopOnTokens object
    stop = StopOnTokens()

    # Construct the input message string for the model by concatenating the current system message and conversation history
    messages = curr_system_message + \
        "".join(["".join(["<user>: "+item[0], "<chatbot>: "+item[1]])
                for item in history])
        
    # Tokenize the messages string
    tokens = tokenizer([messages], return_tensors="pt")
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
   
    token_ids = tokens.input_ids
    attention_mask=tokens.attention_mask
    
    generate_kwargs = dict(
       input_ids=token_ids, 
        attention_mask = attention_mask,  
        streamer = streamer,
        max_length=2048,
        do_sample=True,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id, 
        temperature = 0.7,
        stopping_criteria=StoppingCriteriaList([stop])
    ) 
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()    

    #Initialize an empty string to store the generated text
    partial_text = ""
    for new_text in streamer:
        # print(new_text)
        partial_text += new_text
        history[-1][1] = partial_text
        # Yield an empty string to cleanup the message textbox and the updated conversation history
        yield history
    return partial_text        

gr.ChatInterface(chat,
    title=title,
    description=description,
    examples=examples,
    cache_examples=True,
    retry_btn=None,
    undo_btn="Delete Previous",
    clear_btn="Clear",
    chatbot=gr.Chatbot(height=300),
    textbox=gr.Textbox(placeholder="Chat with me")).queue().launch()