Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| absl::StatusOr<std::unique_ptr<LoraManager>> LoraManager::Create( | |
| const litert::CompiledModel& compiled_model) { | |
| return absl::WrapUnique(new LoraManager(compiled_model)); | |
| } | |
| LoraManager::LoraManager(const litert::CompiledModel& compiled_model) | |
| : compiled_model_(compiled_model) {} | |
| absl::Status LoraManager::LoadLoRA(uint32_t lora_id, | |
| const ModelAssets& model_assets) { | |
| if (lora_data_.contains(lora_id)) { | |
| return absl::AlreadyExistsError("LoRA ID already exists"); | |
| } | |
| ASSIGN_OR_RETURN(auto scoped_file, model_assets.GetOrCreateScopedFile()); | |
| ASSIGN_OR_RETURN(auto lora_data, LoraData::CreateFromScopedFile(scoped_file)); | |
| lora_data_[lora_id] = std::move(lora_data); | |
| return absl::OkStatus(); | |
| } | |
| absl::Status LoraManager::UseLoRA(uint32_t lora_id) { | |
| if (!lora_data_.contains(lora_id) && !loras_.contains(lora_id)) { | |
| return absl::NotFoundError("LoRA ID not found"); | |
| } | |
| if (!loras_.contains(lora_id)) { | |
| ASSIGN_OR_RETURN(auto lora, LoRA::Create(std::move(lora_data_[lora_id]), | |
| compiled_model_)); | |
| loras_[lora_id] = std::move(lora); | |
| lora_data_.erase(lora_id); | |
| } | |
| current_lora_id_ = lora_id; | |
| return absl::OkStatus(); | |
| } | |
| absl::StatusOr<absl::flat_hash_map<absl::string_view, litert::TensorBuffer>> | |
| LoraManager::GetLoRABuffers() const { | |
| if (!current_lora_id_.has_value()) { | |
| return absl::FailedPreconditionError("No LoRA ID is set"); | |
| } | |
| if (!loras_.contains(*current_lora_id_)) { | |
| return absl::NotFoundError("LoRA ID not found"); | |
| } | |
| return loras_.at(*current_lora_id_)->GetLoRABuffers(); | |
| } | |
| } // namespace litert::lm | |