diff options
Diffstat (limited to 'lib/connection.cpp')
-rw-r--r-- | lib/connection.cpp | 120 |
1 files changed, 85 insertions, 35 deletions
diff --git a/lib/connection.cpp b/lib/connection.cpp index 0ef002ca..28377dd9 100644 --- a/lib/connection.cpp +++ b/lib/connection.cpp @@ -119,6 +119,8 @@ public: PicklingMode picklingMode = Unencrypted {}; Database *database = nullptr; QHash<QString, int> oneTimeKeysCount; + std::vector<std::unique_ptr<EncryptedEvent>> pendingEncryptedEvents; + void handleEncryptedToDeviceEvent(const EncryptedEvent& event); // A map from SenderKey to vector of InboundSession UnorderedMap<QString, std::vector<QOlmSessionPtr>> olmSessions; @@ -219,7 +221,7 @@ public: } q->database()->saveOlmSession(senderKey, session->sessionId(), std::get<QByteArray>(pickleResult), QDateTime::currentDateTime()); } - QString sessionDecryptPrekey(const QOlmMessage& message, const QString &senderKey, std::unique_ptr<QOlmAccount>& olmAccount) + std::pair<QString, QString> sessionDecryptPrekey(const QOlmMessage& message, const QString &senderKey, std::unique_ptr<QOlmAccount>& olmAccount) { Q_ASSERT(message.type() == QOlmMessage::PreKey); for (size_t i = 0; i < olmSessions[senderKey].size(); i++) { @@ -239,7 +241,7 @@ public: auto s = std::move(session); olmSessions[senderKey].erase(olmSessions[senderKey].begin() + i); olmSessions[senderKey].insert(olmSessions[senderKey].begin(), std::move(s)); - return std::get<QString>(result); + return { std::get<QString>(result), session->sessionId() }; } else { qCDebug(E2EE) << "Failed to decrypt prekey message"; return {}; @@ -258,16 +260,17 @@ public: qWarning(E2EE) << "Failed to remove one time key for session" << newSession->sessionId(); } const auto result = newSession->decrypt(message); + QString sessionId = newSession->sessionId(); saveSession(newSession, senderKey); olmSessions[senderKey].insert(olmSessions[senderKey].begin(), std::move(newSession)); if(std::holds_alternative<QString>(result)) { - return std::get<QString>(result); + return { std::get<QString>(result), sessionId }; } else { qCDebug(E2EE) << "Failed to decrypt prekey message with new session"; return {}; } } - QString sessionDecryptGeneral(const QOlmMessage& message, const QString &senderKey) + std::pair<QString, QString> sessionDecryptGeneral(const QOlmMessage& message, const QString &senderKey) { Q_ASSERT(message.type() == QOlmMessage::General); for (size_t i = 0; i < olmSessions[senderKey].size(); i++) { @@ -284,31 +287,36 @@ public: auto s = std::move(session); olmSessions[senderKey].erase(olmSessions[senderKey].begin() + i); olmSessions[senderKey].insert(olmSessions[senderKey].begin(), std::move(s)); - return std::get<QString>(result); + return { std::get<QString>(result), session->sessionId() }; } } qCWarning(E2EE) << "Failed to decrypt message"; return {}; } - QString sessionDecryptMessage( + std::pair<QString, QString> sessionDecryptMessage( const QJsonObject& personalCipherObject, const QByteArray& senderKey, std::unique_ptr<QOlmAccount>& account) { QString decrypted; + QString olmSessionId; int type = personalCipherObject.value(TypeKeyL).toInt(-1); QByteArray body = personalCipherObject.value(BodyKeyL).toString().toLatin1(); if (type == QOlmMessage::PreKey) { QOlmMessage preKeyMessage(body, QOlmMessage::PreKey); - decrypted = sessionDecryptPrekey(preKeyMessage, senderKey, account); + auto result = sessionDecryptPrekey(preKeyMessage, senderKey, account); + decrypted = result.first; + olmSessionId = result.second; } else if (type == QOlmMessage::General) { QOlmMessage message(body, QOlmMessage::General); - decrypted = sessionDecryptGeneral(message, senderKey); + auto result = sessionDecryptGeneral(message, senderKey); + decrypted = result.first; + olmSessionId = result.second; } - return decrypted; + return { decrypted, olmSessionId }; } #endif - EventPtr sessionDecryptMessage(const EncryptedEvent& encryptedEvent) + std::pair<EventPtr, QString> sessionDecryptMessage(const EncryptedEvent& encryptedEvent) { #ifndef Quotient_E2EE_ENABLED qCWarning(E2EE) << "End-to-end encryption (E2EE) support is turned off."; @@ -324,7 +332,7 @@ public: qCDebug(E2EE) << "Encrypted event is not for the current device"; return {}; } - const auto decrypted = sessionDecryptMessage( + const auto [decrypted, olmSessionId] = sessionDecryptMessage( personalCipherObject, encryptedEvent.senderKey().toLatin1(), olmAccount); if (decrypted.isEmpty()) { qCDebug(E2EE) << "Problem with new session from senderKey:" @@ -343,9 +351,17 @@ public: << "in Olm plaintext"; return {}; } - //TODO make this do the check mentioned in the E2EE Implementation guide instead - if (decryptedEvent->fullJson()["keys"]["ed25519"].toString().isEmpty()) { - qCDebug(E2EE) << "Event does not contain an ed25519 key"; + + auto query = database->prepareQuery(QStringLiteral("SELECT edKey FROM tracked_devices WHERE curveKey=:curveKey;")); + query.bindValue(":curveKey", encryptedEvent.contentJson()["sender_key"].toString()); + database->execute(query); + if (!query.next()) { + qCWarning(E2EE) << "Received olm message from unknown device" << encryptedEvent.contentJson()["sender_key"].toString(); + return {}; + } + auto edKey = decryptedEvent->fullJson()["keys"]["ed25519"].toString(); + if (edKey.isEmpty() || query.value(QStringLiteral("edKey")).toString() != edKey) { + qCDebug(E2EE) << "Received olm message with invalid ed key"; return {}; } @@ -367,7 +383,7 @@ public: return {}; } - return std::move(decryptedEvent); + return { std::move(decryptedEvent), olmSessionId }; #endif // Quotient_E2EE_ENABLED } #ifdef Quotient_E2EE_ENABLED @@ -950,30 +966,46 @@ void Connection::Private::consumeToDeviceEvents(Events&& toDeviceEvents) qCDebug(E2EE) << "Unsupported algorithm" << event.id() << "for event" << event.algorithm(); return; } - const auto decryptedEvent = sessionDecryptMessage(event); - if(!decryptedEvent) { - qCWarning(E2EE) << "Failed to decrypt event" << event.id(); + if (q->isKnownCurveKey(event.senderId(), event.senderKey())) { + handleEncryptedToDeviceEvent(event); return; } - - switchOnType(*decryptedEvent, - [this, senderKey = event.senderKey()](const RoomKeyEvent& roomKeyEvent) { - if (auto* detectedRoom = q->room(roomKeyEvent.roomId())) { - detectedRoom->handleRoomKeyEvent(roomKeyEvent, senderKey); - } else { - qCDebug(E2EE) << "Encrypted event room id" << roomKeyEvent.roomId() - << "is not found at the connection" << q->objectName(); - } - }, - [](const Event& evt) { - qCDebug(E2EE) << "Skipping encrypted to_device event, type" - << evt.matrixType(); - }); + trackedUsers += event.senderId(); + outdatedUsers += event.senderId(); + encryptionUpdateRequired = true; + pendingEncryptedEvents.push_back(std::make_unique<EncryptedEvent>(event.fullJson())); + }, [](const Event& e){ + // Unhandled }); } #endif } +#ifdef Quotient_E2EE_ENABLED +void Connection::Private::handleEncryptedToDeviceEvent(const EncryptedEvent& event) +{ + const auto [decryptedEvent, olmSessionId] = sessionDecryptMessage(event); + if(!decryptedEvent) { + qCWarning(E2EE) << "Failed to decrypt event" << event.id(); + return; + } + + switchOnType(*decryptedEvent, + [this, senderKey = event.senderKey(), &event, olmSessionId = olmSessionId](const RoomKeyEvent& roomKeyEvent) { + if (auto* detectedRoom = q->room(roomKeyEvent.roomId())) { + detectedRoom->handleRoomKeyEvent(roomKeyEvent, event.senderId(), olmSessionId); + } else { + qCDebug(E2EE) << "Encrypted event room id" << roomKeyEvent.roomId() + << "is not found at the connection" << q->objectName(); + } + }, + [](const Event& evt) { + qCDebug(E2EE) << "Skipping encrypted to_device event, type" + << evt.matrixType(); + }); +} +#endif + void Connection::Private::consumeDevicesList(DevicesList&& devicesList) { #ifdef Quotient_E2EE_ENABLED @@ -2059,6 +2091,15 @@ void Connection::Private::loadOutdatedUserDevices() outdatedUsers -= user; } saveDevicesList(); + + for(size_t i = 0; i < pendingEncryptedEvents.size();) { + if (q->isKnownCurveKey(pendingEncryptedEvents[i]->fullJson()[SenderKeyL].toString(), pendingEncryptedEvents[i]->contentJson()["sender_key"].toString())) { + handleEncryptedToDeviceEvent(*(pendingEncryptedEvents[i].get())); + pendingEncryptedEvents.erase(pendingEncryptedEvents.begin() + i); + } else { + i++; + } + } }); } @@ -2180,14 +2221,14 @@ Database* Connection::database() return d->database; } -UnorderedMap<std::pair<QString, QString>, QOlmInboundGroupSessionPtr> Connection::loadRoomMegolmSessions(Room* room) +UnorderedMap<QString, QOlmInboundGroupSessionPtr> Connection::loadRoomMegolmSessions(Room* room) { return database()->loadMegolmSessions(room->id(), picklingMode()); } -void Connection::saveMegolmSession(Room* room, const QString& senderKey, QOlmInboundGroupSession* session, const QString& ed25519Key) +void Connection::saveMegolmSession(Room* room, QOlmInboundGroupSession* session) { - database()->saveMegolmSession(room->id(), senderKey, session->sessionId(), ed25519Key, session->pickle(picklingMode())); + database()->saveMegolmSession(room->id(), session->sessionId(), session->pickle(picklingMode()), session->senderId(), session->olmSessionId()); } QStringList Connection::devicesForUser(User* user) const @@ -2246,4 +2287,13 @@ void Connection::saveCurrentOutboundMegolmSession(Room *room, const QOlmOutbound d->database->saveCurrentOutboundMegolmSession(room->id(), d->picklingMode, data); } +bool Connection::isKnownCurveKey(const QString& user, const QString& curveKey) +{ + auto query = database()->prepareQuery(QStringLiteral("SELECT * FROM tracked_devices WHERE matrixId=:matrixId AND curveKey=:curveKey")); + query.bindValue(":matrixId", user); + query.bindValue(":curveKey", curveKey); + database()->execute(query); + return query.next(); +} + #endif |