Avimanyu commited on
Commit
8814dc8
1 Parent(s): 2d13aca

Add application file

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
5
+
6
+ # Define hyperparameters
7
+ max_seq_length = 512
8
+ max_output_length = 1024
9
+ num_beams = 16
10
+ length_penalty = 1.4
11
+ no_repeat_ngram_size = 2
12
+ temperature = 0.7
13
+ top_k = 150
14
+ top_p = 0.92
15
+ repetition_penalty = 2.1
16
+ early_stopping = True
17
+
18
+ # Load the pre-trained model and tokenizer
19
+ model_name = "google/flan-t5-large"
20
+ tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=512)
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ if torch.cuda.device_count() > 1:
24
+ device_ids = [i for i in range(torch.cuda.device_count())]
25
+ model = torch.nn.DataParallel(T5ForConditionalGeneration.from_pretrained(model_name, return_dict=True), device_ids=device_ids)
26
+ else:
27
+ model = T5ForConditionalGeneration.from_pretrained(model_name, return_dict=True)
28
+
29
+ model.to(device)
30
+
31
+ # Define a function to generate a response to user input
32
+ def chatbot(text):
33
+ with torch.no_grad():
34
+ # Tokenize the input text and convert to a PyTorch tensor
35
+ input_ids = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_seq_length).input_ids.to(device)
36
+
37
+ # Generate a response using the model
38
+ if torch.cuda.device_count() > 1:
39
+ outputs = model.module.generate(input_ids, min_length=max_seq_length, max_new_tokens=max_output_length, num_beams=num_beams, length_penalty=length_penalty, no_repeat_ngram_size=no_repeat_ngram_size, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, early_stopping=early_stopping)
40
+ else:
41
+ outputs = model.generate(input_ids, min_length=max_seq_length, max_new_tokens=max_output_length, num_beams=num_beams, length_penalty=length_penalty, no_repeat_ngram_size=no_repeat_ngram_size, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, early_stopping=early_stopping)
42
+
43
+ # Decode the response and return it
44
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ return response
46
+
47
+ chat_help_text = "Welcome! This ChatBot is designed to answer questions about a wide range of topics. " \
48
+ "Please note that the ChatBot may not always provide accurate or complete answers, and may not " \
49
+ "understand certain questions. To use the ChatBot, simply type in your question in the text box " \
50
+ "below and hit Enter or click the button. Please keep in mind that the ChatBot is not perfect " \
51
+ "and may provide inaccurate or incomplete answers. It is best suited for simple factual " \
52
+ "questions rather than complex or nuanced inquiries."
53
+
54
+ # Create a Gradio interface
55
+ iface = gr.Interface(fn=chatbot, inputs="text", outputs="text", title="NuNet Inferencing Demo",
56
+ description=chat_help_text)
57
+
58
+ iface.launch(share=True)
59
+