aboutsummaryrefslogtreecommitdiff
path: root/lib/crypto/qolmoutboundsession.cpp
blob: bc572ba54849b2211d60e0c02914966c9660726f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// SPDX-FileCopyrightText: 2021 Carl Schwan <carlschwan@kde.org>
//
// SPDX-License-Identifier: LGPL-2.1-or-later

#include "qolmoutboundsession.h"
#include "crypto/qolmutils.h"

using namespace Quotient;

QOlmError lastError(OlmOutboundGroupSession *session) {
    const std::string error_raw = olm_outbound_group_session_last_error(session);

    return fromString(error_raw);
}

QOlmOutboundGroupSession::QOlmOutboundGroupSession(OlmOutboundGroupSession *session)
    : m_groupSession(session)
{
}

QOlmOutboundGroupSession::~QOlmOutboundGroupSession()
{
    olm_clear_outbound_group_session(m_groupSession);
    delete[](reinterpret_cast<uint8_t *>(m_groupSession));
}

std::unique_ptr<QOlmOutboundGroupSession> QOlmOutboundGroupSession::create()
{
    auto *olmOutboundGroupSession = olm_outbound_group_session(new uint8_t[olm_outbound_group_session_size()]);
    const auto randomLength = olm_init_outbound_group_session_random_length(olmOutboundGroupSession);
    QByteArray randomBuf = getRandom(randomLength);

    const auto error = olm_init_outbound_group_session(olmOutboundGroupSession,
            reinterpret_cast<uint8_t *>(randomBuf.data()), randomBuf.length());

    if (error == olm_error()) {
        throw lastError(olmOutboundGroupSession);
    }

    const auto keyMaxLength = olm_outbound_group_session_key_length(olmOutboundGroupSession);
    QByteArray keyBuffer(keyMaxLength, '0');
    olm_outbound_group_session_key(olmOutboundGroupSession, reinterpret_cast<uint8_t *>(keyBuffer.data()),
            keyMaxLength);

    randomBuf.clear();

    return std::make_unique<QOlmOutboundGroupSession>(olmOutboundGroupSession);
}

std::variant<QByteArray, QOlmError> QOlmOutboundGroupSession::pickle(const PicklingMode &mode)
{
    QByteArray pickledBuf(olm_pickle_outbound_group_session_length(m_groupSession), '0');
    QByteArray key = toKey(mode);
    const auto error = olm_pickle_outbound_group_session(m_groupSession, key.data(), key.length(),
            pickledBuf.data(), pickledBuf.length());

    if (error == olm_error()) {
        return lastError(m_groupSession);
    }

    key.clear();

    return pickledBuf;
}


std::variant<std::unique_ptr<QOlmOutboundGroupSession>, QOlmError> QOlmOutboundGroupSession::unpickle(QByteArray &pickled, const PicklingMode &mode)
{
    QByteArray pickledBuf = pickled;
    auto *olmOutboundGroupSession = olm_outbound_group_session(new uint8_t[olm_outbound_group_session_size()]);
    QByteArray key = toKey(mode);
    const auto error = olm_unpickle_outbound_group_session(olmOutboundGroupSession, key.data(), key.length(),
            pickled.data(), pickled.length());
    if (error == olm_error()) {
        return lastError(olmOutboundGroupSession);
    }
    const auto idMaxLength = olm_outbound_group_session_id_length(olmOutboundGroupSession);
    QByteArray idBuffer(idMaxLength, '0');
    olm_outbound_group_session_id(olmOutboundGroupSession, reinterpret_cast<uint8_t *>(idBuffer.data()),
            idBuffer.length());

    key.clear();
    return std::make_unique<QOlmOutboundGroupSession>(olmOutboundGroupSession);
}

std::variant<QByteArray, QOlmError> QOlmOutboundGroupSession::encrypt(const QString &plaintext)
{
    QByteArray plaintextBuf = plaintext.toUtf8();
    const auto messageMaxLength = olm_group_encrypt_message_length(m_groupSession, plaintextBuf.length());
    QByteArray messageBuf(messageMaxLength, '0');
    const auto error = olm_group_encrypt(m_groupSession, reinterpret_cast<uint8_t *>(plaintextBuf.data()),
            plaintextBuf.length(), reinterpret_cast<uint8_t *>(messageBuf.data()), messageBuf.length());

    if (error == olm_error()) {
        return lastError(m_groupSession);
    }

    return messageBuf;
}

uint32_t QOlmOutboundGroupSession::sessionMessageIndex() const
{
    return olm_outbound_group_session_message_index(m_groupSession);
}

QByteArray QOlmOutboundGroupSession::sessionId() const
{
    const auto idMaxLength = olm_outbound_group_session_id_length(m_groupSession);
    QByteArray idBuffer(idMaxLength, '0');
    const auto error = olm_outbound_group_session_id(m_groupSession, reinterpret_cast<uint8_t *>(idBuffer.data()),
            idBuffer.length());
    if (error == olm_error()) {
        throw lastError(m_groupSession);
    }
    return idBuffer;
}

std::variant<QByteArray, QOlmError> QOlmOutboundGroupSession::sessionKey() const
{
    const auto keyMaxLength = olm_outbound_group_session_key_length(m_groupSession);
    QByteArray keyBuffer(keyMaxLength, '0');
    const auto error = olm_outbound_group_session_key(m_groupSession, reinterpret_cast<uint8_t *>(keyBuffer.data()),
            keyMaxLength);
    if (error == olm_error()) {
        return lastError(m_groupSession);
    }
    return keyBuffer;
}