Python 标准库源码之 threading 模块

引言

虽然说 Python 受限于 CPython 的实现,存在的 GIL 会导致我们在使用多线程的时候,没法利用多核跑多线程。但是有的时候还是会用到线程的,尤其是针对一些 I/O 密集型的任务,也可以使用它们。

在使用多线程编程时,我们随时需要注意竞态条件(race condition)数据竞争(data race)的问题,前者会导致我们在不同的时间点运行程序得到的输出可能不同;而后者则更为可怕,容易导致共享的数据结构被错误修改,甚至导致程序崩溃或者出现莫名其妙的 Bug。这个时候自然就要用到 Python threading 模块为我们提供的若干同步原语了。

那么,我们常用的 Lock、RLock、条件变量(Condition Variables)、信号量(Semaphore)等是如何实现的呢?接下来的源码学习是基于 CPython master 分支的线程模块。希望在学习完它们的实现后,能够加深理解,合理运用。

源码学习

CPython 的 threading 模块实际上是基于 Java 的线程模型实现的,所以熟悉 Java 的话,自然也不会对该模块的实现感到陌生。该模块是基于更底层的 _thread 模块,抽象出更加方便使用的线程模型,核心包括 threading.Thread 线程类封装,便于用户继承或组合;此外还有一些同步原语的实现。Python/thread_nt.h 文件中是 C 语言实现的底层和线程有关的函数(如锁的创建和维护、线程的创建和管理)。

同步原语

Lock

该模块中,Lock 其实是使用了底层 _thread.allocate_lock 函数来创建锁的。代码也很简单:

1
Lock = _allocate_lock

Lock 为我们提供了 acquire()release() 这两个主要的方法。当一个线程持有锁时,其它线程调用 acquire() 方法时会被阻塞(此时线程一般就是睡眠等待了),直到主动 release() 后,等待锁的线程会被唤醒。

关于 Lock 有两点值得注意:

  1. 该锁是不可重入的,也就是如果在一个函数中递归 acquire() 会导致死锁的问题。为了避免这种问题,一般会使用 RLock 来代替
  2. Lock 并非 Mutex(互斥锁),且它底层是通过信号量那样实现的,本身不会记录谁持有了该锁,也就是说 Lock 可以在不同的线程中被引用,可以在主线程获取,而在子线程释放它。具体可以在 CPython/Python/thread_nt.h:PyThread_allocate_lock 可以看到它的实现如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    /*
    * Lock support. It has to be implemented as semaphores.
    * I [Dag] tried to implement it with mutex but I could find a way to
    * tell whether a thread already own the lock or not.
    * Lock 支持:它必须以信号量的方式来实现。我尝试使用互斥锁实现过,但是我
    * 发现了另外一种方式可以得知一个线程是否持有了锁。
    */
    PyThread_type_lock
    PyThread_allocate_lock(void)
    {
    PNRMUTEX aLock;
    dprintf(("PyThread_allocate_lock called\n"));
    if (!initialized)
    PyThread_init_thread();
    aLock = AllocNonRecursiveMutex() ;
    dprintf(("%lu: PyThread_allocate_lock() -> %p\n", PyThread_get_thread_ident(), aLock));
    return (PyThread_type_lock) aLock;
    }

    // 其中 PNRMUTEX 定义如下,它并不会告诉我们当前是哪个线程
    // 持有了锁
    typedef struct _NRMUTEX
    {
    PyMUTEX_T cs;
    PyCOND_T cv;
    int locked;
    } NRMUTEX;
    typedef NRMUTEX *PNRMUTEX;

比较有趣的是,其实 PyCOND 即条件变量是通过信号量来实现的;而接下来我们会看到,在 Python 的 threading 模块中,我们使用了 Condition 实现了信号量。

RLock

