diff options
-rw-r--r-- | lib/connection.cpp | 12 | ||||
-rw-r--r-- | lib/connection.h | 6 | ||||
-rw-r--r-- | lib/database.cpp | 23 | ||||
-rw-r--r-- | lib/database.h | 6 | ||||
-rw-r--r-- | lib/room.cpp | 51 |
5 files changed, 47 insertions, 51 deletions
diff --git a/lib/connection.cpp b/lib/connection.cpp index 66e21a2a..dba18cb1 100644 --- a/lib/connection.cpp +++ b/lib/connection.cpp @@ -2214,9 +2214,9 @@ void Connection::saveMegolmSession(const Room* room, session.senderId(), session.olmSessionId()); } -QStringList Connection::devicesForUser(User* user) const +QStringList Connection::devicesForUser(const QString& userId) const { - return d->deviceKeys[user->id()].keys(); + return d->deviceKeys[userId].keys(); } QString Connection::curveKeyForUserDevice(const QString& user, const QString& device) const @@ -2238,15 +2238,15 @@ bool Connection::isKnownCurveKey(const QString& user, const QString& curveKey) return query.next(); } -bool Connection::hasOlmSession(User* user, const QString& deviceId) const +bool Connection::hasOlmSession(const QString& user, const QString& deviceId) const { - const auto& curveKey = curveKeyForUserDevice(user->id(), deviceId); + const auto& curveKey = curveKeyForUserDevice(user, deviceId); return d->olmSessions.contains(curveKey) && !d->olmSessions[curveKey].empty(); } -QPair<QOlmMessage::Type, QByteArray> Connection::olmEncryptMessage(User* user, const QString& device, const QByteArray& message) +QPair<QOlmMessage::Type, QByteArray> Connection::olmEncryptMessage(const QString& user, const QString& device, const QByteArray& message) { - const auto& curveKey = curveKeyForUserDevice(user->id(), device); + const auto& curveKey = curveKeyForUserDevice(user, device); QOlmMessage::Type type = d->olmSessions[curveKey][0]->encryptMessageType(); auto result = d->olmSessions[curveKey][0]->encrypt(message); auto pickle = d->olmSessions[curveKey][0]->pickle(picklingMode()); diff --git a/lib/connection.h b/lib/connection.h index 5b266aad..f8744752 100644 --- a/lib/connection.h +++ b/lib/connection.h @@ -323,14 +323,14 @@ public: const Room* room); void saveMegolmSession(const Room* room, const QOlmInboundGroupSession& session); - bool hasOlmSession(User* user, const QString& deviceId) const; + bool hasOlmSession(const QString& user, const QString& deviceId) const; QOlmOutboundGroupSessionPtr loadCurrentOutboundMegolmSession(Room* room); void saveCurrentOutboundMegolmSession(Room *room, const QOlmOutboundGroupSessionPtr& data); //This assumes that an olm session with (user, device) exists - QPair<QOlmMessage::Type, QByteArray> olmEncryptMessage(User* user, const QString& device, const QByteArray& message); + QPair<QOlmMessage::Type, QByteArray> olmEncryptMessage(const QString& userId, const QString& device, const QByteArray& message); void createOlmSession(const QString& theirIdentityKey, const QString& theirOneTimeKey); #endif // Quotient_E2EE_ENABLED Q_INVOKABLE Quotient::SyncJob* syncJob() const; @@ -694,7 +694,7 @@ public Q_SLOTS: PicklingMode picklingMode() const; QJsonObject decryptNotification(const QJsonObject ¬ification); - QStringList devicesForUser(User* user) const; + QStringList devicesForUser(const QString& user) const; QString curveKeyForUserDevice(const QString &user, const QString& device) const; QString edKeyForUserDevice(const QString& user, const QString& device) const; bool isKnownCurveKey(const QString& user, const QString& curveKey); diff --git a/lib/database.cpp b/lib/database.cpp index 99c6f358..3255e5e7 100644 --- a/lib/database.cpp +++ b/lib/database.cpp @@ -13,9 +13,7 @@ #include "e2ee/e2ee.h" #include "e2ee/qolmsession.h" #include "e2ee/qolminboundsession.h" -#include "connection.h" -#include "user.h" -#include "room.h" +#include "e2ee/qolmoutboundsession.h" using namespace Quotient; Database::Database(const QString& matrixId, const QString& deviceId, QObject* parent) @@ -348,17 +346,16 @@ QOlmOutboundGroupSessionPtr Database::loadCurrentOutboundMegolmSession(const QSt return nullptr; } -void Database::setDevicesReceivedKey(const QString& roomId, QHash<User *, QStringList> devices, const QString& sessionId, int index) +void Database::setDevicesReceivedKey(const QString& roomId, const QHash<QString, QList<std::pair<QString, QString>>>& devices, const QString& sessionId, int index) { - auto connection = dynamic_cast<Connection *>(parent()); transaction(); for (const auto& user : devices.keys()) { - for (const auto& device : devices[user]) { + for (const auto& [device, curveKey] : devices[user]) { auto query = prepareQuery(QStringLiteral("INSERT INTO sent_megolm_sessions(roomId, userId, deviceId, identityKey, sessionId, i) VALUES(:roomId, :userId, :deviceId, :identityKey, :sessionId, :i);")); query.bindValue(":roomId", roomId); - query.bindValue(":userId", user->id()); + query.bindValue(":userId", user); query.bindValue(":deviceId", device); - query.bindValue(":identityKey", connection->curveKeyForUserDevice(user->id(), device)); + query.bindValue(":identityKey", curveKey); query.bindValue(":sessionId", sessionId); query.bindValue(":i", index); execute(query); @@ -367,16 +364,10 @@ void Database::setDevicesReceivedKey(const QString& roomId, QHash<User *, QStrin commit(); } -QHash<QString, QStringList> Database::devicesWithoutKey(Room* room, const QString &sessionId) +QHash<QString, QStringList> Database::devicesWithoutKey(const QString& roomId, QHash<QString, QStringList>& devices, const QString &sessionId) { - auto connection = dynamic_cast<Connection *>(parent()); - QHash<QString, QStringList> devices; - for (const auto& user : room->users()) { - devices[user->id()] = connection->devicesForUser(user); - } - auto query = prepareQuery(QStringLiteral("SELECT userId, deviceId FROM sent_megolm_sessions WHERE roomId=:roomId AND sessionId=:sessionId")); - query.bindValue(":roomId", room->id()); + query.bindValue(":roomId", roomId); query.bindValue(":sessionId", sessionId); transaction(); execute(query); diff --git a/lib/database.h b/lib/database.h index 00002204..8bef332f 100644 --- a/lib/database.h +++ b/lib/database.h @@ -44,9 +44,9 @@ public: void saveCurrentOutboundMegolmSession(const QString& roomId, const PicklingMode& picklingMode, const QOlmOutboundGroupSessionPtr& data); void updateOlmSession(const QString& senderKey, const QString& sessionId, const QByteArray& pickle); - // Returns a map User -> [Device] that have not received key yet - QHash<QString, QStringList> devicesWithoutKey(Room* room, const QString &sessionId); - void setDevicesReceivedKey(const QString& roomId, QHash<User *, QStringList> devices, const QString& sessionId, int index); + // Returns a map UserId -> [DeviceId] that have not received key yet + QHash<QString, QStringList> devicesWithoutKey(const QString& roomId, QHash<QString, QStringList>& devices, const QString &sessionId); + void setDevicesReceivedKey(const QString& roomId, const QHash<QString, QList<std::pair<QString, QString>>>& devices, const QString& sessionId, int index); private: void migrateTo1(); diff --git a/lib/room.cpp b/lib/room.cpp index d77bf9ef..5d3ae329 100644 --- a/lib/room.cpp +++ b/lib/room.cpp @@ -455,16 +455,16 @@ public: addInboundGroupSession(currentOutboundMegolmSession->sessionId(), *sessionKey, q->localUser()->id(), "SELF"_ls); } - std::unique_ptr<EncryptedEvent> payloadForUserDevice(User* user, const QString& device, const QByteArray& sessionId, const QByteArray& sessionKey) + std::unique_ptr<EncryptedEvent> payloadForUserDevice(QString user, const QString& device, const QByteArray& sessionId, const QByteArray& sessionKey) { // Noisy but nice for debugging //qCDebug(E2EE) << "Creating the payload for" << user->id() << device << sessionId << sessionKey.toHex(); const auto event = makeEvent<RoomKeyEvent>("m.megolm.v1.aes-sha2", q->id(), sessionId, sessionKey, q->localUser()->id()); QJsonObject payloadJson = event->fullJson(); - payloadJson["recipient"] = user->id(); + payloadJson["recipient"] = user; payloadJson["sender"] = connection->user()->id(); QJsonObject recipientObject; - recipientObject["ed25519"] = connection->edKeyForUserDevice(user->id(), device); + recipientObject["ed25519"] = connection->edKeyForUserDevice(user, device); payloadJson["recipient_keys"] = recipientObject; QJsonObject senderObject; senderObject["ed25519"] = QString(connection->olmAccount()->identityKeys().ed25519); @@ -472,22 +472,21 @@ public: payloadJson["sender_device"] = connection->deviceId(); auto cipherText = connection->olmEncryptMessage(user, device, QJsonDocument(payloadJson).toJson(QJsonDocument::Compact)); QJsonObject encrypted; - encrypted[connection->curveKeyForUserDevice(user->id(), device)] = QJsonObject{{"type", cipherText.first}, {"body", QString(cipherText.second)}}; + encrypted[connection->curveKeyForUserDevice(user, device)] = QJsonObject{{"type", cipherText.first}, {"body", QString(cipherText.second)}}; return makeEvent<EncryptedEvent>(encrypted, connection->olmAccount()->identityKeys().curve25519); } - QHash<User*, QStringList> getDevicesWithoutKey() const + QHash<QString, QStringList> getDevicesWithoutKey() const { - QHash<User*, QStringList> devices; - auto rawDevices = q->connection()->database()->devicesWithoutKey(q, QString(currentOutboundMegolmSession->sessionId())); - for (const auto& user : rawDevices.keys()) { - devices[q->connection()->user(user)] = rawDevices[user]; + QHash<QString, QStringList> devices; + for (const auto& user : q->users()) { + devices[user->id()] = q->connection()->devicesForUser(user->id()); } - return devices; + return q->connection()->database()->devicesWithoutKey(q->id(), devices, QString(currentOutboundMegolmSession->sessionId())); } - void sendRoomKeyToDevices(const QByteArray& sessionId, const QByteArray& sessionKey, const QHash<User*, QStringList> devices, int index) + void sendRoomKeyToDevices(const QByteArray& sessionId, const QByteArray& sessionKey, const QHash<QString, QStringList> devices, int index) { qCDebug(E2EE) << "Sending room key to devices" << sessionId, sessionKey.toHex(); QHash<QString, QHash<QString, QString>> hash; @@ -500,7 +499,7 @@ public: } } if (!u.isEmpty()) { - hash[user->id()] = u; + hash[user] = u; } } if (hash.isEmpty()) { @@ -512,38 +511,44 @@ public: const auto data = job->jsonData(); for(const auto &user : devices.keys()) { for(const auto &device : devices[user]) { - const auto recipientCurveKey = connection->curveKeyForUserDevice(user->id(), device); + const auto recipientCurveKey = connection->curveKeyForUserDevice(user, device); if (!connection->hasOlmSession(user, device)) { qCDebug(E2EE) << "Creating a new session for" << user << device; - if(data["one_time_keys"][user->id()][device].toObject().isEmpty()) { + if(data["one_time_keys"][user][device].toObject().isEmpty()) { qWarning() << "No one time key for" << user << device; continue; } - const auto keyId = data["one_time_keys"][user->id()][device].toObject().keys()[0]; - const auto oneTimeKey = data["one_time_keys"][user->id()][device][keyId]["key"].toString(); - const auto signature = data["one_time_keys"][user->id()][device][keyId]["signatures"][user->id()][QStringLiteral("ed25519:") + device].toString().toLatin1(); - auto signedData = data["one_time_keys"][user->id()][device][keyId].toObject(); + const auto keyId = data["one_time_keys"][user][device].toObject().keys()[0]; + const auto oneTimeKey = data["one_time_keys"][user][device][keyId]["key"].toString(); + const auto signature = data["one_time_keys"][user][device][keyId]["signatures"][user][QStringLiteral("ed25519:") + device].toString().toLatin1(); + auto signedData = data["one_time_keys"][user][device][keyId].toObject(); signedData.remove("unsigned"); signedData.remove("signatures"); - auto signatureMatch = QOlmUtility().ed25519Verify(connection->edKeyForUserDevice(user->id(), device).toLatin1(), QJsonDocument(signedData).toJson(QJsonDocument::Compact), signature); + auto signatureMatch = QOlmUtility().ed25519Verify(connection->edKeyForUserDevice(user, device).toLatin1(), QJsonDocument(signedData).toJson(QJsonDocument::Compact), signature); if (!signatureMatch) { - qCWarning(E2EE) << "Failed to verify one-time-key signature for" << user->id() << device << ". Skipping this device."; + qCWarning(E2EE) << "Failed to verify one-time-key signature for" << user << device << ". Skipping this device."; continue; } else { } connection->createOlmSession(recipientCurveKey, oneTimeKey); } - usersToDevicesToEvents[user->id()][device] = payloadForUserDevice(user, device, sessionId, sessionKey); + usersToDevicesToEvents[user][device] = payloadForUserDevice(user, device, sessionId, sessionKey); } } if (!usersToDevicesToEvents.empty()) { connection->sendToDevices("m.room.encrypted", usersToDevicesToEvents); - connection->database()->setDevicesReceivedKey(q->id(), devices, sessionId, index); + QHash<QString, QList<std::pair<QString, QString>>> receivedDevices; + for (const auto& user : devices.keys()) { + for (const auto& device : devices[user]) { + receivedDevices[user] += {device, q->connection()->curveKeyForUserDevice(user, device) }; + } + } + connection->database()->setDevicesReceivedKey(q->id(), receivedDevices, sessionId, index); } }); } - void sendMegolmSession(const QHash<User *, QStringList>& devices) { + void sendMegolmSession(const QHash<QString, QStringList>& devices) { // Save the session to this device const auto sessionId = currentOutboundMegolmSession->sessionId(); const auto _sessionKey = currentOutboundMegolmSession->sessionKey(); |