Pgohari commited on
Commit
a66e61b
1 Parent(s): d354ed6

update the model to mt5-small

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -1,20 +1,26 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
- # Load the Mistral AI model and tokenizer from Hugging Face
5
- model_name = "mistralai/Mistral-7B"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
- # Define the chatbot function
10
  def chatbot(user_input):
11
- inputs = tokenizer(user_input, return_tensors="pt")
12
- outputs = model.generate(inputs['input_ids'], max_length=50)
 
 
 
 
 
13
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
14
  return response
15
 
16
  # Set up the Gradio interface
17
- demo = gr.Interface(fn=chatbot, inputs="text", outputs="text", title="Mistral AI Chatbot")
18
 
19
  # Launch the app
20
  demo.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
+ # Load the mT5-small model and tokenizer
5
+ model_name = "google/mt5-small"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
9
+ # Define the chatbot function for summarization and answering questions
10
  def chatbot(user_input):
11
+ # Tokenize the user input
12
+ inputs = tokenizer(user_input, return_tensors="pt", max_length=512, truncation=True)
13
+
14
+ # Generate a response (you can customize max_length and num_beams for different outputs)
15
+ outputs = model.generate(inputs["input_ids"], max_length=150, num_beams=2, early_stopping=True)
16
+
17
+ # Decode and return the generated text
18
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
  return response
20
 
21
  # Set up the Gradio interface
22
+ demo = gr.Interface(fn=chatbot, inputs="text", outputs="text", title="mT5-Small Chatbot")
23
 
24
  # Launch the app
25
  demo.launch()
26
+