diff options
Diffstat (limited to 'lib/e2ee/qolmsession.cpp')
-rw-r--r-- | lib/e2ee/qolmsession.cpp | 251 |
1 files changed, 251 insertions, 0 deletions
diff --git a/lib/e2ee/qolmsession.cpp b/lib/e2ee/qolmsession.cpp new file mode 100644 index 00000000..e575ff39 --- /dev/null +++ b/lib/e2ee/qolmsession.cpp @@ -0,0 +1,251 @@ +// SPDX-FileCopyrightText: 2021 Alexey Andreyev <aa13q@ya.ru> +// +// SPDX-License-Identifier: LGPL-2.1-or-later + +#include "qolmsession.h" +#include "e2ee/qolmutils.h" +#include "logging.h" +#include <cstring> +#include <QDebug> + +using namespace Quotient; + +QOlmError lastError(OlmSession* session) { + return fromString(olm_session_last_error(session)); +} + +Quotient::QOlmSession::~QOlmSession() +{ + olm_clear_session(m_session); + delete[](reinterpret_cast<uint8_t *>(m_session)); +} + +OlmSession* QOlmSession::create() +{ + return olm_session(new uint8_t[olm_session_size()]); +} + +std::variant<QOlmSessionPtr, QOlmError> QOlmSession::createInbound(QOlmAccount *account, const QOlmMessage &preKeyMessage, bool from, const QString &theirIdentityKey) +{ + if (preKeyMessage.type() != QOlmMessage::PreKey) { + qCCritical(E2EE) << "The message is not a pre-key in when creating inbound session" << BadMessageFormat; + } + + const auto olmSession = create(); + + QByteArray oneTimeKeyMessageBuf = preKeyMessage.toCiphertext(); + QByteArray theirIdentityKeyBuf = theirIdentityKey.toUtf8(); + size_t error = 0; + if (from) { + error = olm_create_inbound_session_from(olmSession, account->data(), theirIdentityKeyBuf.data(), theirIdentityKeyBuf.length(), oneTimeKeyMessageBuf.data(), oneTimeKeyMessageBuf.length()); + } else { + error = olm_create_inbound_session(olmSession, account->data(), oneTimeKeyMessageBuf.data(), oneTimeKeyMessageBuf.length()); + } + + if (error == olm_error()) { + const auto lastErr = lastError(olmSession); + qCWarning(E2EE) << "Error when creating inbound session" << lastErr; + return lastErr; + } + + return std::make_unique<QOlmSession>(olmSession); +} + +std::variant<QOlmSessionPtr, QOlmError> QOlmSession::createInboundSession(QOlmAccount *account, const QOlmMessage &preKeyMessage) +{ + return createInbound(account, preKeyMessage); +} + +std::variant<QOlmSessionPtr, QOlmError> QOlmSession::createInboundSessionFrom(QOlmAccount *account, const QString &theirIdentityKey, const QOlmMessage &preKeyMessage) +{ + return createInbound(account, preKeyMessage, true, theirIdentityKey); +} + +std::variant<QOlmSessionPtr, QOlmError> QOlmSession::createOutboundSession(QOlmAccount *account, const QString &theirIdentityKey, const QString &theirOneTimeKey) +{ + auto *olmOutboundSession = create(); + const auto randomLen = olm_create_outbound_session_random_length(olmOutboundSession); + QByteArray randomBuf = getRandom(randomLen); + + QByteArray theirIdentityKeyBuf = theirIdentityKey.toUtf8(); + QByteArray theirOneTimeKeyBuf = theirOneTimeKey.toUtf8(); + const auto error = olm_create_outbound_session(olmOutboundSession, + account->data(), + reinterpret_cast<uint8_t *>(theirIdentityKeyBuf.data()), theirIdentityKeyBuf.length(), + reinterpret_cast<uint8_t *>(theirOneTimeKeyBuf.data()), theirOneTimeKeyBuf.length(), + reinterpret_cast<uint8_t *>(randomBuf.data()), randomBuf.length()); + + if (error == olm_error()) { + const auto lastErr = lastError(olmOutboundSession); + if (lastErr == QOlmError::NotEnoughRandom) { + throw lastErr; + } + return lastErr; + } + + randomBuf.clear(); + return std::make_unique<QOlmSession>(olmOutboundSession); +} + +std::variant<QByteArray, QOlmError> QOlmSession::pickle(const PicklingMode &mode) +{ + QByteArray pickledBuf(olm_pickle_session_length(m_session), '0'); + QByteArray key = toKey(mode); + const auto error = olm_pickle_session(m_session, key.data(), key.length(), + pickledBuf.data(), pickledBuf.length()); + + if (error == olm_error()) { + return lastError(m_session); + } + + key.clear(); + + return pickledBuf; +} + +std::variant<QOlmSessionPtr, QOlmError> QOlmSession::unpickle(const QByteArray &pickled, const PicklingMode &mode) +{ + QByteArray pickledBuf = pickled; + auto *olmSession = create(); + QByteArray key = toKey(mode); + const auto error = olm_unpickle_session(olmSession, key.data(), key.length(), + pickledBuf.data(), pickledBuf.length()); + if (error == olm_error()) { + return lastError(olmSession); + } + + key.clear(); + return std::make_unique<QOlmSession>(olmSession); +} + +QOlmMessage QOlmSession::encrypt(const QString &plaintext) +{ + QByteArray plaintextBuf = plaintext.toUtf8(); + const auto messageMaxLen = olm_encrypt_message_length(m_session, plaintextBuf.length()); + QByteArray messageBuf(messageMaxLen, '0'); + const auto messageType = encryptMessageType(); + const auto randomLen = olm_encrypt_random_length(m_session); + QByteArray randomBuf = getRandom(randomLen); + const auto error = olm_encrypt(m_session, + reinterpret_cast<uint8_t *>(plaintextBuf.data()), plaintextBuf.length(), + reinterpret_cast<uint8_t *>(randomBuf.data()), randomBuf.length(), + reinterpret_cast<uint8_t *>(messageBuf.data()), messageBuf.length()); + + if (error == olm_error()) { + throw lastError(m_session); + } + + return QOlmMessage(messageBuf, messageType); +} + +std::variant<QString, QOlmError> QOlmSession::decrypt(const QOlmMessage &message) const +{ + const auto messageType = message.type(); + const auto ciphertext = message.toCiphertext(); + const auto messageTypeValue = messageType == QOlmMessage::Type::General + ? OLM_MESSAGE_TYPE_MESSAGE : OLM_MESSAGE_TYPE_PRE_KEY; + + // We need to clone the message because + // olm_decrypt_max_plaintext_length destroys the input buffer + QByteArray messageBuf(ciphertext.length(), '0'); + std::copy(message.begin(), message.end(), messageBuf.begin()); + + const auto plaintextMaxLen = olm_decrypt_max_plaintext_length(m_session, messageTypeValue, + reinterpret_cast<uint8_t *>(messageBuf.data()), messageBuf.length()); + + if (plaintextMaxLen == olm_error()) { + return lastError(m_session); + } + + QByteArray plaintextBuf(plaintextMaxLen, '0'); + QByteArray messageBuf2(ciphertext.length(), '0'); + std::copy(message.begin(), message.end(), messageBuf2.begin()); + + const auto plaintextResultLen = olm_decrypt(m_session, messageTypeValue, + reinterpret_cast<uint8_t *>(messageBuf2.data()), messageBuf2.length(), + reinterpret_cast<uint8_t *>(plaintextBuf.data()), plaintextMaxLen); + + if (plaintextResultLen == olm_error()) { + const auto lastErr = lastError(m_session); + if (lastErr == QOlmError::OutputBufferTooSmall) { + throw lastErr; + } + return lastErr; + } + QByteArray output(plaintextResultLen, '0'); + std::memcpy(output.data(), plaintextBuf.data(), plaintextResultLen); + plaintextBuf.clear(); + return output; +} + +QOlmMessage::Type QOlmSession::encryptMessageType() +{ + const auto messageTypeResult = olm_encrypt_message_type(m_session); + if (messageTypeResult == olm_error()) { + throw lastError(m_session); + } + if (messageTypeResult == OLM_MESSAGE_TYPE_PRE_KEY) { + return QOlmMessage::PreKey; + } + return QOlmMessage::General; +} + +QByteArray QOlmSession::sessionId() const +{ + const auto idMaxLength = olm_session_id_length(m_session); + QByteArray idBuffer(idMaxLength, '0'); + const auto error = olm_session_id(m_session, reinterpret_cast<uint8_t *>(idBuffer.data()), + idBuffer.length()); + if (error == olm_error()) { + throw lastError(m_session); + } + return idBuffer; +} + +bool QOlmSession::hasReceivedMessage() const +{ + return olm_session_has_received_message(m_session); +} + +std::variant<bool, QOlmError> QOlmSession::matchesInboundSession(const QOlmMessage &preKeyMessage) const +{ + Q_ASSERT(preKeyMessage.type() == QOlmMessage::Type::PreKey); + QByteArray oneTimeKeyBuf(preKeyMessage.data()); + const auto matchesResult = olm_matches_inbound_session(m_session, oneTimeKeyBuf.data(), oneTimeKeyBuf.length()); + + if (matchesResult == olm_error()) { + return lastError(m_session); + } + switch (matchesResult) { + case 0: + return false; + case 1: + return true; + default: + return QOlmError::Unknown; + } +} +std::variant<bool, QOlmError> QOlmSession::matchesInboundSessionFrom(const QString &theirIdentityKey, const QOlmMessage &preKeyMessage) const +{ + const auto theirIdentityKeyBuf = theirIdentityKey.toUtf8(); + auto oneTimeKeyMessageBuf = preKeyMessage.toCiphertext(); + const auto error = olm_matches_inbound_session_from(m_session, theirIdentityKeyBuf.data(), theirIdentityKeyBuf.length(), + oneTimeKeyMessageBuf.data(), oneTimeKeyMessageBuf.length()); + + if (error == olm_error()) { + return lastError(m_session); + } + switch (error) { + case 0: + return false; + case 1: + return true; + default: + return QOlmError::Unknown; + } +} + +QOlmSession::QOlmSession(OlmSession *session) + : m_session(session) +{ +} |