aboutsummaryrefslogtreecommitdiff
path: root/lib/e2ee/qolmsession.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/e2ee/qolmsession.cpp')
-rw-r--r--lib/e2ee/qolmsession.cpp251
1 files changed, 251 insertions, 0 deletions
diff --git a/lib/e2ee/qolmsession.cpp b/lib/e2ee/qolmsession.cpp
new file mode 100644
index 00000000..e575ff39
--- /dev/null
+++ b/lib/e2ee/qolmsession.cpp
@@ -0,0 +1,251 @@
+// SPDX-FileCopyrightText: 2021 Alexey Andreyev <aa13q@ya.ru>
+//
+// SPDX-License-Identifier: LGPL-2.1-or-later
+
+#include "qolmsession.h"
+#include "e2ee/qolmutils.h"
+#include "logging.h"
+#include <cstring>
+#include <QDebug>
+
+using namespace Quotient;
+
+QOlmError lastError(OlmSession* session) {
+ return fromString(olm_session_last_error(session));
+}
+
+Quotient::QOlmSession::~QOlmSession()
+{
+ olm_clear_session(m_session);
+ delete[](reinterpret_cast<uint8_t *>(m_session));
+}
+
+OlmSession* QOlmSession::create()
+{
+ return olm_session(new uint8_t[olm_session_size()]);
+}
+
+std::variant<QOlmSessionPtr, QOlmError> QOlmSession::createInbound(QOlmAccount *account, const QOlmMessage &preKeyMessage, bool from, const QString &theirIdentityKey)
+{
+ if (preKeyMessage.type() != QOlmMessage::PreKey) {
+ qCCritical(E2EE) << "The message is not a pre-key in when creating inbound session" << BadMessageFormat;
+ }
+
+ const auto olmSession = create();
+
+ QByteArray oneTimeKeyMessageBuf = preKeyMessage.toCiphertext();
+ QByteArray theirIdentityKeyBuf = theirIdentityKey.toUtf8();
+ size_t error = 0;
+ if (from) {
+ error = olm_create_inbound_session_from(olmSession, account->data(), theirIdentityKeyBuf.data(), theirIdentityKeyBuf.length(), oneTimeKeyMessageBuf.data(), oneTimeKeyMessageBuf.length());
+ } else {
+ error = olm_create_inbound_session(olmSession, account->data(), oneTimeKeyMessageBuf.data(), oneTimeKeyMessageBuf.length());
+ }
+
+ if (error == olm_error()) {
+ const auto lastErr = lastError(olmSession);
+ qCWarning(E2EE) << "Error when creating inbound session" << lastErr;
+ return lastErr;
+ }
+
+ return std::make_unique<QOlmSession>(olmSession);
+}
+
+std::variant<QOlmSessionPtr, QOlmError> QOlmSession::createInboundSession(QOlmAccount *account, const QOlmMessage &preKeyMessage)
+{
+ return createInbound(account, preKeyMessage);
+}
+
+std::variant<QOlmSessionPtr, QOlmError> QOlmSession::createInboundSessionFrom(QOlmAccount *account, const QString &theirIdentityKey, const QOlmMessage &preKeyMessage)
+{
+ return createInbound(account, preKeyMessage, true, theirIdentityKey);
+}
+
+std::variant<QOlmSessionPtr, QOlmError> QOlmSession::createOutboundSession(QOlmAccount *account, const QString &theirIdentityKey, const QString &theirOneTimeKey)
+{
+ auto *olmOutboundSession = create();
+ const auto randomLen = olm_create_outbound_session_random_length(olmOutboundSession);
+ QByteArray randomBuf = getRandom(randomLen);
+
+ QByteArray theirIdentityKeyBuf = theirIdentityKey.toUtf8();
+ QByteArray theirOneTimeKeyBuf = theirOneTimeKey.toUtf8();
+ const auto error = olm_create_outbound_session(olmOutboundSession,
+ account->data(),
+ reinterpret_cast<uint8_t *>(theirIdentityKeyBuf.data()), theirIdentityKeyBuf.length(),
+ reinterpret_cast<uint8_t *>(theirOneTimeKeyBuf.data()), theirOneTimeKeyBuf.length(),
+ reinterpret_cast<uint8_t *>(randomBuf.data()), randomBuf.length());
+
+ if (error == olm_error()) {
+ const auto lastErr = lastError(olmOutboundSession);
+ if (lastErr == QOlmError::NotEnoughRandom) {
+ throw lastErr;
+ }
+ return lastErr;
+ }
+
+ randomBuf.clear();
+ return std::make_unique<QOlmSession>(olmOutboundSession);
+}
+
+std::variant<QByteArray, QOlmError> QOlmSession::pickle(const PicklingMode &mode)
+{
+ QByteArray pickledBuf(olm_pickle_session_length(m_session), '0');
+ QByteArray key = toKey(mode);
+ const auto error = olm_pickle_session(m_session, key.data(), key.length(),
+ pickledBuf.data(), pickledBuf.length());
+
+ if (error == olm_error()) {
+ return lastError(m_session);
+ }
+
+ key.clear();
+
+ return pickledBuf;
+}
+
+std::variant<QOlmSessionPtr, QOlmError> QOlmSession::unpickle(const QByteArray &pickled, const PicklingMode &mode)
+{
+ QByteArray pickledBuf = pickled;
+ auto *olmSession = create();
+ QByteArray key = toKey(mode);
+ const auto error = olm_unpickle_session(olmSession, key.data(), key.length(),
+ pickledBuf.data(), pickledBuf.length());
+ if (error == olm_error()) {
+ return lastError(olmSession);
+ }
+
+ key.clear();
+ return std::make_unique<QOlmSession>(olmSession);
+}
+
+QOlmMessage QOlmSession::encrypt(const QString &plaintext)
+{
+ QByteArray plaintextBuf = plaintext.toUtf8();
+ const auto messageMaxLen = olm_encrypt_message_length(m_session, plaintextBuf.length());
+ QByteArray messageBuf(messageMaxLen, '0');
+ const auto messageType = encryptMessageType();
+ const auto randomLen = olm_encrypt_random_length(m_session);
+ QByteArray randomBuf = getRandom(randomLen);
+ const auto error = olm_encrypt(m_session,
+ reinterpret_cast<uint8_t *>(plaintextBuf.data()), plaintextBuf.length(),
+ reinterpret_cast<uint8_t *>(randomBuf.data()), randomBuf.length(),
+ reinterpret_cast<uint8_t *>(messageBuf.data()), messageBuf.length());
+
+ if (error == olm_error()) {
+ throw lastError(m_session);
+ }
+
+ return QOlmMessage(messageBuf, messageType);
+}
+
+std::variant<QString, QOlmError> QOlmSession::decrypt(const QOlmMessage &message) const
+{
+ const auto messageType = message.type();
+ const auto ciphertext = message.toCiphertext();
+ const auto messageTypeValue = messageType == QOlmMessage::Type::General
+ ? OLM_MESSAGE_TYPE_MESSAGE : OLM_MESSAGE_TYPE_PRE_KEY;
+
+ // We need to clone the message because
+ // olm_decrypt_max_plaintext_length destroys the input buffer
+ QByteArray messageBuf(ciphertext.length(), '0');
+ std::copy(message.begin(), message.end(), messageBuf.begin());
+
+ const auto plaintextMaxLen = olm_decrypt_max_plaintext_length(m_session, messageTypeValue,
+ reinterpret_cast<uint8_t *>(messageBuf.data()), messageBuf.length());
+
+ if (plaintextMaxLen == olm_error()) {
+ return lastError(m_session);
+ }
+
+ QByteArray plaintextBuf(plaintextMaxLen, '0');
+ QByteArray messageBuf2(ciphertext.length(), '0');
+ std::copy(message.begin(), message.end(), messageBuf2.begin());
+
+ const auto plaintextResultLen = olm_decrypt(m_session, messageTypeValue,
+ reinterpret_cast<uint8_t *>(messageBuf2.data()), messageBuf2.length(),
+ reinterpret_cast<uint8_t *>(plaintextBuf.data()), plaintextMaxLen);
+
+ if (plaintextResultLen == olm_error()) {
+ const auto lastErr = lastError(m_session);
+ if (lastErr == QOlmError::OutputBufferTooSmall) {
+ throw lastErr;
+ }
+ return lastErr;
+ }
+ QByteArray output(plaintextResultLen, '0');
+ std::memcpy(output.data(), plaintextBuf.data(), plaintextResultLen);
+ plaintextBuf.clear();
+ return output;
+}
+
+QOlmMessage::Type QOlmSession::encryptMessageType()
+{
+ const auto messageTypeResult = olm_encrypt_message_type(m_session);
+ if (messageTypeResult == olm_error()) {
+ throw lastError(m_session);
+ }
+ if (messageTypeResult == OLM_MESSAGE_TYPE_PRE_KEY) {
+ return QOlmMessage::PreKey;
+ }
+ return QOlmMessage::General;
+}
+
+QByteArray QOlmSession::sessionId() const
+{
+ const auto idMaxLength = olm_session_id_length(m_session);
+ QByteArray idBuffer(idMaxLength, '0');
+ const auto error = olm_session_id(m_session, reinterpret_cast<uint8_t *>(idBuffer.data()),
+ idBuffer.length());
+ if (error == olm_error()) {
+ throw lastError(m_session);
+ }
+ return idBuffer;
+}
+
+bool QOlmSession::hasReceivedMessage() const
+{
+ return olm_session_has_received_message(m_session);
+}
+
+std::variant<bool, QOlmError> QOlmSession::matchesInboundSession(const QOlmMessage &preKeyMessage) const
+{
+ Q_ASSERT(preKeyMessage.type() == QOlmMessage::Type::PreKey);
+ QByteArray oneTimeKeyBuf(preKeyMessage.data());
+ const auto matchesResult = olm_matches_inbound_session(m_session, oneTimeKeyBuf.data(), oneTimeKeyBuf.length());
+
+ if (matchesResult == olm_error()) {
+ return lastError(m_session);
+ }
+ switch (matchesResult) {
+ case 0:
+ return false;
+ case 1:
+ return true;
+ default:
+ return QOlmError::Unknown;
+ }
+}
+std::variant<bool, QOlmError> QOlmSession::matchesInboundSessionFrom(const QString &theirIdentityKey, const QOlmMessage &preKeyMessage) const
+{
+ const auto theirIdentityKeyBuf = theirIdentityKey.toUtf8();
+ auto oneTimeKeyMessageBuf = preKeyMessage.toCiphertext();
+ const auto error = olm_matches_inbound_session_from(m_session, theirIdentityKeyBuf.data(), theirIdentityKeyBuf.length(),
+ oneTimeKeyMessageBuf.data(), oneTimeKeyMessageBuf.length());
+
+ if (error == olm_error()) {
+ return lastError(m_session);
+ }
+ switch (error) {
+ case 0:
+ return false;
+ case 1:
+ return true;
+ default:
+ return QOlmError::Unknown;
+ }
+}
+
+QOlmSession::QOlmSession(OlmSession *session)
+ : m_session(session)
+{
+}