From faaea2eb2a090787f365eb12a42f8452bb0f07e2 Mon Sep 17 00:00:00 2001
From: Carl Schwan <carl@carlschwan.eu>
Date: Sun, 24 Jan 2021 20:12:39 +0100
Subject: Make it work

---
 lib/olm/qolminboundsession.cpp  | 48 ++++++++++++++++++++---------------------
 lib/olm/qolminboundsession.h    | 14 ++++++------
 lib/olm/qolmoutboundsession.cpp | 31 +++++++++++++++-----------
 lib/olm/qolmoutboundsession.h   |  9 ++++----
 lib/olm/utils.cpp               |  4 +++-
 5 files changed, 57 insertions(+), 49 deletions(-)

(limited to 'lib')

diff --git a/lib/olm/qolminboundsession.cpp b/lib/olm/qolminboundsession.cpp
index f0ca73c4..d3b98a63 100644
--- a/lib/olm/qolminboundsession.cpp
+++ b/lib/olm/qolminboundsession.cpp
@@ -4,7 +4,6 @@
 
 #ifdef Quotient_E2EE_ENABLED
 #include "olm/qolminboundsession.h"
-#include <QDebug>
 #include <iostream>
 using namespace Quotient;
 
@@ -15,48 +14,44 @@ OlmError lastError(OlmInboundGroupSession *session) {
     return fromString(error_raw);
 }
 
-QOlmInboundGroupSession::QOlmInboundGroupSession(OlmInboundGroupSession *session, QByteArray buffer)
+QOlmInboundGroupSession::QOlmInboundGroupSession(OlmInboundGroupSession *session)
     : m_groupSession(session)
-    , m_buffer(buffer)
 {
 }
 
 QOlmInboundGroupSession::~QOlmInboundGroupSession()
 {
     olm_clear_inbound_group_session(m_groupSession);
+    //delete[](reinterpret_cast<uint8_t *>(m_groupSession));
 }
 
-std::variant<QOlmInboundGroupSession, OlmError> QOlmInboundGroupSession::create(const QByteArray &key)
+std::unique_ptr<QOlmInboundGroupSession> QOlmInboundGroupSession::create(const QByteArray &key)
 {
-    QByteArray olmInboundGroupSessionBuf(olm_inbound_group_session_size(), '0');
-    const auto olmInboundGroupSession = olm_inbound_group_session(olmInboundGroupSessionBuf.data());
-
+    const auto olmInboundGroupSession = olm_inbound_group_session(new uint8_t[olm_inbound_group_session_size()]);
     const auto temp = key;
-
     const auto error = olm_init_inbound_group_session(olmInboundGroupSession,
             reinterpret_cast<const uint8_t *>(temp.data()), temp.size());
 
     if (error == olm_error()) {
-        return lastError(olmInboundGroupSession);
+        throw lastError(olmInboundGroupSession);
     }
 
-    return QOlmInboundGroupSession(olmInboundGroupSession, std::move(olmInboundGroupSessionBuf));
+    return std::make_unique<QOlmInboundGroupSession>(olmInboundGroupSession);
 }
 
 
-std::variant<QOlmInboundGroupSession, OlmError> QOlmInboundGroupSession::import(const QByteArray &key)
+std::unique_ptr<QOlmInboundGroupSession> QOlmInboundGroupSession::import(const QByteArray &key)
 {
-    QByteArray olmInboundGroupSessionBuf(olm_inbound_group_session_size(), '0');
-    const auto olmInboundGroupSession = olm_inbound_group_session(olmInboundGroupSessionBuf.data());
+    const auto olmInboundGroupSession = olm_inbound_group_session(new uint8_t[olm_inbound_group_session_size()]);
     QByteArray keyBuf = key;
 
     const auto error = olm_import_inbound_group_session(olmInboundGroupSession,
             reinterpret_cast<const uint8_t *>(keyBuf.data()), keyBuf.size());
     if (error == olm_error()) {
-        return lastError(olmInboundGroupSession);
+        throw lastError(olmInboundGroupSession);
     }
 
-    return QOlmInboundGroupSession(olmInboundGroupSession, std::move(olmInboundGroupSessionBuf));
+    return std::make_unique<QOlmInboundGroupSession>(olmInboundGroupSession);
 }
 
 QByteArray toKey(const PicklingMode &mode)
@@ -67,28 +62,31 @@ QByteArray toKey(const PicklingMode &mode)
     return std::get<Encrypted>(mode).key;
 }
 
