import torch from config import Config from networks import peft_model tokenizer = Config.tokenizer tokenizer.pad_token = tokenizer.eos_token tokenizer.add_tokens('') def prepare_inputs(peft_model, audio_model, clip_model, projection, text_input=None, image_input=None, audio_input=None): text_audio, text_embed, image_embed = None, None, None if audio_input: audio_transcribed = audio_model.transcribe(audio_input) processed_audio = '' for audio_segment in audio_transcribed['segments']: processed_audio += audio_segment['text'] processed_audio = processed_audio.strip() if image_input != None: image_processed = Config.processor(images=image_input, return_tensors="pt") with torch.no_grad(): outputs = clip_model(**image_processed) last_hidden_state = outputs.last_hidden_state[:, 1:, :] image_embed = projection(last_hidden_state.to(Config.device)).to(torch.float16) if audio_input != None and text_input != None: text_audio = f"{text_input} {processed_audio}" elif audio_input and text_input == None: text_audio = processed_audio elif audio_input == None and text_input: text_audio = text_input if text_audio: tokenized_text_audio = tokenizer.encode(text_audio) tokenized_text_audio = Config.IMAGE_SEPARATOR_TOKENS + tokenized_text_audio + [Config.QUESTION_ANSWER_SEPARATOR_ID] with torch.no_grad(): tokenized_text_audio = torch.tensor(tokenized_text_audio) text_embed = peft_model.model.model.embed_tokens(tokenized_text_audio.to(Config.device)).unsqueeze(0) if text_audio != None and image_input != None: combined_embed = torch.cat([image_embed, text_embed], dim=1) elif text_audio and image_input == None: combined_embed = text_embed elif text_audio == None and image_input: combined_embed = image_embed return(combined_embed) def chatbot_response(text_input, image_input, audio_input): if text_input == '': text_input = None if text_input == None and image_input == None and audio_input == None: return "Please enter text, upload an image, or record audio." combined_embeds = prepare_inputs(text_input, image_input, audio_input) generated_tokens = generate_tokens(combined_embeds, max_tokens=60) return(tokenizer.decode(generated_tokens)) def generate_tokens(combined_embeds, max_tokens=100): pred_tokens = [] combined_embed = combined_embeds for _ in range(max_tokens): logits = peft_model(inputs_embeds=combined_embed).logits[:, -1, :] next_token_id = logits.argmax(dim=-1) if next_token_id.item() == 50256: break pred_tokens.append(next_token_id.item()) next_token_embed = peft_model.model.model.embed_tokens(next_token_id.unsqueeze(0)) with torch.no_grad(): combined_embed = torch.cat((combined_embed, next_token_embed), dim=1) return(pred_tokens)