RLock 就是可重入锁(Reentrant Lock),它可以被持有锁的线程多次执行 acquire(),而不会发生阻塞和死锁的问题。它的实现思路很简单:

  1. 规定如果一个线程成功持有了该锁,则将该锁的所有权交给该线程,并且只有该线程可以释放锁,其它线程无法释放;
  2. 当在持有锁的线程中递归获取锁的时候,实际并不会执行底层的 _lock.acquire() 方法,而是只给计数器递增;且释放锁的时候也是先给计数器递减,直到为 0 后才会释放锁。

所以在使用 RLock 的时候一定要记得 acquire()release() 的调用次数得匹配才能真正释放锁。接下来简单看下源码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def RLock(*args, **kwargs):
if _CRLock is None:
# 接下来重点看 Python 版本的 RLock 实现
return _PyRLock(*args, **kwargs)
return _CRLock(*args, **kwargs)

class _RLock:
def __init__(self):
# 这里是真正的锁
self._block = _allocate_lock()
# 记录谁对该锁有所有权
self._owner = None
# 记录该锁被获取的次数,类似引用计数
self._count = 0

def acquire(self, blocking=True, timeout=-1):
me = get_ident()
if self._owner == me:
# 如果当前持有锁的线程就是当前需要获得锁的线程,计数器递增即可
self._count += 1
return 1
rc = self._block.acquire(blocking, timeout)
if rc:
# 如果成功获取到锁后,会把持有锁的线程记录下来,标记该线程是所有权拥有者
self._owner = me
self._count = 1
return rc

def release(self):
if self._owner != get_ident():
# 显而易见,非拥有者不能释放锁,想都不用想!
raise RuntimeError("cannot release un-acquired lock")

# 这里只是递减计数器,只有_count 减没了才会真正释放
self._count = count = self._count - 1
if not count:
self._owner = None
self._block.release()

# 下面的方法是用于条件变量实现时使用
def _acquire_restore(self, state):
# 恢复锁的获取,并且恢复嵌套层次
self._block.acquire()
self._count, self._owner = state

def _release_save(self):
# 需要保证不管有多少层嵌套,都能真正释放锁,但同时返回当前的嵌套状态等信息便于恢复
if self._count == 0:
raise RuntimeError("cannot release un-acquired lock")
count = self._count
self._count = 0
owner = self._owner
self._owner = None
self._block.release()
return (count, owner)

def _is_owned(self):
return self._owner == get_ident()

Condition

条件变量是后面几个同步原语实现的基础,值得重点学习下。条件变量的实现原理比较简单:所有等待的线程会被加入到等待队列中,只有在需要的时候会被唤醒(可以想想如何实现 waiter 线程的等待和唤醒呢?)。

在分析源码前,我们可以看看 Condition 类提供了哪些主要接口:

  • wait(timeout=None),线程可以调用该接口等待被唤醒
  • notify(),线程可以调用该接口通知队列中一个或多个等待线程被唤醒

接下来看看源码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class Condition:
def __init__(self, lock=None):
if lock is None:
lock = RLock()
self._lock = lock
# Export the lock's acquire() and release() methods
self.acquire = lock.acquire
self.release = lock.release
# If the lock defines _release_save() and/or _acquire_restore(),
# these override the default implementations (which just call
# release() and acquire() on the lock). Ditto for _is_owned().
try:
self._release_save = lock._release_save
except AttributeError:
pass
try:
self._acquire_restore = lock._acquire_restore
except AttributeError:
pass
try:
self._is_owned = lock._is_owned
except AttributeError:
pass
self._waiters = _deque()

def _release_save(self):
self._lock.release() # No state to save

def _acquire_restore(self, x):
self._lock.acquire() # Ignore saved state

def _is_owned(self):
# Return True if lock is owned by current_thread.
# This method is called only if _lock doesn't have _is_owned().
if self._lock.acquire(False):
self._lock.release()
return False
else:
return True

