summaryrefslogtreecommitdiff
path: root/vendor/tokio/tests/rt_poll_callbacks.rs
blob: 8ccff385772e240631dc2a02d30abaa007f34707 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#![allow(unknown_lints, unexpected_cfgs)]
#![cfg(tokio_unstable)]

use std::sync::{atomic::AtomicUsize, Arc, Mutex};

use tokio::task::yield_now;

#[cfg(not(target_os = "wasi"))]
#[test]
fn callbacks_fire_multi_thread() {
    let poll_start_counter = Arc::new(AtomicUsize::new(0));
    let poll_stop_counter = Arc::new(AtomicUsize::new(0));
    let poll_start = poll_start_counter.clone();
    let poll_stop = poll_stop_counter.clone();

    let before_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
        Arc::new(Mutex::new(None));
    let after_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
        Arc::new(Mutex::new(None));

    let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id);
    let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id);
    let rt = tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .on_before_task_poll(move |task_meta| {
            before_task_poll_callback_task_id_ref
                .lock()
                .unwrap()
                .replace(task_meta.id());
            poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        })
        .on_after_task_poll(move |task_meta| {
            after_task_poll_callback_task_id_ref
                .lock()
                .unwrap()
                .replace(task_meta.id());
            poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        })
        .build()
        .unwrap();
    let task = rt.spawn(async {
        yield_now().await;
        yield_now().await;
        yield_now().await;
    });

    let spawned_task_id = task.id();

    rt.block_on(task).expect("task should succeed");
    // We need to drop the runtime to guarantee the workers have exited (and thus called the callback)
    drop(rt);

    assert_eq!(
        before_task_poll_callback_task_id.lock().unwrap().unwrap(),
        spawned_task_id
    );
    assert_eq!(
        after_task_poll_callback_task_id.lock().unwrap().unwrap(),
        spawned_task_id
    );
    let actual_count = 4;
    assert_eq!(
        poll_start.load(std::sync::atomic::Ordering::Relaxed),
        actual_count,
        "unexpected number of poll starts"
    );
    assert_eq!(
        poll_stop.load(std::sync::atomic::Ordering::Relaxed),
        actual_count,
        "unexpected number of poll stops"
    );
}

#[test]
fn callbacks_fire_current_thread() {
    let poll_start_counter = Arc::new(AtomicUsize::new(0));
    let poll_stop_counter = Arc::new(AtomicUsize::new(0));
    let poll_start = poll_start_counter.clone();
    let poll_stop = poll_stop_counter.clone();

    let before_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
        Arc::new(Mutex::new(None));
    let after_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
        Arc::new(Mutex::new(None));

    let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id);
    let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id);
    let rt = tokio::runtime::Builder::new_current_thread()
        .enable_all()
        .on_before_task_poll(move |task_meta| {
            before_task_poll_callback_task_id_ref
                .lock()
                .unwrap()
                .replace(task_meta.id());
            poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        })
        .on_after_task_poll(move |task_meta| {
            after_task_poll_callback_task_id_ref
                .lock()
                .unwrap()
                .replace(task_meta.id());
            poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        })
        .build()
        .unwrap();

    let task = rt.spawn(async {
        yield_now().await;
        yield_now().await;
        yield_now().await;
    });

    let spawned_task_id = task.id();

    let _ = rt.block_on(task);
    drop(rt);

    assert_eq!(
        before_task_poll_callback_task_id.lock().unwrap().unwrap(),
        spawned_task_id
    );
    assert_eq!(
        after_task_poll_callback_task_id.lock().unwrap().unwrap(),
        spawned_task_id
    );
    assert_eq!(poll_start.load(std::sync::atomic::Ordering::Relaxed), 4);
    assert_eq!(poll_stop.load(std::sync::atomic::Ordering::Relaxed), 4);
}