1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

use crate::{
    arrays::{
        QirArray, __quantum__rt__array_concatenate, __quantum__rt__array_update_reference_count,
    },
    tuples::{__quantum__rt__tuple_copy, __quantum__rt__tuple_update_reference_count},
    update_counts,
};
use std::{cell::RefCell, mem::ManuallyDrop, rc::Rc};

#[derive(Clone)]
pub struct Callable {
    func_table: *mut *mut u8,
    mem_table: *mut *mut u8,
    cap_tuple: *mut u8,
    is_adj: RefCell<bool>,
    ctls_count: RefCell<u32>,
}

#[no_mangle]
pub extern "C" fn __quantum__rt__callable_create(
    func_table: *mut *mut u8,
    mem_table: *mut *mut u8,
    cap_tuple: *mut u8,
) -> *const Callable {
    Rc::into_raw(Rc::new(Callable {
        func_table,
        mem_table,
        cap_tuple,
        is_adj: RefCell::new(false),
        ctls_count: RefCell::new(0),
    }))
}

#[no_mangle]
#[allow(clippy::cast_ptr_alignment)]
pub unsafe extern "C" fn __quantum__rt__callable_invoke(
    callable: *const Callable,
    args_tup: *mut u8,
    res_tup: *mut u8,
) {
    let call = &*callable;
    let index =
        usize::from(*call.is_adj.borrow()) + (if *call.ctls_count.borrow() > 0 { 2 } else { 0 });

    // Collect any nested controls into a single control list.
    let mut args_copy: *mut *const Vec<u8> = std::ptr::null_mut();
    if !args_tup.is_null() {
        // Copy the tuple so we can potentially edit it.
        args_copy = __quantum__rt__tuple_copy(args_tup.cast::<*const Vec<u8>>(), true);

        if *call.ctls_count.borrow() > 0 {
            // If there are any controls, increment the reference count on the control list. This is just
            // to balance the decrement that will happen in the loop and at the end of invoking the callable
            // to ensure the original, non-owned list does not get incorrectly cleaned up.
            __quantum__rt__array_update_reference_count(*args_copy.cast::<*const QirArray>(), 1);

            let mut ctl_count = *call.ctls_count.borrow();
            while ctl_count > 1 {
                let ctls = *args_copy.cast::<*const QirArray>();
                let inner_tuple = *args_copy
                    .cast::<*const QirArray>()
                    .wrapping_add(1)
                    .cast::<*mut *const Vec<u8>>();
                let inner_ctls = *inner_tuple.cast::<*const QirArray>();
                let new_ctls = __quantum__rt__array_concatenate(ctls, inner_ctls);
                let new_args = __quantum__rt__tuple_copy(inner_tuple, true);
                *new_args.cast::<*const QirArray>() = new_ctls;

                // Decrementing the reference count is either the extra count added above or the new
                // list created when performing concatenate above. In the latter case, the concatenated
                // list will get cleaned up, preventing memory from leaking.
                __quantum__rt__array_update_reference_count(
                    *args_copy.cast::<*const QirArray>(),
                    -1,
                );
                // Decrement the count on the copy to clean it up as well, since we created a new copy
                // with the updated controls list.
                __quantum__rt__tuple_update_reference_count(args_copy, -1);
                args_copy = new_args;
                ctl_count -= 1;
            }
        }
    }

    (*call
        .func_table
        .wrapping_add(index)
        .cast::<extern "C" fn(*mut u8, *mut u8, *mut u8)>())(
        call.cap_tuple,
        args_copy.cast::<u8>(),
        res_tup,
    );
    if *call.ctls_count.borrow() > 0 {
        __quantum__rt__array_update_reference_count(*args_copy.cast::<*const QirArray>(), -1);
    }
    if !args_copy.is_null() {
        __quantum__rt__tuple_update_reference_count(args_copy, -1);
    }
}

#[no_mangle]
pub unsafe extern "C" fn __quantum__rt__callable_copy(
    callable: *const Callable,
    force: bool,
) -> *const Callable {
    let rc = ManuallyDrop::new(Rc::from_raw(callable));
    if force || Rc::weak_count(&rc) > 0 {
        let copy = rc.as_ref().clone();
        Rc::into_raw(Rc::new(copy))
    } else {
        Rc::into_raw(Rc::clone(&rc));
        callable
    }
}

#[no_mangle]
pub unsafe extern "C" fn __quantum__rt__callable_make_adjoint(callable: *const Callable) {
    let call = &*callable;
    call.is_adj.replace_with(|&mut old| !old);
}

#[no_mangle]
pub unsafe extern "C" fn __quantum__rt__callable_make_controlled(callable: *const Callable) {
    let call = &*callable;
    call.ctls_count.replace_with(|&mut old| old + 1);
}

#[no_mangle]
pub unsafe extern "C" fn __quantum__rt__callable_update_reference_count(
    callable: *const Callable,
    update: i32,
) {
    update_counts(callable, update, false);
}

#[no_mangle]
pub unsafe extern "C" fn __quantum__rt__callable_update_alias_count(
    callable: *const Callable,
    update: i32,
) {
    update_counts(callable, update, true);
}

#[no_mangle]
pub unsafe extern "C" fn __quantum__rt__capture_update_reference_count(
    callable: *const Callable,
    update: i32,
) {
    let call = &*callable;
    if !call.mem_table.is_null() && !(*(call.mem_table)).is_null() {
        (*call.mem_table.cast::<extern "C" fn(*mut u8, i32)>())(call.cap_tuple, update);
    }
}

#[no_mangle]
pub unsafe extern "C" fn __quantum__rt__capture_update_alias_count(
    callable: *const Callable,
    update: i32,
) {
    let call = &*callable;
    if !call.mem_table.is_null() && !(*(call.mem_table.wrapping_add(1))).is_null() {
        (*call
            .mem_table
            .wrapping_add(1)
            .cast::<extern "C" fn(*mut u8, i32)>())(call.cap_tuple, update);
    }
}