-std::variant<QByteArray, OlmError> QOlmInboundGroupSession::pickle(const PicklingMode &mode) const
+QByteArray QOlmInboundGroupSession::pickle(const PicklingMode &mode) const
 {
     QByteArray pickledBuf(olm_pickle_inbound_group_session_length(m_groupSession), '0');
     const QByteArray key = toKey(mode);
     const auto error = olm_pickle_inbound_group_session(m_groupSession, key.data(), key.length(), pickledBuf.data(),
             pickledBuf.length());
     if (error == olm_error()) {
-        return lastError(m_groupSession);
+        throw lastError(m_groupSession);
     }
     return pickledBuf;
 }
 
-std::variant<QOlmInboundGroupSession, OlmError> QOlmInboundGroupSession::unpickle(QByteArray &picked, const PicklingMode &mode)
+std::variant<std::unique_ptr<QOlmInboundGroupSession>, OlmError> QOlmInboundGroupSession::unpickle(QByteArray &pickled, const PicklingMode &mode)
 {
-    QByteArray groupSessionBuf(olm_inbound_group_session_size(), '0');
-    auto groupSession = olm_inbound_group_session(groupSessionBuf.data());
-    const QByteArray key = toKey(mode);
-    const auto error = olm_unpickle_inbound_group_session(groupSession, key.data(), key.length(), picked.data(), picked.size());
+    QByteArray pickledBuf = pickled;
+    const auto groupSession = olm_inbound_group_session(new uint8_t[olm_inbound_group_session_size()]);
+    QByteArray key = toKey(mode);
+    const auto error = olm_unpickle_inbound_group_session(groupSession, key.data(), key.length(),
+            pickledBuf.data(), pickledBuf.size());
     if (error == olm_error()) {
         return lastError(groupSession);
     }
-    return QOlmInboundGroupSession(groupSession, std::move(groupSessionBuf));
+    key.clear();
+
+    return std::make_unique<QOlmInboundGroupSession>(groupSession);
 }
 
 std::variant<std::pair<QString, uint32_t>, OlmError> QOlmInboundGroupSession::decrypt(QString &message)
@@ -136,13 +134,13 @@ uint32_t QOlmInboundGroupSession::firstKnownIndex() const
     return olm_inbound_group_session_first_known_index(m_groupSession);
 }
 
-std::variant<QByteArray, OlmError> QOlmInboundGroupSession::sessionId() const
+QByteArray QOlmInboundGroupSession::sessionId() const
 {
     QByteArray sessionIdBuf(olm_inbound_group_session_id_length(m_groupSession), '0');
     const auto error = olm_inbound_group_session_id(m_groupSession, reinterpret_cast<uint8_t *>(sessionIdBuf.data()),
             sessionIdBuf.length());
     if (error == olm_error()) {
-        return lastError(m_groupSession);
+        throw lastError(m_groupSession);
     }
     return sessionIdBuf;
 }
diff --git a/lib/olm/qolminboundsession.h b/lib/olm/qolminboundsession.h
index 85807821..ccc53ba8 100644
--- a/lib/olm/qolminboundsession.h
+++ b/lib/olm/qolminboundsession.h
@@ -8,6 +8,7 @@
 
 #include <QByteArray>
 #include <variant>
+#include <memory>
 #include "olm/olm.h"
 #include "olm/errors.h"
 #include "olm/e2ee.h"
@@ -21,14 +22,14 @@ struct QOlmInboundGroupSession
 public:
     ~QOlmInboundGroupSession();
     //! Creates a new instance of `OlmInboundGroupSession`.
-    static std::variant<QOlmInboundGroupSession, OlmError> create(const QByteArray &key);
+    static std::unique_ptr<QOlmInboundGroupSession> create(const QByteArray &key);
     //! Import an inbound group session, from a previous export.
-    static std::variant<QOlmInboundGroupSession, OlmError> import(const QByteArray &key);
+    static std::unique_ptr<QOlmInboundGroupSession> import(const QByteArray &key);
     //! Serialises an `OlmInboundGroupSession` to encrypted Base64.
-    std::variant<QByteArray, OlmError> pickle(const PicklingMode &mode) const;
+    QByteArray pickle(const PicklingMode &mode) const;
     //! Deserialises from encrypted Base64 that was previously obtained by pickling
     //! an `OlmInboundGroupSession`.
