Mohammed-Altaf commited on
Commit
bb10374
1 Parent(s): 444b5a5

added accelerate to the requirements

Browse files
Files changed (2) hide show
  1. load_meta_data.py +1 -1
  2. main.py +4 -2
load_meta_data.py CHANGED
@@ -14,7 +14,7 @@ class ChatBot:
14
 
15
  def load_from_hub(self,model_id: str):
16
  self.tokenizer = AutoTokenizer.from_pretrained(model_id,)
17
- self.model = AutoModelForCausalLM.from_pretrained(model_id,)
18
 
19
  def get_response(self,text: UserQuery) -> str:
20
  if not self.model or not self.tokenizer:
 
14
 
15
  def load_from_hub(self,model_id: str):
16
  self.tokenizer = AutoTokenizer.from_pretrained(model_id,)
17
+ self.model = AutoModelForCausalLM.from_pretrained(model_id,ignore_mismatched_sizes=True)
18
 
19
  def get_response(self,text: UserQuery) -> str:
20
  if not self.model or not self.tokenizer:
main.py CHANGED
@@ -1,10 +1,12 @@
1
  # Custom Imports
2
  from load_meta_data import ChatBot
 
 
3
 
4
  # Built-in Imports
5
  from enum import Enum
6
- from fastapi import FastAPI
7
- from fastapi.middleware.cors import CORSMiddleware
8
 
9
  # Enum to save the Model Id, which is Constant
10
  class Repo_ID(Enum):
 
1
  # Custom Imports
2
  from load_meta_data import ChatBot
3
+ from fastapi import FastAPI
4
+ from fastapi.middleware.cors import CORSMiddleware
5
 
6
  # Built-in Imports
7
  from enum import Enum
8
+ import warnings
9
+ warnings.filterwarnings('ignore')
10
 
11
  # Enum to save the Model Id, which is Constant
12
  class Repo_ID(Enum):