引言
虽然说 Python 受限于 CPython 的实现,存在的 GIL 会导致我们在使用多线程的时候,没法利用多核跑多线程。但是有的时候还是会用到线程的,尤其是针对一些 I/O 密集型的任务,也可以使用它们。
在使用多线程编程时,我们随时需要注意竞态条件(race condition)和数据竞争(data race)的问题,前者会导致我们在不同的时间点运行程序得到的输出可能不同;而后者则更为可怕,容易导致共享的数据结构被错误修改,甚至导致程序崩溃或者出现莫名其妙的 Bug。这个时候自然就要用到 Python threading 模块为我们提供的若干同步原语了。
那么,我们常用的 Lock、RLock、条件变量(Condition Variables)、信号量(Semaphore)等是如何实现的呢?接下来的源码学习是基于 CPython master 分支的线程模块。希望在学习完它们的实现后,能够加深理解,合理运用。
源码学习
CPython 的 threading 模块实际上是基于 Java 的线程模型实现的,所以熟悉 Java 的话,自然也不会对该模块的实现感到陌生。该模块是基于更底层的 _thread
模块,抽象出更加方便使用的线程模型,核心包括 threading.Thread
线程类封装,便于用户继承或组合;此外还有一些同步原语的实现。Python/thread_nt.h
文件中是 C 语言实现的底层和线程有关的函数(如锁的创建和维护、线程的创建和管理)。
同步原语
Lock
该模块中,Lock
其实是使用了底层 _thread.allocate_lock
函数来创建锁的。代码也很简单:
1 | Lock = _allocate_lock |
Lock 为我们提供了 acquire()
和 release()
这两个主要的方法。当一个线程持有锁时,其它线程调用 acquire()
方法时会被阻塞(此时线程一般就是睡眠等待了),直到主动 release()
后,等待锁的线程会被唤醒。
关于 Lock 有两点值得注意:
- 该锁是不可重入的,也就是如果在一个函数中递归
acquire()
会导致死锁的问题。为了避免这种问题,一般会使用RLock
来代替 - Lock 并非 Mutex(互斥锁),且它底层是通过信号量那样实现的,本身不会记录谁持有了该锁,也就是说 Lock 可以在不同的线程中被引用,可以在主线程获取,而在子线程释放它。具体可以在
CPython/Python/thread_nt.h:PyThread_allocate_lock
可以看到它的实现如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28/*
* Lock support. It has to be implemented as semaphores.
* I [Dag] tried to implement it with mutex but I could find a way to
* tell whether a thread already own the lock or not.
* Lock 支持:它必须以信号量的方式来实现。我尝试使用互斥锁实现过,但是我
* 发现了另外一种方式可以得知一个线程是否持有了锁。
*/
PyThread_type_lock
PyThread_allocate_lock(void)
{
PNRMUTEX aLock;
dprintf(("PyThread_allocate_lock called\n"));
if (!initialized)
PyThread_init_thread();
aLock = AllocNonRecursiveMutex() ;
dprintf(("%lu: PyThread_allocate_lock() -> %p\n", PyThread_get_thread_ident(), aLock));
return (PyThread_type_lock) aLock;
}
// 其中 PNRMUTEX 定义如下,它并不会告诉我们当前是哪个线程
// 持有了锁
typedef struct _NRMUTEX
{
PyMUTEX_T cs;
PyCOND_T cv;
int locked;
} NRMUTEX;
typedef NRMUTEX *PNRMUTEX;
比较有趣的是,其实 PyCOND
即条件变量是通过信号量来实现的;而接下来我们会看到,在 Python 的 threading 模块中,我们使用了 Condition 实现了信号量。
RLock
RLock 就是可重入锁(Reentrant Lock),它可以被持有锁的线程多次执行 acquire()
,而不会发生阻塞和死锁的问题。它的实现思路很简单:
- 规定如果一个线程成功持有了该锁,则将该锁的所有权交给该线程,并且只有该线程可以释放锁,其它线程无法释放;
- 当在持有锁的线程中递归获取锁的时候,实际并不会执行底层的
_lock.acquire()
方法,而是只给计数器递增;且释放锁的时候也是先给计数器递减,直到为 0 后才会释放锁。
所以在使用 RLock 的时候一定要记得 acquire()
和 release()
的调用次数得匹配才能真正释放锁。接下来简单看下源码实现:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58def RLock(*args, **kwargs):
if _CRLock is None:
# 接下来重点看 Python 版本的 RLock 实现
return _PyRLock(*args, **kwargs)
return _CRLock(*args, **kwargs)
class _RLock:
def __init__(self):
# 这里是真正的锁
self._block = _allocate_lock()
# 记录谁对该锁有所有权
self._owner = None
# 记录该锁被获取的次数,类似引用计数
self._count = 0
def acquire(self, blocking=True, timeout=-1):
me = get_ident()
if self._owner == me:
# 如果当前持有锁的线程就是当前需要获得锁的线程,计数器递增即可
self._count += 1
return 1
rc = self._block.acquire(blocking, timeout)
if rc:
# 如果成功获取到锁后,会把持有锁的线程记录下来,标记该线程是所有权拥有者
self._owner = me
self._count = 1
return rc
def release(self):
if self._owner != get_ident():
# 显而易见,非拥有者不能释放锁,想都不用想!
raise RuntimeError("cannot release un-acquired lock")
# 这里只是递减计数器,只有_count 减没了才会真正释放
self._count = count = self._count - 1
if not count:
self._owner = None
self._block.release()
# 下面的方法是用于条件变量实现时使用
def _acquire_restore(self, state):
# 恢复锁的获取,并且恢复嵌套层次
self._block.acquire()
self._count, self._owner = state
def _release_save(self):
# 需要保证不管有多少层嵌套,都能真正释放锁,但同时返回当前的嵌套状态等信息便于恢复
if self._count == 0:
raise RuntimeError("cannot release un-acquired lock")
count = self._count
self._count = 0
owner = self._owner
self._owner = None
self._block.release()
return (count, owner)
def _is_owned(self):
return self._owner == get_ident()
Condition
条件变量是后面几个同步原语实现的基础,值得重点学习下。条件变量的实现原理比较简单:所有等待的线程会被加入到等待队列中,只有在需要的时候会被唤醒(可以想想如何实现 waiter 线程的等待和唤醒呢?)。
在分析源码前,我们可以看看 Condition
类提供了哪些主要接口:
wait(timeout=None)
,线程可以调用该接口等待被唤醒notify()
,线程可以调用该接口通知队列中一个或多个等待线程被唤醒
接下来看看源码实现:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95class Condition:
def __init__(self, lock=None):
if lock is None:
lock = RLock()
self._lock = lock
# Export the lock's acquire() and release() methods
self.acquire = lock.acquire
self.release = lock.release
# If the lock defines _release_save() and/or _acquire_restore(),
# these override the default implementations (which just call
# release() and acquire() on the lock). Ditto for _is_owned().
try:
self._release_save = lock._release_save
except AttributeError:
pass
try:
self._acquire_restore = lock._acquire_restore
except AttributeError:
pass
try:
self._is_owned = lock._is_owned
except AttributeError:
pass
self._waiters = _deque()
def _release_save(self):
self._lock.release() # No state to save
def _acquire_restore(self, x):
self._lock.acquire() # Ignore saved state
def _is_owned(self):
# Return True if lock is owned by current_thread.
# This method is called only if _lock doesn't have _is_owned().
if self._lock.acquire(False):
self._lock.release()
return False
else:
return True
def wait(self, timeout=None):
if not self._is_owned():
raise RuntimeError("cannot wait on un-acquired lock")
waiter = _allocate_lock()
waiter.acquire()
self._waiters.append(waiter)
saved_state = self._release_save()
gotit = False
try: # restore state no matter what (e.g., KeyboardInterrupt)
if timeout is None:
waiter.acquire()
gotit = True
else:
if timeout > 0:
gotit = waiter.acquire(True, timeout)
else:
gotit = waiter.acquire(False)
return gotit
finally:
self._acquire_restore(saved_state)
if not gotit:
try:
self._waiters.remove(waiter)
except ValueError:
pass
def wait_for(self, predicate, timeout=None):
endtime = None
waittime = timeout
result = predicate()
while not result:
if waittime is not None:
if endtime is None:
endtime = _time() + waittime
else:
waittime = endtime - _time()
if waittime <= 0:
break
self.wait(waittime)
result = predicate()
return result
def notify(self, n=1):
if not self._is_owned():
raise RuntimeError("cannot notify on un-acquired lock")
all_waiters = self._waiters
waiters_to_notify = _deque(_islice(all_waiters, n))
if not waiters_to_notify:
return
for waiter in waiters_to_notify:
waiter.release()
try:
all_waiters.remove(waiter)
except ValueError:
pass
Semaphore
1 | class Semaphore: |
Event
1 | class Event: |
Barrier
通常可以使用 Barrier 实现并发初始化,然后一切就绪后才会进入下一个阶段。应用示例如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31# coding: utf-8
from threading import get_ident as get_ident
from threading import Barrier, Thread
def signal_prepared():
print("All are ready")
barrier = Barrier(parties=4, action=signal_prepared)
def main():
Thread(target=load_disk_files).start()
Thread(target=make_cache).start()
Thread(target=init_db_pool).start()
print("I'm ready, wait for other workers")
barrier.wait()
print("Time to start our server")
def load_disk_files():
print(u"[{}] load_disk_files".format(get_ident()))
barrier.wait()
def make_cache():
print(u"[{}] make cache".format(get_ident()))
barrier.wait()
def init_db_pool():
print(u"[{}] init db pool".format(get_ident()))
barrier.wait()
if __name__ == '__main__':
main()
运行效果:
接下来看看 Barrier 是如何实现的:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149# Barrier 是基于部分的 `pthread_barrier_*` API 和 Java 中的 `CyclicBarrier`
# 参考:
# 1. http://sourceware.org/pthreads-win32/manual/pthread_barrier_init.html and
# 2. http://java.sun.com/j2se/1.5.0/docs/api/java/util/concurrent/CyclicBarrier.html
# 在内部维护了两种主要的状态:`filling` 和 `draining`,从而让屏障变成可循环使用的。
# 只有上一个周期完全排水(`drained`)完毕才可以允许新的线程进入(对比下漏桶限流算法)
# 此外,这里还提供了 `resetting` 状态,它类似于 `draining`,但是会让线程离开的时候抛出 `BrokeBarrierError`
# `broken` 状态表示所有的线程都产生了异常
class Barrier:
"""I
我们通常可以使用屏障让多个线程在相同的同步点同步开始(可以想象下有个水缸,水不断地流进来
但是会在某个点一起放开,形成洪流...)。所有调用了 `wait()` 的线程会在条件满足时几乎同时被唤醒,
然后大家就可以一起快乐地干活了。
"""
def __init__(self, parties, action=None, timeout=None):
"""Create a barrier, initialised to 'parties' threads.
'action' is a callable which, when supplied, will be called by one of
the threads after they have all entered the barrier and just prior to
releasing them all. If a 'timeout' is provided, it is used as the
default for all subsequent 'wait()' calls.
"""
self._cond = Condition(Lock())
self._action = action
self._timeout = timeout
self._parties = parties
self._state = 0 #0 filling, 1, draining, -1 resetting, -2 broken
self._count = 0
def wait(self, timeout=None):
"""Wait for the barrier.
When the specified number of threads have started waiting, they are all
simultaneously awoken. If an 'action' was provided for the barrier, one
of the threads will have executed that callback prior to returning.
Returns an individual index number from 0 to 'parties-1'.
"""
if timeout is None:
timeout = self._timeout
with self._cond:
self._enter() # Block while the barrier drains.
index = self._count
self._count += 1
try:
if index + 1 == self._parties:
# We release the barrier
self._release()
else:
# We wait until someone releases us
self._wait(timeout)
return index
finally:
self._count -= 1
# Wake up any threads waiting for barrier to drain.
self._exit()
# Block until the barrier is ready for us, or raise an exception
# if it is broken.
def _enter(self):
while self._state in (-1, 1):
# It is draining or resetting, wait until done
self._cond.wait()
#see if the barrier is in a broken state
if self._state < 0:
raise BrokenBarrierError
assert self._state == 0
# Optionally run the 'action' and release the threads waiting
# in the barrier.
def _release(self):
try:
if self._action:
self._action()
# enter draining state
self._state = 1
self._cond.notify_all()
except:
#an exception during the _action handler. Break and reraise
self._break()
raise
# Wait in the barrier until we are released. Raise an exception
# if the barrier is reset or broken.
def _wait(self, timeout):
if not self._cond.wait_for(lambda : self._state != 0, timeout):
#timed out. Break the barrier
self._break()
raise BrokenBarrierError
if self._state < 0:
raise BrokenBarrierError
assert self._state == 1
# If we are the last thread to exit the barrier, signal any threads
# waiting for the barrier to drain.
def _exit(self):
if self._count == 0:
if self._state in (-1, 1):
#resetting or draining
self._state = 0
self._cond.notify_all()
def reset(self):
"""Reset the barrier to the initial state.
Any threads currently waiting will get the BrokenBarrier exception
raised.
"""
with self._cond:
if self._count > 0:
if self._state == 0:
#reset the barrier, waking up threads
self._state = -1
elif self._state == -2:
#was broken, set it to reset state
#which clears when the last thread exits
self._state = -1
else:
self._state = 0
self._cond.notify_all()
def abort(self):
"""Place the barrier into a 'broken' state.
Useful in case of error. Any currently waiting threads and threads
attempting to 'wait()' will have BrokenBarrierError raised.
"""
with self._cond:
self._break()
def _break(self):
# An internal error was detected. The barrier is set to
# a broken state all parties awakened.
self._state = -2
self._cond.notify_all()
def parties(self):
"""Return the number of threads required to trip the barrier."""
return self._parties
def n_waiting(self):
"""Return the number of threads currently waiting at the barrier."""
# We don't need synchronization here since this is an ephemeral result
# anyway. It returns the correct value in the steady state.
if self._state == 0:
return self._count
return 0
def broken(self):
"""Return True if the barrier is in a broken state."""
return self._state == -2
总结
Python 源码的注释太丰富了,以至于我都不想翻译成中文。所以结合注释看代码即可~