-    static std::variant<QOlmInboundGroupSession, OlmError> unpickle(QByteArray &picked, const PicklingMode &mode);
+    static std::variant<std::unique_ptr<QOlmInboundGroupSession>, OlmError> unpickle(QByteArray &picked, const PicklingMode &mode);
     //! Decrypts ciphertext received for this group session.
     std::variant<std::pair<QString, uint32_t>, OlmError> decrypt(QString &message);
     //! Export the base64-encoded ratchet key for this session, at the given index,
@@ -37,12 +38,11 @@ public:
     //! Get the first message index we know how to decrypt.
     uint32_t firstKnownIndex() const;
     //! Get a base64-encoded identifier for this session.
-    std::variant<QByteArray, OlmError> sessionId() const;
+    QByteArray sessionId() const;
     bool isVerified() const;
+    QOlmInboundGroupSession(OlmInboundGroupSession *session);
 private:
-    QOlmInboundGroupSession(OlmInboundGroupSession *session, QByteArray buffer);
     OlmInboundGroupSession *m_groupSession;
-    QByteArray m_buffer;
 };
 } // namespace Quotient
 #endif
diff --git a/lib/olm/qolmoutboundsession.cpp b/lib/olm/qolmoutboundsession.cpp
index 60126469..ba8be4f6 100644
--- a/lib/olm/qolmoutboundsession.cpp
+++ b/lib/olm/qolmoutboundsession.cpp
@@ -14,34 +14,38 @@ OlmError lastError(OlmOutboundGroupSession *session) {
     return fromString(error_raw);
 }
 
-QOlmOutboundGroupSession::QOlmOutboundGroupSession(OlmOutboundGroupSession *session, const QByteArray &buffer)
+QOlmOutboundGroupSession::QOlmOutboundGroupSession(OlmOutboundGroupSession *session)
     : m_groupSession(session)
-    , m_buffer(buffer)
 {
 }
 
 QOlmOutboundGroupSession::~QOlmOutboundGroupSession()
 {
     olm_clear_outbound_group_session(m_groupSession);
+    //delete[](reinterpret_cast<uint8_t *>(m_groupSession));
 }
 
