aboutsummaryrefslogtreecommitdiff
path: root/lib/connection.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/connection.cpp')
-rw-r--r--lib/connection.cpp120
1 files changed, 85 insertions, 35 deletions
diff --git a/lib/connection.cpp b/lib/connection.cpp
index 0ef002ca..28377dd9 100644
--- a/lib/connection.cpp
+++ b/lib/connection.cpp
@@ -119,6 +119,8 @@ public:
PicklingMode picklingMode = Unencrypted {};
Database *database = nullptr;
QHash<QString, int> oneTimeKeysCount;
+ std::vector<std::unique_ptr<EncryptedEvent>> pendingEncryptedEvents;
+ void handleEncryptedToDeviceEvent(const EncryptedEvent& event);
// A map from SenderKey to vector of InboundSession
UnorderedMap<QString, std::vector<QOlmSessionPtr>> olmSessions;
@@ -219,7 +221,7 @@ public:
}
q->database()->saveOlmSession(senderKey, session->sessionId(), std::get<QByteArray>(pickleResult), QDateTime::currentDateTime());
}
- QString sessionDecryptPrekey(const QOlmMessage& message, const QString &senderKey, std::unique_ptr<QOlmAccount>& olmAccount)
+ std::pair<QString, QString> sessionDecryptPrekey(const QOlmMessage& message, const QString &senderKey, std::unique_ptr<QOlmAccount>& olmAccount)
{
Q_ASSERT(message.type() == QOlmMessage::PreKey);
for (size_t i = 0; i < olmSessions[senderKey].size(); i++) {
@@ -239,7 +241,7 @@ public:
auto s = std::move(session);
olmSessions[senderKey].erase(olmSessions[senderKey].begin() + i);
olmSessions[senderKey].insert(olmSessions[senderKey].begin(), std::move(s));
- return std::get<QString>(result);
+ return { std::get<QString>(result), session->sessionId() };
} else {
qCDebug(E2EE) << "Failed to decrypt prekey message";
return {};
@@ -258,16 +260,17 @@ public:
qWarning(E2EE) << "Failed to remove one time key for session" << newSession->sessionId();
}
const auto result = newSession->decrypt(message);
+ QString sessionId = newSession->sessionId();
saveSession(newSession, senderKey);
olmSessions[senderKey].insert(olmSessions[senderKey].begin(), std::move(newSession));
if(std::holds_alternative<QString>(result)) {
- return std::get<QString>(result);
+ return { std::get<QString>(result), sessionId };
} else {
qCDebug(E2EE) << "Failed to decrypt prekey message with new session";
return {};
}
}
- QString sessionDecryptGeneral(const QOlmMessage& message, const QString &senderKey)
+ std::pair<QString, QString> sessionDecryptGeneral(const QOlmMessage& message, const QString &senderKey)
{
Q_ASSERT(message.type() == QOlmMessage::General);
for (size_t i = 0; i < olmSessions[senderKey].size(); i++) {
@@ -284,31 +287,36 @@ public:
auto s = std::move(session);
olmSessions[senderKey].erase(olmSessions[senderKey].begin() + i);
olmSessions[senderKey].insert(olmSessions[senderKey].begin(), std::move(s));
- return std::get<QString>(result);
+ return { std::get<QString>(result), session->sessionId() };
}
}
qCWarning(E2EE) << "Failed to decrypt message";
return {};
}
- QString sessionDecryptMessage(
+ std::pair<QString, QString> sessionDecryptMessage(
const QJsonObject& personalCipherObject, const QByteArray& senderKey, std::unique_ptr<QOlmAccount>& account)
{
QString decrypted;
+ QString olmSessionId;
int type = personalCipherObject.value(TypeKeyL).toInt(-1);
QByteArray body = personalCipherObject.value(BodyKeyL).toString().toLatin1();
if (type == QOlmMessage::PreKey) {
QOlmMessage preKeyMessage(body, QOlmMessage::PreKey);
- decrypted = sessionDecryptPrekey(preKeyMessage, senderKey, account);
+ auto result = sessionDecryptPrekey(preKeyMessage, senderKey, account);
+ decrypted = result.first;
+ olmSessionId = result.second;
} else if (type == QOlmMessage::General) {
QOlmMessage message(body, QOlmMessage::General);
- decrypted = sessionDecryptGeneral(message, senderKey);
+ auto result = sessionDecryptGeneral(message, senderKey);
+ decrypted = result.first;
+ olmSessionId = result.second;
}
- return decrypted;
+ return { decrypted, olmSessionId };
}
#endif
- EventPtr sessionDecryptMessage(const EncryptedEvent& encryptedEvent)
+ std::pair<EventPtr, QString> sessionDecryptMessage(const EncryptedEvent& encryptedEvent)
{
#ifndef Quotient_E2EE_ENABLED
qCWarning(E2EE) << "End-to-end encryption (E2EE) support is turned off.";
@@ -324,7 +332,7 @@ public:
qCDebug(E2EE) << "Encrypted event is not for the current device";
return {};
}
- const auto decrypted = sessionDecryptMessage(
+ const auto [decrypted, olmSessionId] = sessionDecryptMessage(
personalCipherObject, encryptedEvent.senderKey().toLatin1(), olmAccount);
if (decrypted.isEmpty()) {
qCDebug(E2EE) << "Problem with new session from senderKey:"
@@ -343,9 +351,17 @@ public:
<< "in Olm plaintext";
return {};
}
- //TODO make this do the check mentioned in the E2EE Implementation guide instead
- if (decryptedEvent->fullJson()["keys"]["ed25519"].toString().isEmpty()) {
- qCDebug(E2EE) << "Event does not contain an ed25519 key";
+
+ auto query = database->prepareQuery(QStringLiteral("SELECT edKey FROM tracked_devices WHERE curveKey=:curveKey;"));
+ query.bindValue(":curveKey", encryptedEvent.contentJson()["sender_key"].toString());
+ database->execute(query);
+ if (!query.next()) {
+ qCWarning(E2EE) << "Received olm message from unknown device" << encryptedEvent.contentJson()["sender_key"].toString();
+ return {};
+ }
+ auto edKey = decryptedEvent->fullJson()["keys"]["ed25519"].toString();
+ if (edKey.isEmpty() || query.value(QStringLiteral("edKey")).toString() != edKey) {
+ qCDebug(E2EE) << "Received olm message with invalid ed key";
return {};
}
@@ -367,7 +383,7 @@ public:
return {};
}
- return std::move(decryptedEvent);
+ return { std::move(decryptedEvent), olmSessionId };
#endif // Quotient_E2EE_ENABLED
}
#ifdef Quotient_E2EE_ENABLED
@@ -950,30 +966,46 @@ void Connection::Private::consumeToDeviceEvents(Events&& toDeviceEvents)
qCDebug(E2EE) << "Unsupported algorithm" << event.id() << "for event" << event.algorithm();
return;
}
- const auto decryptedEvent = sessionDecryptMessage(event);
- if(!decryptedEvent) {
- qCWarning(E2EE) << "Failed to decrypt event" << event.id();
+ if (q->isKnownCurveKey(event.senderId(), event.senderKey())) {
+ handleEncryptedToDeviceEvent(event);
return;
}
-
- switchOnType(*decryptedEvent,
- [this, senderKey = event.senderKey()](const RoomKeyEvent& roomKeyEvent) {
- if (auto* detectedRoom = q->room(roomKeyEvent.roomId())) {
- detectedRoom->handleRoomKeyEvent(roomKeyEvent, senderKey);
- } else {
- qCDebug(E2EE) << "Encrypted event room id" << roomKeyEvent.roomId()
- << "is not found at the connection" << q->objectName();
- }
- },
- [](const Event& evt) {
- qCDebug(E2EE) << "Skipping encrypted to_device event, type"
- << evt.matrixType();
- });
+ trackedUsers += event.senderId();
+ outdatedUsers += event.senderId();
+ encryptionUpdateRequired = true;
+ pendingEncryptedEvents.push_back(std::make_unique<EncryptedEvent>(event.fullJson()));
+ }, [](const Event& e){
+ // Unhandled
});
}
#endif
}
+#ifdef Quotient_E2EE_ENABLED
+void Connection::Private::handleEncryptedToDeviceEvent(const EncryptedEvent& event)
+{
+ const auto [decryptedEvent, olmSessionId] = sessionDecryptMessage(event);
+ if(!decryptedEvent) {
+ qCWarning(E2EE) << "Failed to decrypt event" << event.id();
+ return;
+ }
+
+ switchOnType(*decryptedEvent,
+ [this, senderKey = event.senderKey(), &event, olmSessionId = olmSessionId](const RoomKeyEvent& roomKeyEvent) {
+ if (auto* detectedRoom = q->room(roomKeyEvent.roomId())) {
+ detectedRoom->handleRoomKeyEvent(roomKeyEvent, event.senderId(), olmSessionId);
+ } else {
+ qCDebug(E2EE) << "Encrypted event room id" << roomKeyEvent.roomId()
+ << "is not found at the connection" << q->objectName();
+ }
+ },
+ [](const Event& evt) {
+ qCDebug(E2EE) << "Skipping encrypted to_device event, type"
+ << evt.matrixType();
+ });
+}
+#endif
+
void Connection::Private::consumeDevicesList(DevicesList&& devicesList)
{
#ifdef Quotient_E2EE_ENABLED
@@ -2059,6 +2091,15 @@ void Connection::Private::loadOutdatedUserDevices()
outdatedUsers -= user;
}
saveDevicesList();
+
+ for(size_t i = 0; i < pendingEncryptedEvents.size();) {
+ if (q->isKnownCurveKey(pendingEncryptedEvents[i]->fullJson()[SenderKeyL].toString(), pendingEncryptedEvents[i]->contentJson()["sender_key"].toString())) {
+ handleEncryptedToDeviceEvent(*(pendingEncryptedEvents[i].get()));
+ pendingEncryptedEvents.erase(pendingEncryptedEvents.begin() + i);
+ } else {
+ i++;
+ }
+ }
});
}
@@ -2180,14 +2221,14 @@ Database* Connection::database()
return d->database;
}
-UnorderedMap<std::pair<QString, QString>, QOlmInboundGroupSessionPtr> Connection::loadRoomMegolmSessions(Room* room)
+UnorderedMap<QString, QOlmInboundGroupSessionPtr> Connection::loadRoomMegolmSessions(Room* room)
{
return database()->loadMegolmSessions(room->id(), picklingMode());
}
-void Connection::saveMegolmSession(Room* room, const QString& senderKey, QOlmInboundGroupSession* session, const QString& ed25519Key)
+void Connection::saveMegolmSession(Room* room, QOlmInboundGroupSession* session)
{
- database()->saveMegolmSession(room->id(), senderKey, session->sessionId(), ed25519Key, session->pickle(picklingMode()));
+ database()->saveMegolmSession(room->id(), session->sessionId(), session->pickle(picklingMode()), session->senderId(), session->olmSessionId());
}
QStringList Connection::devicesForUser(User* user) const
@@ -2246,4 +2287,13 @@ void Connection::saveCurrentOutboundMegolmSession(Room *room, const QOlmOutbound
d->database->saveCurrentOutboundMegolmSession(room->id(), d->picklingMode, data);
}
+bool Connection::isKnownCurveKey(const QString& user, const QString& curveKey)
+{
+ auto query = database()->prepareQuery(QStringLiteral("SELECT * FROM tracked_devices WHERE matrixId=:matrixId AND curveKey=:curveKey"));
+ query.bindValue(":matrixId", user);
+ query.bindValue(":curveKey", curveKey);
+ database()->execute(query);
+ return query.next();
+}
+
#endif