Sijuade's picture
Update utils.py
54c36b3 verified
raw
history blame
3 kB
import torch
from config import Config
from networks import peft_model
tokenizer = Config.tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_tokens('<question-answer>')
peft_model, audio_model = Config.peft_model, Config.audio_model
clip_model, projection = Config.clip_model, Config.projection
def prepare_inputs(text_input=None, image_input=None, audio_input=None):
text_audio, text_embed, image_embed = None, None, None
if audio_input:
audio_transcribed = audio_model.transcribe(audio_input)
processed_audio = ''
for audio_segment in audio_transcribed['segments']:
processed_audio += audio_segment['text']
processed_audio = processed_audio.strip()
if image_input != None:
image_processed = Config.processor(images=image_input, return_tensors="pt")
with torch.no_grad():
outputs = clip_model(**image_processed.to(Config.device))
last_hidden_state = outputs.last_hidden_state[:, 1:, :]
image_embed = projection(last_hidden_state.to(Config.device)).to(torch.float16)
if audio_input != None and text_input != None:
text_audio = f"{text_input} {processed_audio}"
elif audio_input and text_input == None:
text_audio = processed_audio
elif audio_input == None and text_input:
text_audio = text_input
if text_audio:
tokenized_text_audio = tokenizer.encode(text_audio)
tokenized_text_audio = Config.IMAGE_SEPARATOR_TOKENS + tokenized_text_audio + [Config.QUESTION_ANSWER_SEPARATOR_ID]
with torch.no_grad():
tokenized_text_audio = torch.tensor(tokenized_text_audio)
text_embed = peft_model.model.model.embed_tokens(tokenized_text_audio.to(Config.device)).unsqueeze(0)
if text_audio != None and image_input != None:
combined_embed = torch.cat([image_embed, text_embed], dim=1)
elif text_audio and image_input == None:
combined_embed = text_embed
elif text_audio == None and image_input:
combined_embed = image_embed
return(combined_embed)
def chatbot_response(text_input, image_input, audio_input):
if text_input == '':
text_input = None
if text_input == None and image_input == None and audio_input == None:
return "Please enter text, upload an image, or record audio."
combined_embeds = prepare_inputs(text_input, image_input, audio_input)
generated_tokens = generate_tokens(combined_embeds, max_tokens=60)
return(tokenizer.decode(generated_tokens))
def generate_tokens(combined_embeds, max_tokens=100):
pred_tokens = []
combined_embed = combined_embeds
for _ in range(max_tokens):
logits = peft_model(inputs_embeds=combined_embed).logits[:, -1, :]
next_token_id = logits.argmax(dim=-1)
if next_token_id.item() == 50256:
break
pred_tokens.append(next_token_id.item())
next_token_embed = peft_model.model.model.embed_tokens(next_token_id.unsqueeze(0))
with torch.no_grad():
combined_embed = torch.cat((combined_embed, next_token_embed), dim=1)
return(pred_tokens)