-std::variant<QOlmOutboundGroupSession, OlmError> QOlmOutboundGroupSession::create()
+std::unique_ptr<QOlmOutboundGroupSession> QOlmOutboundGroupSession::create()
 {
-    QByteArray sessionBuffer(olm_outbound_group_session_size(), '0');
-    auto *olmOutboundGroupSession = olm_outbound_group_session(sessionBuffer.data());
+    auto *olmOutboundGroupSession = olm_outbound_group_session(new uint8_t[olm_outbound_group_session_size()]);
     const auto randomLen = olm_init_outbound_group_session_random_length(olmOutboundGroupSession);
     QByteArray randomBuf = getRandom(randomLen);
 
-    const auto error = olm_init_outbound_group_session(olmOutboundGroupSession, 
+    const auto error = olm_init_outbound_group_session(olmOutboundGroupSession,
             reinterpret_cast<uint8_t *>(randomBuf.data()), randomBuf.length());
 
     if (error == olm_error()) {
-        return lastError(olmOutboundGroupSession);
+        throw lastError(olmOutboundGroupSession);
     }
 
+    const auto keyMaxLength = olm_outbound_group_session_key_length(olmOutboundGroupSession);
+    QByteArray keyBuffer(keyMaxLength, '0');
+    olm_outbound_group_session_key(olmOutboundGroupSession, reinterpret_cast<uint8_t *>(keyBuffer.data()),
+            keyMaxLength);
+
     randomBuf.clear();
 
-    return QOlmOutboundGroupSession(olmOutboundGroupSession, sessionBuffer);
+    return std::make_unique<QOlmOutboundGroupSession>(olmOutboundGroupSession);
 }
 
 std::variant<QByteArray, OlmError> QOlmOutboundGroupSession::pickle(const PicklingMode &mode)
@@ -61,20 +65,23 @@ std::variant<QByteArray, OlmError> QOlmOutboundGroupSession::pickle(const Pickli
 }
 
 
-std::variant<QOlmOutboundGroupSession, OlmError> QOlmOutboundGroupSession::unpickle(QByteArray &pickled, const PicklingMode &mode)
+std::variant<std::unique_ptr<QOlmOutboundGroupSession>, OlmError> QOlmOutboundGroupSession::unpickle(QByteArray &pickled, const PicklingMode &mode)
 {
     QByteArray pickledBuf = pickled;
-    QByteArray olmOutboundGroupSessionBuf(olm_outbound_group_session_size(), '0');
+    auto *olmOutboundGroupSession = olm_outbound_group_session(new uint8_t[olm_outbound_group_session_size()]);
     QByteArray key = toKey(mode);
-    auto olmOutboundGroupSession = olm_outbound_group_session(reinterpret_cast<uint8_t *>(olmOutboundGroupSessionBuf.data()));
     const auto error = olm_unpickle_outbound_group_session(olmOutboundGroupSession, key.data(), key.length(),
             pickled.data(), pickled.length());
     if (error == olm_error()) {
         return lastError(olmOutboundGroupSession);
     }
+    const auto idMaxLength = olm_outbound_group_session_id_length(olmOutboundGroupSession);
+    QByteArray idBuffer(idMaxLength, '0');
+    olm_outbound_group_session_id(olmOutboundGroupSession, reinterpret_cast<uint8_t *>(idBuffer.data()),
+            idBuffer.length());
 
     key.clear();
-    return QOlmOutboundGroupSession(olmOutboundGroupSession, olmOutboundGroupSessionBuf);
+    return std::make_unique<QOlmOutboundGroupSession>(olmOutboundGroupSession);
 }
 
 std::variant<QString, OlmError> QOlmOutboundGroupSession::encrypt(QString &plaintext)
diff --git a/lib/olm/qolmoutboundsession.h b/lib/olm/qolmoutboundsession.h
index 2e1439d3..29776a3d 100644
--- a/lib/olm/qolmoutboundsession.h
+++ b/lib/olm/qolmoutboundsession.h
@@ -7,6 +7,7 @@
 #include "olm/olm.h" // from Olm
 #include "olm/errors.h"
 #include "olm/e2ee.h"
+#include <memory>
 
 namespace Quotient {
 
@@ -17,12 +18,13 @@ class QOlmOutboundGroupSession
 public:
     ~QOlmOutboundGroupSession();
     //! Creates a new instance of `QOlmOutboundGroupSession`.
-    static std::variant<QOlmOutboundGroupSession, OlmError> create();
+    //! Throw OlmError on errors
+    static std::unique_ptr<QOlmOutboundGroupSession> create();
     //! Serialises an `QOlmOutboundGroupSession` to encrypted Base64.
     std::variant<QByteArray, OlmError> pickle(const PicklingMode &mode);
     //! Deserialises from encrypted Base64 that was previously obtained by
     //! pickling a `QOlmOutboundGroupSession`.
-    static std::variant<QOlmOutboundGroupSession, OlmError> unpickle(QByteArray &pickled, const PicklingMode &mode);
+    static std::variant<std::unique_ptr<QOlmOutboundGroupSession>, OlmError> unpickle(QByteArray &pickled, const PicklingMode &mode);
     //! Encrypts a plaintext message using the session.
     std::variant<QString, OlmError> encrypt(QString &plaintext);
 
@@ -40,10 +42,9 @@ public:
     //! Each message is sent with a different ratchet key. This function returns the
     //! ratchet key that will be used for the next message.
     std::variant<QByteArray, OlmError> sessionKey() const;
+    QOlmOutboundGroupSession(OlmOutboundGroupSession *groupSession);
 private:
-    QOlmOutboundGroupSession(OlmOutboundGroupSession *groupSession, const QByteArray &groupSessionBuf);
     OlmOutboundGroupSession *m_groupSession;
-    QByteArray m_buffer;
 };
 }
 #endif
diff --git a/lib/olm/utils.cpp b/lib/olm/utils.cpp
index 15def1d7..227e6d84 100644
--- a/lib/olm/utils.cpp
+++ b/lib/olm/utils.cpp
@@ -4,6 +4,8 @@
 
 #ifdef Quotient_E2EE_ENABLED
 #include "olm/utils.h"
+#include <QDebug>
+#include <openssl/rand.h>
 
 using namespace Quotient;
 
@@ -18,7 +20,7 @@ QByteArray Quotient::toKey(const Quotient::PicklingMode &mode)
 QByteArray Quotient::getRandom(size_t bufferSize)
 {
     QByteArray buffer(bufferSize, '0');
-    std::generate(buffer.begin(), buffer.end(), std::rand);
+    RAND_bytes(reinterpret_cast<uint8_t *>(buffer.data()), buffer.size());
     return buffer;
 }
 #endif
-- 
cgit v1.2.3