def wait(self, timeout=None):
if not self._is_owned():
raise RuntimeError("cannot wait on un-acquired lock")
waiter = _allocate_lock()
waiter.acquire()
self._waiters.append(waiter)
saved_state = self._release_save()
gotit = False
try: # restore state no matter what (e.g., KeyboardInterrupt)
if timeout is None:
waiter.acquire()
gotit = True
else:
if timeout > 0:
gotit = waiter.acquire(True, timeout)
else:
gotit = waiter.acquire(False)
return gotit
finally:
self._acquire_restore(saved_state)
if not gotit:
try:
self._waiters.remove(waiter)
except ValueError:
pass

def wait_for(self, predicate, timeout=None):
endtime = None
waittime = timeout
result = predicate()
while not result:
if waittime is not None:
if endtime is None:
endtime = _time() + waittime
else:
waittime = endtime - _time()
if waittime <= 0:
break
self.wait(waittime)
result = predicate()
return result

def notify(self, n=1):
if not self._is_owned():
raise RuntimeError("cannot notify on un-acquired lock")
all_waiters = self._waiters
waiters_to_notify = _deque(_islice(all_waiters, n))
if not waiters_to_notify:
return
for waiter in waiters_to_notify:
waiter.release()
try:
all_waiters.remove(waiter)
except ValueError:
pass

Semaphore

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class Semaphore:
"""This class implements semaphore objects.
Semaphores manage a counter representing the number of release() calls minus
the number of acquire() calls, plus an initial value. The acquire() method
blocks if necessary until it can return without making the counter
negative. If not given, value defaults to 1.
"""
# After Tim Peters' semaphore class, but not quite the same (no maximum)
def __init__(self, value=1):
if value < 0:
raise ValueError("semaphore initial value must be >= 0")
self._cond = Condition(Lock())
self._value = value

def acquire(self, blocking=True, timeout=None):
"""Acquire a semaphore, decrementing the internal counter by one.
When invoked without arguments: if the internal counter is larger than
zero on entry, decrement it by one and return immediately. If it is zero
on entry, block, waiting until some other thread has called release() to
make it larger than zero. This is done with proper interlocking so that
if multiple acquire() calls are blocked, release() will wake exactly one
of them up. The implementation may pick one at random, so the order in
which blocked threads are awakened should not be relied on. There is no
return value in this case.
When invoked with blocking set to true, do the same thing as when called
without arguments, and return true.
When invoked with blocking set to false, do not block. If a call without
an argument would block, return false immediately; otherwise, do the
same thing as when called without arguments, and return true.
When invoked with a timeout other than None, it will block for at
most timeout seconds. If acquire does not complete successfully in
that interval, return false. Return true otherwise.
"""
if not blocking and timeout is not None:
raise ValueError("can't specify timeout for non-blocking acquire")
rc = False
endtime = None
with self._cond:
while self._value == 0:
if not blocking:
break
if timeout is not None:
if endtime is None:
endtime = _time() + timeout
else:
timeout = endtime - _time()
if timeout <= 0:
break
self._cond.wait(timeout)
else:
self._value -= 1
rc = True
return rc

def release(self, n=1):
"""Release a semaphore, incrementing the internal counter by one or more.
When the counter is zero on entry and another thread is waiting for it
to become larger than zero again, wake up that thread.
"""
if n < 1:
raise ValueError('n must be one or more')
with self._cond:
self._value += n
for i in range(n):
self._cond.notify()

class BoundedSemaphore(Semaphore):
"""Implements a bounded semaphore.
A bounded semaphore checks to make sure its current value doesn't exceed its
initial value. If it does, ValueError is raised. In most situations
semaphores are used to guard resources with limited capacity.
If the semaphore is released too many times it's a sign of a bug. If not
given, value defaults to 1.
Like regular semaphores, bounded semaphores manage a counter representing
the number of release() calls minus the number of acquire() calls, plus an
initial value. The acquire() method blocks if necessary until it can return
without making the counter negative. If not given, value defaults to 1.
"""
def __init__(self, value=1):
Semaphore.__init__(self, value)
self._initial_value = value

