diff options
-rw-r--r-- | autotests/testolmsession.cpp | 4 | ||||
-rw-r--r-- | lib/olm/session.cpp | 41 | ||||
-rw-r--r-- | lib/olm/session.h | 6 |
3 files changed, 48 insertions, 3 deletions
diff --git a/autotests/testolmsession.cpp b/autotests/testolmsession.cpp index 2f7a82e9..fc151621 100644 --- a/autotests/testolmsession.cpp +++ b/autotests/testolmsession.cpp @@ -52,9 +52,9 @@ void TestOlmSession::olmEncryptDecrypt() QVERIFY(std::get<bool>(inboundSession->matchesInboundSession(m))); } - //const auto decrypted = inboundSession->decrypt(encrypted); + const auto decrypted = std::get<QString>(inboundSession->decrypt(encrypted)); - //QCOMPARE(decrypted, "Hello world!"); + QCOMPARE(decrypted, "Hello world!"); #endif } diff --git a/lib/olm/session.cpp b/lib/olm/session.cpp index d0493fe8..a05e0786 100644 --- a/lib/olm/session.cpp +++ b/lib/olm/session.cpp @@ -6,6 +6,7 @@ #include "olm/session.h" #include "olm/utils.h" #include "logging.h" +#include <cstring> using namespace Quotient; @@ -142,6 +143,46 @@ Message QOlmSession::encrypt(const QString &plaintext) return Message(messageBuf, messageType); } +std::variant<QString, OlmError> QOlmSession::decrypt(const Message &message) const +{ + const auto messageType = message.type(); + const auto ciphertext = message.toCiphertext(); + const auto messageTypeValue = messageType == Message::Type::General + ? OLM_MESSAGE_TYPE_MESSAGE : OLM_MESSAGE_TYPE_PRE_KEY; + + // We need to clone the message because + // olm_decrypt_max_plaintext_length destroys the input buffer + QByteArray messageBuf(ciphertext.length(), '0'); + std::copy(message.begin(), message.end(), messageBuf.begin()); + + const auto plaintextMaxLen = olm_decrypt_max_plaintext_length(m_session, messageTypeValue, + reinterpret_cast<uint8_t *>(messageBuf.data()), messageBuf.length()); + + if (plaintextMaxLen == olm_error()) { + return lastError(m_session); + } + + QByteArray plaintextBuf(plaintextMaxLen, '0'); + QByteArray messageBuf2(ciphertext.length(), '0'); + std::copy(message.begin(), message.end(), messageBuf2.begin()); + + const auto plaintextResultLen = olm_decrypt(m_session, messageTypeValue, + reinterpret_cast<uint8_t *>(messageBuf2.data()), messageBuf2.length(), + reinterpret_cast<uint8_t *>(plaintextBuf.data()), plaintextMaxLen); + + if (plaintextResultLen == olm_error()) { + const auto lastErr = lastError(m_session); + if (lastErr == OlmError::OutputBufferTooSmall) { + throw lastErr; + } + return lastErr; + } + QByteArray output(plaintextResultLen, '0'); + std::memcpy(output.data(), plaintextBuf.data(), plaintextResultLen); + plaintextBuf.clear(); + return output; +} + Message::Type QOlmSession::encryptMessageType() { const auto messageTypeResult = olm_encrypt_message_type(m_session); diff --git a/lib/olm/session.h b/lib/olm/session.h index c45b6898..3f1622c7 100644 --- a/lib/olm/session.h +++ b/lib/olm/session.h @@ -32,7 +32,11 @@ public: static std::variant<std::unique_ptr<QOlmSession>, OlmError> unpickle(QByteArray &pickled, const PicklingMode &mode); //! Encrypts a plaintext message using the session. Message encrypt(const QString &plaintext); - // TODO: WiP + + //! Decrypts a message using this session. Decoding is lossy, meaing if + //! the decrypted plaintext contains invalid UTF-8 symbols, they will + //! be returned as `U+FFFD` (�). + std::variant<QString, OlmError> decrypt(const Message &message) const; //! Get a base64-encoded identifier for this session. QByteArray sessionId() const; |