dimitris commited on
Commit
a3df6eb
·
1 Parent(s): 223c4db

update model

Browse files
Files changed (2) hide show
  1. app.py +7 -18
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,17 +1,12 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- model_name = "EleutherAI/pythia-1b"
6
- model = AutoModelForCausalLM.from_pretrained(
7
- model_name,
8
- cache_dir="./pythia-1b",
9
- )
10
 
11
- tokenizer = AutoTokenizer.from_pretrained(
12
- model_name,
13
- cache_dir="./pythia-1b",
14
- )
15
 
16
 
17
  def predict(message, history):
@@ -23,16 +18,10 @@ def predict(message, history):
23
  input_ids = tokenizer.encode(message, return_tensors="pt")
24
  chat_history_ids = model.generate(
25
  input_ids,
26
- min_length=20,
27
- max_new_tokens=128,
28
  do_sample=True,
29
- top_p=0.95,
30
- top_k=50,
31
- temperature=0.75,
32
- num_return_sequences=5,
33
- pad_token_id=tokenizer.eos_token_id,
34
  )
35
- response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
36
  yield response
37
 
38
 
 
1
  import gradio as gr
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
  import torch
4
 
 
 
 
 
 
5
 
6
+
7
+ model_name = "google/flan-t5-base"
8
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
9
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
10
 
11
 
12
  def predict(message, history):
 
18
  input_ids = tokenizer.encode(message, return_tensors="pt")
19
  chat_history_ids = model.generate(
20
  input_ids,
21
+ max_length=512,
 
22
  do_sample=True,
 
 
 
 
 
23
  )
24
+ response = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True)
25
  yield response
26
 
27
 
requirements.txt CHANGED
@@ -2,3 +2,4 @@ gradio
2
  transformers
3
  # -i https://download.pytorch.org/whl/cpu
4
  torch
 
 
2
  transformers
3
  # -i https://download.pytorch.org/whl/cpu
4
  torch
5
+ sentencepiece