def release(self, n=1):
"""Release a semaphore, incrementing the internal counter by one or more.
When the counter is zero on entry and another thread is waiting for it
to become larger than zero again, wake up that thread.
If the number of releases exceeds the number of acquires,
raise a ValueError.
"""
if n < 1:
raise ValueError('n must be one or more')
with self._cond:
if self._value + n > self._initial_value:
raise ValueError("Semaphore released too many times")
self._value += n
for i in range(n):
self._cond.notify()

Event

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Event:
"""Class implementing event objects.
Events manage a flag that can be set to true with the set() method and reset
to false with the clear() method. The wait() method blocks until the flag is
true. The flag is initially false.
"""
# After Tim Peters' event class (without is_posted())
def __init__(self):
self._cond = Condition(Lock())
self._flag = False

def _reset_internal_locks(self):
# private! called by Thread._reset_internal_locks by _after_fork()
self._cond.__init__(Lock())

def is_set(self):
"""Return true if and only if the internal flag is true."""
return self._flag

isSet = is_set
def set(self):
"""Set the internal flag to true.
All threads waiting for it to become true are awakened. Threads
that call wait() once the flag is true will not block at all.
"""
with self._cond:
self._flag = True
self._cond.notify_all()

def clear(self):
"""Reset the internal flag to false.
Subsequently, threads calling wait() will block until set() is called to
set the internal flag to true again.
"""
with self._cond:
self._flag = False

def wait(self, timeout=None):
"""Block until the internal flag is true.
If the internal flag is true on entry, return immediately. Otherwise,
block until another thread calls set() to set the flag to true, or until
the optional timeout occurs.
When the timeout argument is present and not None, it should be a
floating point number specifying a timeout for the operation in seconds
(or fractions thereof).
This method returns the internal flag on exit, so it will always return
True except if a timeout is given and the operation times out.
"""
with self._cond:
signaled = self._flag
if not signaled:
signaled = self._cond.wait(timeout)
return signaled

Barrier

通常可以使用 Barrier 实现并发初始化,然后一切就绪后才会进入下一个阶段。应用示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding: utf-8
from threading import get_ident as get_ident
from threading import Barrier, Thread

def signal_prepared():
print("All are ready")

barrier = Barrier(parties=4, action=signal_prepared)

def main():
Thread(target=load_disk_files).start()
Thread(target=make_cache).start()
Thread(target=init_db_pool).start()
print("I'm ready, wait for other workers")
barrier.wait()
print("Time to start our server")

def load_disk_files():
print(u"[{}] load_disk_files".format(get_ident()))
barrier.wait()

def make_cache():
print(u"[{}] make cache".format(get_ident()))
barrier.wait()

def init_db_pool():
print(u"[{}] init db pool".format(get_ident()))
barrier.wait()

if __name__ == '__main__':
main()

运行效果:

接下来看看 Barrier 是如何实现的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Barrier 是基于部分的 `pthread_barrier_*` API 和 Java 中的 `CyclicBarrier`
# 参考:
# 1. http://sourceware.org/pthreads-win32/manual/pthread_barrier_init.html and
# 2. http://java.sun.com/j2se/1.5.0/docs/api/java/util/concurrent/CyclicBarrier.html
# 在内部维护了两种主要的状态:`filling` 和 `draining`,从而让屏障变成可循环使用的。
# 只有上一个周期完全排水(`drained`)完毕才可以允许新的线程进入(对比下漏桶限流算法)
# 此外,这里还提供了 `resetting` 状态,它类似于 `draining`,但是会让线程离开的时候抛出 `BrokeBarrierError`
# `broken` 状态表示所有的线程都产生了异常
class Barrier:
"""I
我们通常可以使用屏障让多个线程在相同的同步点同步开始(可以想象下有个水缸,水不断地流进来
但是会在某个点一起放开,形成洪流...)。所有调用了 `wait()` 的线程会在条件满足时几乎同时被唤醒,
然后大家就可以一起快乐地干活了。
"""
def __init__(self, parties, action=None, timeout=None):
"""Create a barrier, initialised to 'parties' threads.
'action' is a callable which, when supplied, will be called by one of
the threads after they have all entered the barrier and just prior to
releasing them all. If a 'timeout' is provided, it is used as the
default for all subsequent 'wait()' calls.
"""
self._cond = Condition(Lock())
self._action = action
self._timeout = timeout
self._parties = parties
self._state = 0 #0 filling, 1, draining, -1 resetting, -2 broken
self._count = 0

