mouryachinta commited on
Commit
4e37510
β€’
1 Parent(s): 5ece8e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -9
app.py CHANGED
@@ -1,25 +1,50 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
- def initialize_model_and_tokenizer(model_name="mouryachinta/llama-2-7b-mourya"):
5
- model = AutoModelForCausalLM.from_pretrained(model_name)
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- return model, tokenizer
8
- model, tokenizer = initialize_model_and_tokenizer()
9
 
10
- from langchain.llms.base import LLM
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class CustomLLM(LLM):
 
 
 
13
  def _call(self, prompt, stop=None, run_manager=None) -> str:
14
- inputs = tokenizer(prompt, return_tensors="pt")
15
- result = model.generate(input_ids=inputs.input_ids, max_new_tokens=20)
16
- result = tokenizer.decode(result[0])
 
 
 
 
17
  return result
18
 
19
  @property
20
  def _llm_type(self) -> str:
21
  return "custom"
22
 
 
23
  llm = CustomLLM()
24
 
25
  from langchain import PromptTemplate
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ base_model_name = "mouryachinta/llama-2-7b-mourya"
5
+ # Tokenizer
 
 
 
6
 
7
+ llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
8
+ llama_tokenizer.pad_token = llama_tokenizer.eos_token
9
+ llama_tokenizer.padding_side = "right" # Fix for fp16
10
 
11
+ # Quantization Config
12
+ quant_config = BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_compute_dtype=torch.float16,
16
+ bnb_4bit_use_double_quant=False
17
+ )
18
+
19
+ # Model Initialization
20
+ base_model = AutoModelForCausalLM.from_pretrained(
21
+ base_model_name,
22
+ quantization_config=quant_config,
23
+ device_map={"": 0}
24
+ )
25
+ base_model.config.use_cache = False
26
+ base_model.config.pretraining_tp = 1
27
+
28
+ # Define CustomLLM class
29
  class CustomLLM(LLM):
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
  def _call(self, prompt, stop=None, run_manager=None) -> str:
34
+ inputs = llama_tokenizer(prompt, return_tensors="pt")
35
+ input_ids = inputs.input_ids.to(base_model.device)
36
+ attention_mask = inputs.attention_mask.to(base_model.device) if "attention_mask" in inputs else None
37
+ if "max_length" not in stop:
38
+ stop["max_length"] = 20
39
+ result = base_model.generate(input_ids=input_ids, attention_mask=attention_mask, **stop)
40
+ result = llama_tokenizer.decode(result[0], skip_special_tokens=True)
41
  return result
42
 
43
  @property
44
  def _llm_type(self) -> str:
45
  return "custom"
46
 
47
+ # Instantiate CustomLLM
48
  llm = CustomLLM()
49
 
50
  from langchain import PromptTemplate