S-Dreamer commited on
Commit
c766fcb
·
verified ·
1 Parent(s): 0a023ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -50
app.py CHANGED
@@ -1,89 +1,63 @@
1
-
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
5
 
6
- # Define a class named `CodeGenerator` that will be responsible for generating code based on a given prompt.
7
  class CodeGenerator:
8
- # The constructor initializes the CodeGenerator object with a pre-trained model name.
9
- # The default model name is "Salesforce/codet5-base".
10
- def __init__(self, model_name="Salesforce/codet5-base"):
11
- # Load the pre-trained tokenizer from the specified model name.
12
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- # Load the pre-trained sequence-to-sequence language model from the specified model name.
14
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
15
 
16
- # This method generates code based on the given prompt.
17
- # The method takes two parameters: `prompt` (the input text) and `max_length` (the maximum length of the generated code).
18
  def generate_code(self, prompt, max_length=100):
19
- # Encode the prompt into input IDs that the model can understand.
20
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
21
- # Generate the output sequence using the pre-trained model.
22
- # The `generate` method takes the input IDs, the maximum length of the output, and the number of output sequences to return (in this case, 1).
23
- output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1)
24
- # Decode the output sequence and return the generated code.
25
- return self.tokenizer.decode(output[0], skip_special_tokens=True)
26
 
27
- # Define a class named `ChatHandler` that will be responsible for managing the chat history.
28
  class ChatHandler:
29
- # The constructor initializes the ChatHandler object with an empty chat history.
30
- def __init__(self):
31
  self.history = []
 
32
 
33
- # This method handles incoming messages and generates responses using the provided CodeGenerator.
34
- # The method takes two parameters: `message` (the user's input message) and `code_generator` (an instance of the CodeGenerator class).
35
- def handle_message(self, message, code_generator):
36
- # Generate the response using the provided CodeGenerator.
37
- response = code_generator.generate_code(message)
38
- # Append the message-response pair to the chat history.
39
  self.history.append((message, response))
40
- # Return the empty message input and the updated chat history.
41
  return "", self.history
42
 
43
- # Define a function named `create_gradio_interface` that creates a Gradio interface for the chat application.
 
 
 
44
  def create_gradio_interface():
45
- # Create an instance of the CodeGenerator class.
46
- code_generator = CodeGenerator()
47
- # Create an instance of the ChatHandler class.
48
- chat_handler = ChatHandler()
49
 
50
- # Create a Gradio Blocks interface with a soft theme.
51
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
52
- # Display a Markdown title for the chat interface.
53
  gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface")
54
 
55
- # Create a row with two columns.
56
  with gr.Row():
57
- # The first column will contain the chat interface.
58
  with gr.Column(scale=3):
59
- # Create a chatbot component to display the chat history.
60
  chatbot = gr.Chatbot(height=400)
61
- # Create a textbox for the user to input their message.
62
  message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...")
63
- # Create a submit button to send the message.
64
  submit_button = gr.Button("Submit")
65
 
66
- # The second column will contain the features.
67
  with gr.Column(scale=1):
68
- # Display a Markdown title for the features section.
69
  gr.Markdown("## Features")
70
- # Define a list of features.
71
  features = ["Code generation", "Code completion", "Code explanation", "Error correction"]
72
- # Display each feature as a Markdown list item.
73
  for feature in features:
74
  gr.Markdown(f"- {feature}")
75
- # Create a button to clear the chat history.
76
  clear_button = gr.Button("Clear Chat")
77
 
78
- # Connect the submit button to the `handle_message` method of the ChatHandler.
79
- submit_button.click(chat_handler.handle_message, inputs=[message_input], outputs=[message_input, chatbot])
80
- # Connect the clear button to a function that clears the chat history.
81
- clear_button.click(lambda: None, outputs=[chatbot], inputs=[])
82
 
83
- # Launch the Gradio interface.
84
  demo.launch()
85
 
86
- # This is the entry point of the application.
87
  if __name__ == "__main__":
88
- # Call the `create_gradio_interface` function to start the chat application.
89
  create_gradio_interface()
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
 
5
  class CodeGenerator:
6
+ def __init__(self, model_name="Salesforce/codet5-base", device=None):
 
 
 
7
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
8
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
+ if device:
10
+ self.model = self.model.to(device)
11
 
 
 
12
  def generate_code(self, prompt, max_length=100):
13
+ try:
14
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
15
+ output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1)
16
+ return self.tokenizer.decode(output[0], skip_special_tokens=True)
17
+ except Exception as e:
18
+ return f"Error generating code: {str(e)}"
 
19
 
 
20
  class ChatHandler:
21
+ def __init__(self, code_generator):
 
22
  self.history = []
23
+ self.code_generator = code_generator # Store the generator reference
24
 
25
+ def handle_message(self, message):
26
+ if not message.strip():
27
+ return "", self.history
28
+ response = self.code_generator.generate_code(message)
 
 
29
  self.history.append((message, response))
 
30
  return "", self.history
31
 
32
+ def clear_history(self):
33
+ self.history = []
34
+ return []
35
+
36
  def create_gradio_interface():
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ code_generator = CodeGenerator(device=device)
39
+ chat_handler = ChatHandler(code_generator)
 
40
 
 
41
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
42
  gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface")
43
 
 
44
  with gr.Row():
 
45
  with gr.Column(scale=3):
 
46
  chatbot = gr.Chatbot(height=400)
 
47
  message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...")
 
48
  submit_button = gr.Button("Submit")
49
 
 
50
  with gr.Column(scale=1):
 
51
  gr.Markdown("## Features")
 
52
  features = ["Code generation", "Code completion", "Code explanation", "Error correction"]
 
53
  for feature in features:
54
  gr.Markdown(f"- {feature}")
 
55
  clear_button = gr.Button("Clear Chat")
56
 
57
+ submit_button.click(chat_handler.handle_message, inputs=message_input, outputs=[message_input, chatbot])
58
+ clear_button.click(lambda: (None, chat_handler.clear_history()), inputs=[], outputs=[message_input, chatbot])
 
 
59
 
 
60
  demo.launch()
61
 
 
62
  if __name__ == "__main__":
 
63
  create_gradio_interface()