S-Dreamer commited on
Commit
0a023ed
·
verified ·
1 Parent(s): a68d729

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -77
app.py CHANGED
@@ -1,85 +1,89 @@
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
- # Load the Salesforce/codet5-base model and tokenizer
6
- # We are using the 'Salesforce/codet5-base' model, which is a pre-trained model for code-related tasks.
7
- # The AutoTokenizer and AutoModelForSeq2SeqLM classes from the Transformers library are used to load the model and tokenizer.
8
- model_name = "Salesforce/codet5-base"
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Function to generate code
13
- # This function takes a prompt (code-related query) as input and generates code based on that prompt.
14
- # It uses the loaded model and tokenizer to encode the input, generate the output, and then decode the generated text.
15
- def generate_code(prompt, max_length=100):
16
- # Encode the input prompt using the tokenizer
17
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
18
-
19
- # Generate the output using the model
20
- # The `model.generate()` function is used to generate the output sequence based on the input.
21
- # The `max_length` parameter sets the maximum length of the generated sequence.
22
- # The `num_return_sequences` parameter specifies the number of output sequences to be generated (in this case, 1).
23
- output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)
24
-
25
- # Decode the generated output to get the actual code
26
- # The `tokenizer.decode()` function is used to convert the output token IDs back to readable text.
27
- # The `skip_special_tokens=True` argument ensures that any special tokens (e.g., start/end of sequence tokens) are removed from the output.
28
- generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
29
-
30
- # Return the generated code
31
- return generated_code
32
 
33
- # Function to handle chat interaction
34
- # This function is responsible for managing the chat interaction between the user and the system.
35
- # It takes the user's message and the chat history as input, and returns the system's response and the updated chat history.
36
- def chat_interaction(message, history):
37
- # Initialize the chat history if it's not provided
38
- history = history or []
39
-
40
- # Generate the response using the `generate_code` function
41
- response = generate_code(message)
42
-
43
- # Update the chat history by appending the user's message and the system's response
44
- history.append((message, response))
45
-
46
- # Return the empty message (to clear the input field) and the updated chat history
47
- return "", history
48
 
49
- # Create the Gradio interface
50
- # The Gradio library is used to create an interactive web interface for the chat application.
51
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
52
- # Add a Markdown title for the interface
53
- gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface")
54
-
55
- # Create a row with two columns
56
- with gr.Row():
57
- # Left column for the chat area
58
- with gr.Column(scale=3):
59
- # Add a chatbot component to display the chat history
60
- chatbot = gr.Chatbot(height=400)
61
- # Add a text input field for the user to enter messages
62
- message = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...")
63
- # Add a submit button
64
- submit_button = gr.Button("Submit")
65
-
66
- # Right column for the feature list
67
- with gr.Column(scale=1):
68
- # Add Markdown sections for the features
69
- gr.Markdown("## Features")
70
- gr.Markdown("- Code generation")
71
- gr.Markdown("- Code completion")
72
- gr.Markdown("- Code explanation")
73
- gr.Markdown("- Error correction")
74
-
75
- # Add a clear button to reset the chat
76
- clear_button = gr.Button("Clear Chat")
77
-
78
- # Connect the submit button to the `chat_interaction` function
79
- submit_button.click(chat_interaction, inputs=[message, chatbot], outputs=[message, chatbot])
80
-
81
- # Connect the clear button to a lambda function that clears the chat
82
- clear_button.click(lambda: None, outputs=[chatbot], inputs=[])
83
 
84
- # Launch the Gradio interface
85
- demo.launch()
 
 
 
1
+
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
5
 
6
+ # Define a class named `CodeGenerator` that will be responsible for generating code based on a given prompt.
7
+ class CodeGenerator:
8
+ # The constructor initializes the CodeGenerator object with a pre-trained model name.
9
+ # The default model name is "Salesforce/codet5-base".
10
+ def __init__(self, model_name="Salesforce/codet5-base"):
11
+ # Load the pre-trained tokenizer from the specified model name.
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ # Load the pre-trained sequence-to-sequence language model from the specified model name.
14
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
+
16
+ # This method generates code based on the given prompt.
17
+ # The method takes two parameters: `prompt` (the input text) and `max_length` (the maximum length of the generated code).
18
+ def generate_code(self, prompt, max_length=100):
19
+ # Encode the prompt into input IDs that the model can understand.
20
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
21
+ # Generate the output sequence using the pre-trained model.
22
+ # The `generate` method takes the input IDs, the maximum length of the output, and the number of output sequences to return (in this case, 1).
23
+ output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1)
24
+ # Decode the output sequence and return the generated code.
25
+ return self.tokenizer.decode(output[0], skip_special_tokens=True)
26
+
27
+ # Define a class named `ChatHandler` that will be responsible for managing the chat history.
28
+ class ChatHandler:
29
+ # The constructor initializes the ChatHandler object with an empty chat history.
30
+ def __init__(self):
31
+ self.history = []
32
+
33
+ # This method handles incoming messages and generates responses using the provided CodeGenerator.
34
+ # The method takes two parameters: `message` (the user's input message) and `code_generator` (an instance of the CodeGenerator class).
35
+ def handle_message(self, message, code_generator):
36
+ # Generate the response using the provided CodeGenerator.
37
+ response = code_generator.generate_code(message)
38
+ # Append the message-response pair to the chat history.
39
+ self.history.append((message, response))
40
+ # Return the empty message input and the updated chat history.
41
+ return "", self.history
42
+
43
+ # Define a function named `create_gradio_interface` that creates a Gradio interface for the chat application.
44
+ def create_gradio_interface():
45
+ # Create an instance of the CodeGenerator class.
46
+ code_generator = CodeGenerator()
47
+ # Create an instance of the ChatHandler class.
48
+ chat_handler = ChatHandler()
49
+
50
+ # Create a Gradio Blocks interface with a soft theme.
51
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
52
+ # Display a Markdown title for the chat interface.
53
+ gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface")
54
+
55
+ # Create a row with two columns.
56
+ with gr.Row():
57
+ # The first column will contain the chat interface.
58
+ with gr.Column(scale=3):
59
+ # Create a chatbot component to display the chat history.
60
+ chatbot = gr.Chatbot(height=400)
61
+ # Create a textbox for the user to input their message.
62
+ message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...")
63
+ # Create a submit button to send the message.
64
+ submit_button = gr.Button("Submit")
65
 
66
+ # The second column will contain the features.
67
+ with gr.Column(scale=1):
68
+ # Display a Markdown title for the features section.
69
+ gr.Markdown("## Features")
70
+ # Define a list of features.
71
+ features = ["Code generation", "Code completion", "Code explanation", "Error correction"]
72
+ # Display each feature as a Markdown list item.
73
+ for feature in features:
74
+ gr.Markdown(f"- {feature}")
75
+ # Create a button to clear the chat history.
76
+ clear_button = gr.Button("Clear Chat")
 
 
 
 
 
 
 
 
 
77
 
78
+ # Connect the submit button to the `handle_message` method of the ChatHandler.
79
+ submit_button.click(chat_handler.handle_message, inputs=[message_input], outputs=[message_input, chatbot])
80
+ # Connect the clear button to a function that clears the chat history.
81
+ clear_button.click(lambda: None, outputs=[chatbot], inputs=[])
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Launch the Gradio interface.
84
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # This is the entry point of the application.
87
+ if __name__ == "__main__":
88
+ # Call the `create_gradio_interface` function to start the chat application.
89
+ create_gradio_interface()