zeroMN commited on
Commit
8440578
·
verified ·
1 Parent(s): 8125348

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -5
app.py CHANGED
@@ -29,11 +29,10 @@ class MultiModalModel(nn.Module):
29
  self.nlp_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
30
  self.speech_processor = AutoTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
31
  self.vision_processor = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
32
-
33
  def forward(self, task, inputs):
34
  if task == 'text_generation':
35
  attention_mask = inputs.get('attention_mask')
36
- print("输入数据:", inputs)
37
  outputs = self.text_generator.generate(
38
  inputs['input_ids'],
39
  max_new_tokens=100,
@@ -44,7 +43,6 @@ class MultiModalModel(nn.Module):
44
  temperature=0.8,
45
  do_sample=True
46
  )
47
- print("生成的输出:", outputs)
48
  return self.text_tokenizer.decode(outputs[0], skip_special_tokens=True)
49
  elif task == 'code_generation':
50
  attention_mask = inputs.get('attention_mask')
@@ -59,7 +57,6 @@ class MultiModalModel(nn.Module):
59
  do_sample=True
60
  )
61
  return self.code_tokenizer.decode(outputs[0], skip_special_tokens=True)
62
- # 添加其他任务的逻辑...
63
 
64
  # 定义 Gradio 接口的推理函数
65
  def gradio_inference(task, input_text):
@@ -90,4 +87,3 @@ interface = gr.Interface(
90
 
91
  # 启动 Gradio 应用
92
  interface.launch()
93
-
 
29
  self.nlp_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
30
  self.speech_processor = AutoTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
31
  self.vision_processor = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
32
+
33
  def forward(self, task, inputs):
34
  if task == 'text_generation':
35
  attention_mask = inputs.get('attention_mask')
 
36
  outputs = self.text_generator.generate(
37
  inputs['input_ids'],
38
  max_new_tokens=100,
 
43
  temperature=0.8,
44
  do_sample=True
45
  )
 
46
  return self.text_tokenizer.decode(outputs[0], skip_special_tokens=True)
47
  elif task == 'code_generation':
48
  attention_mask = inputs.get('attention_mask')
 
57
  do_sample=True
58
  )
59
  return self.code_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
60
 
61
  # 定义 Gradio 接口的推理函数
62
  def gradio_inference(task, input_text):
 
87
 
88
  # 启动 Gradio 应用
89
  interface.launch()