shuvom commited on
Commit
ff96a82
·
1 Parent(s): 46e5164

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +107 -0
  2. requirement.txt +76 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, TextStreamer
3
+ from threading import Thread
4
+ import gradio as gr
5
+ from peft import PeftModel
6
+
7
+ model_name_or_path = "sarvamai/OpenHathi-7B-Hi-v0.1-Base"
8
+ peft_model_id = "shuvom/OpenHathi-7B-FT-v0.1_SI"
9
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, device_map="auto")
10
+
11
+ # tokenizer.chat_template = chat_template
12
+ tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
13
+ # make embedding resizing configurable?
14
+ model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
15
+
16
+ model = PeftModel.from_pretrained(model, peft_model_id)
17
+
18
+ class ChatCompletion:
19
+ def __init__(self, model, tokenizer, system_prompt=None):
20
+ self.model = model
21
+ self.tokenizer = tokenizer
22
+ self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
23
+ self.print_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
24
+ # set the model in inference mode
25
+ self.model.eval()
26
+ self.system_prompt = system_prompt
27
+
28
+ def get_completion(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
29
+ if temperature < 1e-2:
30
+ temperature = 1e-2
31
+ messages = []
32
+ if message_history is not None:
33
+ messages.extend(message_history)
34
+ elif system_prompt or self.system_prompt:
35
+ system_prompt = system_prompt or self.system_prompt
36
+ messages.append({"role": "system", "content":system_prompt})
37
+ messages.append({"role": "user", "content": prompt})
38
+ chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
39
+
40
+ inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
41
+ # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
42
+ generation_kwargs = dict(max_new_tokens=max_new_tokens,
43
+ temperature=temperature,
44
+ top_p=0.95,
45
+ do_sample=True,
46
+ eos_token_id=tokenizer.eos_token_id,
47
+ repetition_penalty=1.2
48
+ )
49
+ generated_text = self.model.generate(**inputs, streamer=self.print_streamer, **generation_kwargs)
50
+ return generated_text
51
+
52
+ def get_chat_completion(self, message, history):
53
+ messages = []
54
+ if self.system_prompt:
55
+ messages.append({"role": "system", "content":self.system_prompt})
56
+ for user_message, assistant_message in history:
57
+ messages.append({"role": "user", "content": user_message})
58
+ messages.append({"role": "system", "content": assistant_message})
59
+ messages.append({"role": "user", "content": message})
60
+ chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
+
62
+ inputs = self.tokenizer(chat_prompt, return_tensors="pt")
63
+ # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
64
+ generation_kwargs = dict(inputs,
65
+ streamer=self.streamer,
66
+ max_new_tokens=2048,
67
+ temperature=0.2,
68
+ top_p=0.95,
69
+ eos_token_id=tokenizer.eos_token_id,
70
+ do_sample=True,
71
+ repetition_penalty=1.2,
72
+ )
73
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
74
+ thread.start()
75
+ generated_text = ""
76
+ for new_text in self.streamer:
77
+ generated_text += new_text.replace(self.tokenizer.eos_token, "")
78
+ yield generated_text
79
+ thread.join()
80
+ return generated_text
81
+
82
+ def get_completion_without_streaming(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
83
+ if temperature < 1e-2:
84
+ temperature = 1e-2
85
+ messages = []
86
+ if message_history is not None:
87
+ messages.extend(message_history)
88
+ elif system_prompt or self.system_prompt:
89
+ system_prompt = system_prompt or self.system_prompt
90
+ messages.append({"role": "system", "content":system_prompt})
91
+ messages.append({"role": "user", "content": prompt})
92
+ chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93
+
94
+ inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
95
+ # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
96
+ generation_kwargs = dict(max_new_tokens=max_new_tokens,
97
+ temperature=temperature,
98
+ top_p=0.95,
99
+ do_sample=True,
100
+ repetition_penalty=1.1)
101
+ outputs = self.model.generate(**inputs, **generation_kwargs)
102
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
103
+ return generated_text
104
+
105
+ text_generator = ChatCompletion(model, tokenizer, system_prompt="You are a native Hindi speaker who can converse at expert level in both Hindi and colloquial Hinglish.")
106
+
107
+ gr.ChatInterface(text_generator.get_chat_completion).queue().launch(debug=True)
requirement.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.25.0
2
+ aiofiles==23.2.1
3
+ altair==5.2.0
4
+ annotated-types==0.6.0
5
+ anyio==4.2.0
6
+ attrs==23.2.0
7
+ bitsandbytes==0.41.3.post2
8
+ certifi==2023.11.17
9
+ charset-normalizer==3.3.2
10
+ click==8.1.7
11
+ colorama==0.4.6
12
+ contourpy==1.2.0
13
+ cycler==0.12.1
14
+ fastapi==0.108.0
15
+ ffmpy==0.3.1
16
+ filelock==3.13.1
17
+ fonttools==4.47.0
18
+ fsspec==2023.12.2
19
+ gradio==4.12.0
20
+ gradio_client==0.8.0
21
+ h11==0.14.0
22
+ httpcore==1.0.2
23
+ httpx==0.26.0
24
+ huggingface-hub==0.20.1
25
+ idna==3.6
26
+ importlib-resources==6.1.1
27
+ Jinja2==3.1.2
28
+ jsonschema==4.20.0
29
+ jsonschema-specifications==2023.12.1
30
+ kiwisolver==1.4.5
31
+ markdown-it-py==3.0.0
32
+ MarkupSafe==2.1.3
33
+ matplotlib==3.8.2
34
+ mdurl==0.1.2
35
+ mpmath==1.3.0
36
+ networkx==3.2.1
37
+ numpy==1.26.2
38
+ orjson==3.9.10
39
+ packaging==23.2
40
+ pandas==2.1.4
41
+ peft==0.7.1
42
+ Pillow==10.1.0
43
+ psutil==5.9.7
44
+ pydantic==2.5.3
45
+ pydantic_core==2.14.6
46
+ pydub==0.25.1
47
+ Pygments==2.17.2
48
+ pyparsing==3.1.1
49
+ python-dateutil==2.8.2
50
+ python-multipart==0.0.6
51
+ pytz==2023.3.post1
52
+ PyYAML==6.0.1
53
+ referencing==0.32.0
54
+ regex==2023.12.25
55
+ requests==2.31.0
56
+ rich==13.7.0
57
+ rpds-py==0.16.2
58
+ safetensors==0.4.1
59
+ semantic-version==2.10.0
60
+ shellingham==1.5.4
61
+ six==1.16.0
62
+ sniffio==1.3.0
63
+ starlette==0.32.0.post1
64
+ sympy==1.12
65
+ tokenizers==0.15.0
66
+ tomlkit==0.12.0
67
+ toolz==0.12.0
68
+ torch==2.1.2
69
+ tqdm==4.66.1
70
+ transformers==4.36.2
71
+ typer==0.9.0
72
+ typing_extensions==4.9.0
73
+ tzdata==2023.4
74
+ urllib3==2.1.0
75
+ uvicorn==0.25.0
76
+ websockets==11.0.3