avreymi commited on
Commit
3fd8457
1 Parent(s): 415fc0f

fix model name

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. model.py +1 -1
  3. pipline.py +2 -2
app.py CHANGED
@@ -22,7 +22,7 @@ def __run_pipline():
22
  def __run_model():
23
  st.text(f"input_text: {state.input_text}")
24
  st.markdown(":green[Running model]")
25
- st.text(model.model(state.input_text))
26
 
27
 
28
  st.text_area("input_text", key="input_text")
 
22
  def __run_model():
23
  st.text(f"input_text: {state.input_text}")
24
  st.markdown(":green[Running model]")
25
+ st.text(model.run(state.input_text))
26
 
27
 
28
  st.text_area("input_text", key="input_text")
model.py CHANGED
@@ -12,7 +12,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto").to(device=device)
13
 
14
 
15
- def model(text,**kargs):
16
  inputs = tokenizer.encode(text=text, return_tensors="pt").to(device=device)
17
  outputs = model.generate(inputs,**kargs)
18
  return tokenizer.decode(outputs[0])
 
12
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto").to(device=device)
13
 
14
 
15
+ def run(text,**kargs):
16
  inputs = tokenizer.encode(text=text, return_tensors="pt").to(device=device)
17
  outputs = model.generate(inputs,**kargs)
18
  return tokenizer.decode(outputs[0])
pipline.py CHANGED
@@ -2,7 +2,7 @@ import langchain as lc
2
  from langchain import PromptTemplate
3
  from langchain.prompts import load_prompt
4
  import wikipedia
5
- from model import model
6
 
7
  # save templates to a file
8
 
@@ -24,7 +24,7 @@ def pipeline(text, word):
24
  model_output = ""
25
  input_text = prompt.format(adjective="funny", content=text)
26
  while word not in model_output:
27
- model_output = model(input_text)
28
  wikipedia_entry = wikipedia.search(word)[1]
29
  wiki = wikipedia.summary(wikipedia_entry, auto_suggest=False, redirect=True)
30
  input_text += model_output + wiki
 
2
  from langchain import PromptTemplate
3
  from langchain.prompts import load_prompt
4
  import wikipedia
5
+ import model
6
 
7
  # save templates to a file
8
 
 
24
  model_output = ""
25
  input_text = prompt.format(adjective="funny", content=text)
26
  while word not in model_output:
27
+ model_output = model.run(input_text)
28
  wikipedia_entry = wikipedia.search(word)[1]
29
  wiki = wikipedia.summary(wikipedia_entry, auto_suggest=False, redirect=True)
30
  input_text += model_output + wiki