aboutsummaryrefslogtreecommitdiff
path: root/lib/e2ee/qolmsession.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/e2ee/qolmsession.cpp')
-rw-r--r--lib/e2ee/qolmsession.cpp231
1 files changed, 231 insertions, 0 deletions
diff --git a/lib/e2ee/qolmsession.cpp b/lib/e2ee/qolmsession.cpp
new file mode 100644
index 00000000..e3f69132
--- /dev/null
+++ b/lib/e2ee/qolmsession.cpp
@@ -0,0 +1,231 @@
+// SPDX-FileCopyrightText: 2021 Alexey Andreyev <aa13q@ya.ru>
+//
+// SPDX-License-Identifier: LGPL-2.1-or-later
+
+#include "qolmsession.h"
+
+#include "e2ee/qolmutils.h"
+#include "logging.h"
+
+#include <cstring>
+#include <olm/olm.h>
+
+using namespace Quotient;
+
+OlmErrorCode QOlmSession::lastErrorCode() const {
+ return olm_session_last_error_code(m_session);
+}
+
+const char* QOlmSession::lastError() const
+{
+ return olm_session_last_error(m_session);
+}
+
+Quotient::QOlmSession::~QOlmSession()
+{
+ olm_clear_session(m_session);
+ delete[](reinterpret_cast<uint8_t *>(m_session));
+}
+
+OlmSession* QOlmSession::create()
+{
+ return olm_session(new uint8_t[olm_session_size()]);
+}
+
+QOlmExpected<QOlmSessionPtr> QOlmSession::createInbound(
+ QOlmAccount* account, const QOlmMessage& preKeyMessage, bool from,
+ const QString& theirIdentityKey)
+{
+ if (preKeyMessage.type() != QOlmMessage::PreKey) {
+ qCCritical(E2EE) << "The message is not a pre-key; will try to create "
+ "the inbound session anyway";
+ }
+
+ const auto olmSession = create();
+
+ QByteArray oneTimeKeyMessageBuf = preKeyMessage.toCiphertext();
+ QByteArray theirIdentityKeyBuf = theirIdentityKey.toUtf8();
+ size_t error = 0;
+ if (from) {
+ error = olm_create_inbound_session_from(olmSession, account->data(), theirIdentityKeyBuf.data(), theirIdentityKeyBuf.length(), oneTimeKeyMessageBuf.data(), oneTimeKeyMessageBuf.length());
+ } else {
+ error = olm_create_inbound_session(olmSession, account->data(), oneTimeKeyMessageBuf.data(), oneTimeKeyMessageBuf.length());
+ }
+
+ if (error == olm_error()) {
+ // FIXME: the QOlmSession object should be created earlier
+ const auto lastErr = olm_session_last_error_code(olmSession);
+ qCWarning(E2EE) << "Error when creating inbound session" << lastErr;
+ return lastErr;
+ }
+
+ return std::make_unique<QOlmSession>(olmSession);
+}
+
+QOlmExpected<QOlmSessionPtr> QOlmSession::createInboundSession(
+ QOlmAccount* account, const QOlmMessage& preKeyMessage)
+{
+ return createInbound(account, preKeyMessage);
+}
+
+QOlmExpected<QOlmSessionPtr> QOlmSession::createInboundSessionFrom(
+ QOlmAccount* account, const QString& theirIdentityKey,
+ const QOlmMessage& preKeyMessage)
+{
+ return createInbound(account, preKeyMessage, true, theirIdentityKey);
+}
+
+QOlmExpected<QOlmSessionPtr> QOlmSession::createOutboundSession(
+ QOlmAccount* account, const QByteArray& theirIdentityKey,
+ const QByteArray& theirOneTimeKey)
+{
+ auto* olmOutboundSession = create();
+ if (const auto randomLength =
+ olm_create_outbound_session_random_length(olmOutboundSession);
+ olm_create_outbound_session(
+ olmOutboundSession, account->data(), theirIdentityKey.data(),
+ theirIdentityKey.length(), theirOneTimeKey.data(),
+ theirOneTimeKey.length(), RandomBuffer(randomLength), randomLength)
+ == olm_error()) {
+ // FIXME: the QOlmSession object should be created earlier
+ const auto lastErr = olm_session_last_error_code(olmOutboundSession);
+ QOLM_FAIL_OR_LOG_X(lastErr == OLM_NOT_ENOUGH_RANDOM,
+ "Failed to create an outbound Olm session",
+ olm_session_last_error(olmOutboundSession));
+ return lastErr;
+ }
+
+ return std::make_unique<QOlmSession>(olmOutboundSession);
+}
+
+QByteArray QOlmSession::pickle(const PicklingMode &mode) const
+{
+ QByteArray pickledBuf(olm_pickle_session_length(m_session), '\0');
+ QByteArray key = toKey(mode);
+ if (olm_pickle_session(m_session, key.data(), key.length(),
+ pickledBuf.data(), pickledBuf.length())
+ == olm_error())
+ QOLM_INTERNAL_ERROR("Failed to pickle an Olm session");
+
+ key.clear();
+ return pickledBuf;
+}
+
+QOlmExpected<QOlmSessionPtr> QOlmSession::unpickle(QByteArray&& pickled,
+ const PicklingMode& mode)
+{
+ auto *olmSession = create();
+ auto key = toKey(mode);
+ if (olm_unpickle_session(olmSession, key.data(), key.length(),
+ pickled.data(), pickled.length())
+ == olm_error()) {
+ // FIXME: the QOlmSession object should be created earlier
+ const auto errorCode = olm_session_last_error_code(olmSession);
+ QOLM_FAIL_OR_LOG_X(errorCode == OLM_OUTPUT_BUFFER_TOO_SMALL,
+ "Failed to unpickle an Olm session",
+ olm_session_last_error(olmSession));
+ return errorCode;
+ }
+
+ key.clear();
+ return std::make_unique<QOlmSession>(olmSession);
+}
+
+QOlmMessage QOlmSession::encrypt(const QByteArray& plaintext)
+{
+ const auto messageMaxLength =
+ olm_encrypt_message_length(m_session, plaintext.length());
+ QByteArray messageBuf(messageMaxLength, '\0');
+ // NB: The type has to be calculated before calling olm_encrypt()
+ const auto messageType = olm_encrypt_message_type(m_session);
+ if (const auto randomLength = olm_encrypt_random_length(m_session);
+ olm_encrypt(m_session, plaintext.data(), plaintext.length(),
+ RandomBuffer(randomLength), randomLength, messageBuf.data(),
+ messageBuf.length())
+ == olm_error()) {
+ QOLM_INTERNAL_ERROR("Failed to encrypt the message");
+ }
+
+ return QOlmMessage(messageBuf, QOlmMessage::Type(messageType));
+}
+
+QOlmExpected<QByteArray> QOlmSession::decrypt(const QOlmMessage &message) const
+{
+ const auto ciphertext = message.toCiphertext();
+ const auto messageTypeValue = message.type();
+
+ // We need to clone the message because
+ // olm_decrypt_max_plaintext_length destroys the input buffer
+ QByteArray messageBuf(ciphertext.length(), '\0');
+ std::copy(message.begin(), message.end(), messageBuf.begin());
+
+ const auto plaintextMaxLen = olm_decrypt_max_plaintext_length(
+ m_session, messageTypeValue, messageBuf.data(), messageBuf.length());
+ if (plaintextMaxLen == olm_error()) {
+ qWarning(E2EE) << "Couldn't calculate decrypted message length:"
+ << lastError();
+ return lastErrorCode();
+ }
+
+ QByteArray plaintextBuf(plaintextMaxLen, '\0');
+ QByteArray messageBuf2(ciphertext.length(), '\0');
+ std::copy(message.begin(), message.end(), messageBuf2.begin());
+
+ const auto plaintextResultLen =
+ olm_decrypt(m_session, messageTypeValue, messageBuf2.data(),
+ messageBuf2.length(), plaintextBuf.data(), plaintextMaxLen);
+ if (plaintextResultLen == olm_error()) {
+ QOLM_FAIL_OR_LOG(OLM_OUTPUT_BUFFER_TOO_SMALL,
+ "Failed to decrypt the message");
+ return lastErrorCode();
+ }
+ plaintextBuf.truncate(plaintextResultLen);
+ return plaintextBuf;
+}
+
+QByteArray QOlmSession::sessionId() const
+{
+ const auto idMaxLength = olm_session_id_length(m_session);
+ QByteArray idBuffer(idMaxLength, '\0');
+ if (olm_session_id(m_session, idBuffer.data(), idMaxLength) == olm_error())
+ QOLM_INTERNAL_ERROR("Failed to obtain Olm session id");
+
+ return idBuffer;
+}
+
+bool QOlmSession::hasReceivedMessage() const
+{
+ return olm_session_has_received_message(m_session);
+}
+
+bool QOlmSession::matchesInboundSession(const QOlmMessage& preKeyMessage) const
+{
+ Q_ASSERT(preKeyMessage.type() == QOlmMessage::Type::PreKey);
+ QByteArray oneTimeKeyBuf(preKeyMessage.data());
+ const auto maybeMatches =
+ olm_matches_inbound_session(m_session, oneTimeKeyBuf.data(),
+ oneTimeKeyBuf.length());
+ if (maybeMatches == olm_error())
+ qWarning(E2EE) << "Error matching an inbound session:" << lastError();
+
+ return maybeMatches == 1; // Any errors are treated as non-match
+}
+
+bool QOlmSession::matchesInboundSessionFrom(
+ const QString& theirIdentityKey, const QOlmMessage& preKeyMessage) const
+{
+ const auto theirIdentityKeyBuf = theirIdentityKey.toUtf8();
+ auto oneTimeKeyMessageBuf = preKeyMessage.toCiphertext();
+ const auto maybeMatches = olm_matches_inbound_session_from(
+ m_session, theirIdentityKeyBuf.data(), theirIdentityKeyBuf.length(),
+ oneTimeKeyMessageBuf.data(), oneTimeKeyMessageBuf.length());
+
+ if (maybeMatches == olm_error())
+ qCWarning(E2EE) << "Error matching an inbound session:" << lastError();
+
+ return maybeMatches == 1;
+}
+
+QOlmSession::QOlmSession(OlmSession *session)
+ : m_session(session)
+{}