def wait(self, timeout=None):
"""Wait for the barrier.
When the specified number of threads have started waiting, they are all
simultaneously awoken. If an 'action' was provided for the barrier, one
of the threads will have executed that callback prior to returning.
Returns an individual index number from 0 to 'parties-1'.
"""
if timeout is None:
timeout = self._timeout
with self._cond:
self._enter() # Block while the barrier drains.
index = self._count
self._count += 1
try:
if index + 1 == self._parties:
# We release the barrier
self._release()
else:
# We wait until someone releases us
self._wait(timeout)
return index
finally:
self._count -= 1
# Wake up any threads waiting for barrier to drain.
self._exit()
# Block until the barrier is ready for us, or raise an exception
# if it is broken.

def _enter(self):
while self._state in (-1, 1):
# It is draining or resetting, wait until done
self._cond.wait()
#see if the barrier is in a broken state
if self._state < 0:
raise BrokenBarrierError
assert self._state == 0

# Optionally run the 'action' and release the threads waiting
# in the barrier.
def _release(self):
try:
if self._action:
self._action()
# enter draining state
self._state = 1
self._cond.notify_all()
except:
#an exception during the _action handler. Break and reraise
self._break()
raise

# Wait in the barrier until we are released. Raise an exception
# if the barrier is reset or broken.
def _wait(self, timeout):
if not self._cond.wait_for(lambda : self._state != 0, timeout):
#timed out. Break the barrier
self._break()
raise BrokenBarrierError
if self._state < 0:
raise BrokenBarrierError
assert self._state == 1

# If we are the last thread to exit the barrier, signal any threads
# waiting for the barrier to drain.
def _exit(self):
if self._count == 0:
if self._state in (-1, 1):
#resetting or draining
self._state = 0
self._cond.notify_all()

def reset(self):
"""Reset the barrier to the initial state.
Any threads currently waiting will get the BrokenBarrier exception
raised.
"""
with self._cond:
if self._count > 0:
if self._state == 0:
#reset the barrier, waking up threads
self._state = -1
elif self._state == -2:
#was broken, set it to reset state
#which clears when the last thread exits
self._state = -1
else:
self._state = 0
self._cond.notify_all()

def abort(self):
"""Place the barrier into a 'broken' state.
Useful in case of error. Any currently waiting threads and threads
attempting to 'wait()' will have BrokenBarrierError raised.
"""
with self._cond:
self._break()

def _break(self):
# An internal error was detected. The barrier is set to
# a broken state all parties awakened.
self._state = -2
self._cond.notify_all()

@property
def parties(self):
"""Return the number of threads required to trip the barrier."""
return self._parties

@property
def n_waiting(self):
"""Return the number of threads currently waiting at the barrier."""
# We don't need synchronization here since this is an ephemeral result
# anyway. It returns the correct value in the steady state.
if self._state == 0:
return self._count
return 0

@property
def broken(self):
"""Return True if the barrier is in a broken state."""
return self._state == -2

总结

Python 源码的注释太丰富了,以至于我都不想翻译成中文。所以结合注释看代码即可~

0%