数据透视表¶
之前关于groupby抽象类用以探索数据集内部的关联性,数据透视表是一种类似的操作方法。数据透视表常见于Excel等类似的表格中,是将一列数据作为输入,输出为一个表格。
透视表与groupby的差异有时会很难区分,我们可以理解为透视表是一个多维的groupby累计方法,也就是运用分割-应用-组合过程中分割与组合不时发生在一维索引上,而是发生在二维网格上。
透视表的驱动因素¶
为了很好说明数据透视表,我们使用泰坦尼克数据库 Titanic。该数据存储在 Seaborn 库中。
import numpy as np
import pandas as pd
import seaborn as sns
titanic = sns.load_dataset('titanic')
#titanic.to_csv("Titanic.csv")
#titanic = pd.read_csv("Titanic.csv")
titanic
| survived | pclass | sex | age | sibsp | parch | fare | embarked | class | who | adult_male | deck | embark_town | alive | alone | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 3 | male | 22.0 | 1 | 0 | 7.2500 | S | Third | man | True | NaN | Southampton | no | False |
| 1 | 1 | 1 | female | 38.0 | 1 | 0 | 71.2833 | C | First | woman | False | C | Cherbourg | yes | False |
| 2 | 1 | 3 | female | 26.0 | 0 | 0 | 7.9250 | S | Third | woman | False | NaN | Southampton | yes | True |
| 3 | 1 | 1 | female | 35.0 | 1 | 0 | 53.1000 | S | First | woman | False | C | Southampton | yes | False |
| 4 | 0 | 3 | male | 35.0 | 0 | 0 | 8.0500 | S | Third | man | True | NaN | Southampton | no | True |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 886 | 0 | 2 | male | 27.0 | 0 | 0 | 13.0000 | S | Second | man | True | NaN | Southampton | no | True |
| 887 | 1 | 1 | female | 19.0 | 0 | 0 | 30.0000 | S | First | woman | False | B | Southampton | yes | True |
| 888 | 0 | 3 | female | NaN | 1 | 2 | 23.4500 | S | Third | woman | False | NaN | Southampton | no | False |
| 889 | 1 | 1 | male | 26.0 | 0 | 0 | 30.0000 | C | First | man | True | C | Cherbourg | yes | True |
| 890 | 0 | 3 | male | 32.0 | 0 | 0 | 7.7500 | Q | Third | man | True | NaN | Queenstown | no | True |
891 rows × 15 columns
该表包含了泰坦尼克号上惨遭厄运的每位乘客大量信息,包括性别(gender)、年龄(age)、船舱等级(Class)、船票价格(fare paid)等等
手工制作透视表¶
首先我们使用之前的方法查看性别与生还比例的关系:
titanic.groupby('sex')[['survived']].mean()
| survived | |
|---|---|
| sex | |
| female | 0.742038 |
| male | 0.188908 |
明显看出,所有女性中生存率接近75%,而男性则不足20%。如果我们增加另一个变量如仓位等级,这样仍然使用groupby方法:
titanic.groupby(['sex', 'class'],observed=True)['survived'].aggregate('mean').unstack()
| class | First | Second | Third |
|---|---|---|---|
| sex | |||
| female | 0.968085 | 0.921053 | 0.500000 |
| male | 0.368852 | 0.157407 | 0.135447 |
这样可以清晰的观察到性别、船舱等级对其是否生还的影响,但代码看上去有点复杂。这种长语句表达方式也被称为管道(Pipline),这是在Pandas中非常常用的一种表达方法。由于在二维数组中GroupBy 应用场景非常普遍,因此,Pandas提供了一个 pivot_table方法快速解决多维累计问题。
数据透视表语法¶
在DataFrame中使用 pivot_table方法与之前使用 Groupby的管道方法呈现的结果是一样的:
titanic.pivot_table('survived', index='sex', columns='class',observed=True)
| class | First | Second | Third |
|---|---|---|---|
| sex | |||
| female | 0.968085 | 0.921053 | 0.500000 |
| male | 0.368852 | 0.157407 | 0.135447 |
这种写法显然比 groupby 方法更具有可读性,且效果相同。从表中可以看出来,生还率最高的是仓位等级最高的女性,三等舱的男性生还率仅有十分之一左右。
多级数据透视表¶
与GroupBy类似, 数据透视表分组也可以通过各种参数指定多个等级。例如我们把年龄这个因素加进去作为第三个维度,首先利用pd.cut函数将年龄分成不同的的阶段。并加入数据分析中:
age = pd.cut(titanic['age'], [0, 18, 80])
age
0 (18.0, 80.0]
1 (18.0, 80.0]
2 (18.0, 80.0]
3 (18.0, 80.0]
4 (18.0, 80.0]
...
886 (18.0, 80.0]
887 (18.0, 80.0]
888 NaN
889 (18.0, 80.0]
890 (18.0, 80.0]
Name: age, Length: 891, dtype: category
Categories (2, interval[int64, right]): [(0, 18] < (18, 80]]
titanic.pivot_table('survived', ['sex', age], 'class',observed=True)
| class | First | Second | Third | |
|---|---|---|---|---|
| sex | age | |||
| female | (0, 18] | 0.909091 | 1.000000 | 0.511628 |
| (18, 80] | 0.972973 | 0.900000 | 0.423729 | |
| male | (0, 18] | 0.800000 | 0.600000 | 0.215686 |
| (18, 80] | 0.375000 | 0.071429 | 0.133663 |
我们也可以同过同样的方法加入新的要素,如把船票价格两等分,应用pd.qcut函数,按照分位数进行等分。
fare = pd.qcut(titanic['fare'], 2)
fare
0 (-0.001, 14.454]
1 (14.454, 512.329]
2 (-0.001, 14.454]
3 (14.454, 512.329]
4 (-0.001, 14.454]
...
886 (-0.001, 14.454]
887 (14.454, 512.329]
888 (14.454, 512.329]
889 (14.454, 512.329]
890 (-0.001, 14.454]
Name: fare, Length: 891, dtype: category
Categories (2, interval[float64, right]): [(-0.001, 14.454] < (14.454, 512.329]]
#fare = pd.qcut(titanic['fare'], 2)
titanic.pivot_table('survived', ['sex', age], [fare, 'class'],observed=True)
| fare | (-0.001, 14.454] | (14.454, 512.329] | |||||
|---|---|---|---|---|---|---|---|
| class | First | Second | Third | First | Second | Third | |
| sex | age | ||||||
| female | (0, 18] | NaN | 1.000000 | 0.714286 | 0.909091 | 1.000000 | 0.318182 |
| (18, 80] | NaN | 0.880000 | 0.444444 | 0.972973 | 0.914286 | 0.391304 | |
| male | (0, 18] | NaN | 0.000000 | 0.260870 | 0.800000 | 0.818182 | 0.178571 |
| (18, 80] | 0.0 | 0.098039 | 0.125000 | 0.391304 | 0.030303 | 0.192308 | |
这个结果是一个带有层级索引的思维累计数据表,通过网格显式不同数据之间的相关性。
其他数据透视表选项¶
DataFrame中 pivot_table 方法的全部格式如下所示:
# call signature as of Pandas 2.2.3
pandas.pivot_table(data, values=None, index=None, columns=None, aggfunc='mean', fill_value=None, margins=False, dropna=True, margins_name='All', observed=<no_default>, sort=True)
我们已经介绍了前面的三个参数,现在来看其他参数:fill_value 和dropna 用来处理缺失值,用法与前述的处理缺失值方法一致;aggfunc参数用于设置累计函数类型,默认为均值。与groupby类似,累计函数可以使用一些常见的字符串('sum','mean','count','min','max'等)表示,也可以用标准累计函数(如 np.sum(), min(), sum()等等)表示。也可以利用字典对不同的列指定不同的累计函数。
titanic.pivot_table(index='sex', columns='class',
aggfunc={'survived':'sum', 'fare':'mean'},observed=True)
| fare | survived | |||||
|---|---|---|---|---|---|---|
| class | First | Second | Third | First | Second | Third |
| sex | ||||||
| female | 106.125798 | 21.970121 | 16.118810 | 91 | 70 | 72 |
| male | 67.226127 | 19.741782 | 12.661633 | 45 | 17 | 47 |
注意,这里我们忽略了一个参数 values关键字; 当我们为aggfunc指定映射关系的时候,待透视的数值就已经确定了。
当需要进行合计的时候, margins 关键字就起到作用了,当margins=True和margins_name="合计"时:
titanic.pivot_table('survived', index='sex', columns='class', margins=True,margins_name="合计",observed=True)
| class | First | Second | Third | 合计 |
|---|---|---|---|---|
| sex | ||||
| female | 0.968085 | 0.921053 | 0.500000 | 0.742038 |
| male | 0.368852 | 0.157407 | 0.135447 | 0.188908 |
| 合计 | 0.629630 | 0.472826 | 0.242363 | 0.383838 |
案例:美国人的出生日¶
来看一个有趣的例子,由美国疾病防治中心(CDC)提供的公开生日数据,这些数据可以通过 https://raw.githubusercontent.com/jakevdp/data-CDCbirths/master/births.csv 下载。
As a more interesting example, let's take a look at the freely available data on births in the United States, provided by the Centers for Disease Control (CDC). This data can be found at https://raw.githubusercontent.com/jakevdp/data-CDCbirths/master/births.csv (this dataset has been analyzed rather extensively by Andrew Gelman and his group; see, for example, this blog post):
# shell command to download the data:
# !curl -O https://raw.githubusercontent.com/jakevdp/data-CDCbirths/master/births.csv
data = 'https://raw.githubusercontent.com/jakevdp/data-CDCbirths/master/births.csv'
births = pd.read_csv(data)
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[35], line 1 ----> 1 births = pd.read_csv(data,skiprows=0) File d:\anaconda3\Lib\site-packages\pandas\io\parsers\readers.py:1026, in read_csv(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend) 1013 kwds_defaults = _refine_defaults_read( 1014 dialect, 1015 delimiter, (...) 1022 dtype_backend=dtype_backend, 1023 ) 1024 kwds.update(kwds_defaults) -> 1026 return _read(filepath_or_buffer, kwds) File d:\anaconda3\Lib\site-packages\pandas\io\parsers\readers.py:620, in _read(filepath_or_buffer, kwds) 617 _validate_names(kwds.get("names", None)) 619 # Create the parser. --> 620 parser = TextFileReader(filepath_or_buffer, **kwds) 622 if chunksize or iterator: 623 return parser File d:\anaconda3\Lib\site-packages\pandas\io\parsers\readers.py:1620, in TextFileReader.__init__(self, f, engine, **kwds) 1617 self.options["has_index_names"] = kwds["has_index_names"] 1619 self.handles: IOHandles | None = None -> 1620 self._engine = self._make_engine(f, self.engine) File d:\anaconda3\Lib\site-packages\pandas\io\parsers\readers.py:1880, in TextFileReader._make_engine(self, f, engine) 1878 if "b" not in mode: 1879 mode += "b" -> 1880 self.handles = get_handle( 1881 f, 1882 mode, 1883 encoding=self.options.get("encoding", None), 1884 compression=self.options.get("compression", None), 1885 memory_map=self.options.get("memory_map", False), 1886 is_text=is_text, 1887 errors=self.options.get("encoding_errors", "strict"), 1888 storage_options=self.options.get("storage_options", None), 1889 ) 1890 assert self.handles is not None 1891 f = self.handles.handle File d:\anaconda3\Lib\site-packages\pandas\io\common.py:728, in get_handle(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options) 725 codecs.lookup_error(errors) 727 # open URLs --> 728 ioargs = _get_filepath_or_buffer( 729 path_or_buf, 730 encoding=encoding, 731 compression=compression, 732 mode=mode, 733 storage_options=storage_options, 734 ) 736 handle = ioargs.filepath_or_buffer 737 handles: list[BaseBuffer] File d:\anaconda3\Lib\site-packages\pandas\io\common.py:384, in _get_filepath_or_buffer(filepath_or_buffer, encoding, compression, mode, storage_options) 382 # assuming storage_options is to be interpreted as headers 383 req_info = urllib.request.Request(filepath_or_buffer, headers=storage_options) --> 384 with urlopen(req_info) as req: 385 content_encoding = req.headers.get("Content-Encoding", None) 386 if content_encoding == "gzip": 387 # Override compression based on Content-Encoding header File d:\anaconda3\Lib\site-packages\pandas\io\common.py:289, in urlopen(*args, **kwargs) 283 """ 284 Lazy-import wrapper for stdlib urlopen, as that imports a big chunk of 285 the stdlib. 286 """ 287 import urllib.request --> 289 return urllib.request.urlopen(*args, **kwargs) File d:\anaconda3\Lib\urllib\request.py:215, in urlopen(url, data, timeout, cafile, capath, cadefault, context) 213 else: 214 opener = _opener --> 215 return opener.open(url, data, timeout) File d:\anaconda3\Lib\urllib\request.py:515, in OpenerDirector.open(self, fullurl, data, timeout) 512 req = meth(req) 514 sys.audit('urllib.Request', req.full_url, req.data, req.headers, req.get_method()) --> 515 response = self._open(req, data) 517 # post-process response 518 meth_name = protocol+"_response" File d:\anaconda3\Lib\urllib\request.py:532, in OpenerDirector._open(self, req, data) 529 return result 531 protocol = req.type --> 532 result = self._call_chain(self.handle_open, protocol, protocol + 533 '_open', req) 534 if result: 535 return result File d:\anaconda3\Lib\urllib\request.py:492, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args) 490 for handler in handlers: 491 func = getattr(handler, meth_name) --> 492 result = func(*args) 493 if result is not None: 494 return result File d:\anaconda3\Lib\urllib\request.py:1392, in HTTPSHandler.https_open(self, req) 1391 def https_open(self, req): -> 1392 return self.do_open(http.client.HTTPSConnection, req, 1393 context=self._context) File d:\anaconda3\Lib\urllib\request.py:1348, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args) 1346 except OSError as err: # timeout error 1347 raise URLError(err) -> 1348 r = h.getresponse() 1349 except: 1350 h.close() File d:\anaconda3\Lib\http\client.py:1428, in HTTPConnection.getresponse(self) 1426 try: 1427 try: -> 1428 response.begin() 1429 except ConnectionError: 1430 self.close() File d:\anaconda3\Lib\http\client.py:331, in HTTPResponse.begin(self) 329 # read until we get a non-100 response 330 while True: --> 331 version, status, reason = self._read_status() 332 if status != CONTINUE: 333 break File d:\anaconda3\Lib\http\client.py:292, in HTTPResponse._read_status(self) 291 def _read_status(self): --> 292 line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1") 293 if len(line) > _MAXLINE: 294 raise LineTooLong("status line") File d:\anaconda3\Lib\socket.py:708, in SocketIO.readinto(self, b) 706 while True: 707 try: --> 708 return self._sock.recv_into(b) 709 except timeout: 710 self._timeout_occurred = True File d:\anaconda3\Lib\ssl.py:1252, in SSLSocket.recv_into(self, buffer, nbytes, flags) 1248 if flags != 0: 1249 raise ValueError( 1250 "non-zero flags not allowed in calls to recv_into() on %s" % 1251 self.__class__) -> 1252 return self.read(nbytes, buffer) 1253 else: 1254 return super().recv_into(buffer, nbytes, flags) File d:\anaconda3\Lib\ssl.py:1104, in SSLSocket.read(self, len, buffer) 1102 try: 1103 if buffer is not None: -> 1104 return self._sslobj.read(len, buffer) 1105 else: 1106 return self._sslobj.read(len) KeyboardInterrupt:
如果网络无法直接连接,可以使用已下载的本地数据。
births = pd.read_csv("data/births.csv",index_col=0) #将第一列作为索引
简单看一下会发现数据很简单,只包含不同的出生日期(年月日)与性别的出生人数。
births.head()
| year | month | day | gender | births | |
|---|---|---|---|---|---|
| 0 | 1969 | 1 | 1.0 | F | 4046 |
| 1 | 1969 | 1 | 1.0 | M | 4440 |
| 2 | 1969 | 1 | 2.0 | F | 4454 |
| 3 | 1969 | 1 | 2.0 | M | 4548 |
| 4 | 1969 | 1 | 3.0 | F | 4548 |
我们可以使用透视表来探索这份数据。首先增加一个列,表示不同年代,看看各个年代男女出生的比例。
births['decade'] = 10 * (births['year'] // 10)
births
| year | month | day | gender | births | dayofweek | decade | |
|---|---|---|---|---|---|---|---|
| 1969-01-01 | 1969 | 1 | 1 | F | 4046 | 2 | 1960 |
| 1969-01-01 | 1969 | 1 | 1 | M | 4440 | 2 | 1960 |
| 1969-01-02 | 1969 | 1 | 2 | F | 4454 | 3 | 1960 |
| 1969-01-02 | 1969 | 1 | 2 | M | 4548 | 3 | 1960 |
| 1969-01-03 | 1969 | 1 | 3 | F | 4548 | 4 | 1960 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 1988-12-29 | 1988 | 12 | 29 | M | 5944 | 3 | 1980 |
| 1988-12-30 | 1988 | 12 | 30 | F | 5742 | 4 | 1980 |
| 1988-12-30 | 1988 | 12 | 30 | M | 6095 | 4 | 1980 |
| 1988-12-31 | 1988 | 12 | 31 | F | 4435 | 5 | 1980 |
| 1988-12-31 | 1988 | 12 | 31 | M | 4698 | 5 | 1980 |
14610 rows × 7 columns
births.pivot_table('births', index='decade', columns='gender', aggfunc='sum')
| gender | F | M |
|---|---|---|
| decade | ||
| 1960 | 1753634 | 1846572 |
| 1970 | 16263075 | 17121550 |
| 1980 | 18310351 | 19243452 |
| 1990 | 19479454 | 20420553 |
| 2000 | 18229309 | 19106428 |
我们马上发现每个年代的男性出生率都时高于女性的。为了更直观显式我们绘制图形来看:
%matplotlib inline
import matplotlib.pyplot as plt
sns.set() # use Seaborn styles
births.pivot_table('births', index='year', columns='gender', aggfunc='sum').plot()
plt.ylabel('total births per year');
#计算具体数据,先使用简单计算,再使用对数差计算
piv_table = births.pivot_table('births', index='year', columns='gender', aggfunc='sum')
piv_table["ratio_1"] = (piv_table["M"]-piv_table["F"])/piv_table["F"]
piv_table["ratio_2"] = np.log(piv_table["M"])-np.log(piv_table["F"])
piv_table
| gender | F | M | ratio_1 | ratio_2 |
|---|---|---|---|---|
| year | ||||
| 1969 | 1753634 | 1846572 | 0.052997 | 0.051641 |
| 1970 | 1819164 | 1918636 | 0.054680 | 0.053237 |
| 1971 | 1736774 | 1826774 | 0.051820 | 0.050522 |
| 1972 | 1592347 | 1673888 | 0.051208 | 0.049940 |
| 1973 | 1533102 | 1613023 | 0.052130 | 0.050817 |
| 1974 | 1543005 | 1627626 | 0.054842 | 0.053391 |
| 1975 | 1535546 | 1618010 | 0.053703 | 0.052311 |
| 1976 | 1547613 | 1628863 | 0.052500 | 0.051168 |
| 1977 | 1623363 | 1708796 | 0.052627 | 0.051289 |
| 1978 | 1626324 | 1711976 | 0.052666 | 0.051326 |
| 1979 | 1705837 | 1793958 | 0.051659 | 0.050368 |
| 1980 | 1762459 | 1855522 | 0.052803 | 0.051456 |
| 1981 | 1772037 | 1863478 | 0.051602 | 0.050315 |
| 1982 | 1797239 | 1888218 | 0.050622 | 0.049382 |
| 1983 | 1775299 | 1867522 | 0.051948 | 0.050644 |
| 1984 | 1791802 | 1881766 | 0.050209 | 0.048989 |
| 1985 | 1834774 | 1930290 | 0.052059 | 0.050749 |
| 1986 | 1833708 | 1926987 | 0.050869 | 0.049617 |
| 1987 | 1860111 | 1953105 | 0.049994 | 0.048784 |
| 1988 | 1909210 | 2004583 | 0.049954 | 0.048747 |
| 1989 | 1973712 | 2071981 | 0.049789 | 0.048589 |
| 1990 | 2030966 | 2131951 | 0.049723 | 0.048526 |
| 1991 | 2011601 | 2103741 | 0.045804 | 0.044786 |
| 1992 | 1985118 | 2084310 | 0.049968 | 0.048760 |
| 1993 | 1953456 | 2051067 | 0.049968 | 0.048760 |
| 1994 | 1932234 | 2024691 | 0.047850 | 0.046740 |
| 1995 | 1904871 | 1998141 | 0.048964 | 0.047803 |
| 1996 | 1902664 | 1992210 | 0.047063 | 0.045990 |
| 1997 | 1896928 | 1987401 | 0.047694 | 0.046592 |
| 1998 | 1927106 | 2018086 | 0.047211 | 0.046130 |
| 1999 | 1934510 | 2028955 | 0.048821 | 0.047667 |
| 2000 | 1984255 | 2079568 | 0.048035 | 0.046917 |
| 2001 | 1970770 | 2060761 | 0.045663 | 0.044651 |
| 2002 | 1966519 | 2060857 | 0.047972 | 0.046857 |
| 2003 | 1999387 | 2096705 | 0.048674 | 0.047526 |
| 2004 | 2010710 | 2108197 | 0.048484 | 0.047345 |
| 2005 | 2022892 | 2122727 | 0.049353 | 0.048173 |
| 2006 | 2084957 | 2188268 | 0.049551 | 0.048362 |
| 2007 | 2111890 | 2212118 | 0.047459 | 0.046367 |
| 2008 | 2077929 | 2177227 | 0.047787 | 0.046680 |
用一个简单的pivot_table 和 plot() 方法,我们马上可以看到每年在出生人口性别上的趋势,基本可以通过肉眼看出,在过去50年,男童出生率高于女童大约 5%左右。
深入的数据探索¶
虽然使用数据透视表不是必须的,但通过这个工具可以展示一些有趣的数据特征。如果需要对数据进行清洗工作,消除由于输入错误造成的异常点(如6月31日),消除的简单方法就是直接删除异常点,也可以通过更为稳健的西格玛消除法(sigma-clipping),依据正态分布标准差划定范围,如在Scipy中就是使用四个标准差来进行操作实现稳健判断。
quartiles = np.percentile(births['births'], [25, 50, 75])
mu = quartiles[1]
sig = 0.74 * (quartiles[2] - quartiles[0])
quartiles
array([4358. , 4814. , 5289.5])
最后一行是样本均值的稳定性估计(robust estimate),其中0.74是指标准正态分布的分位数间距。在query()方法中使用这个范围就可以将有效的出生人数数据筛选出来。
births = births.query('(births > @mu - 5 * @sig) & (births < @mu + 5 * @sig)')
births
| year | month | day | gender | births | |
|---|---|---|---|---|---|
| 0 | 1969 | 1 | 1.0 | F | 4046 |
| 1 | 1969 | 1 | 1.0 | M | 4440 |
| 2 | 1969 | 1 | 2.0 | F | 4454 |
| 3 | 1969 | 1 | 2.0 | M | 4548 |
| 4 | 1969 | 1 | 3.0 | F | 4548 |
| ... | ... | ... | ... | ... | ... |
| 15062 | 1988 | 12 | 29.0 | M | 5944 |
| 15063 | 1988 | 12 | 30.0 | F | 5742 |
| 15064 | 1988 | 12 | 30.0 | M | 6095 |
| 15065 | 1988 | 12 | 31.0 | F | 4435 |
| 15066 | 1988 | 12 | 31.0 | M | 4698 |
14610 rows × 5 columns
births.describe()
| year | month | day | births | |
|---|---|---|---|---|
| count | 14610.000000 | 14610.000000 | 14610.000000 | 14610.000000 |
| mean | 1978.501027 | 6.522930 | 15.729637 | 4824.470089 |
| std | 5.766538 | 3.448821 | 8.800393 | 579.996983 |
| min | 1969.000000 | 1.000000 | 1.000000 | 3249.000000 |
| 25% | 1974.000000 | 4.000000 | 8.000000 | 4383.000000 |
| 50% | 1979.000000 | 7.000000 | 16.000000 | 4812.000000 |
| 75% | 1984.000000 | 10.000000 | 23.000000 | 5259.000000 |
| max | 1988.000000 | 12.000000 | 31.000000 | 6527.000000 |
然后将day变量设置为整数,这列数据在筛选之前是字符串,因为数据集中有些列含有缺失值‘null’。
# 这种方法适合对列进行类型更改
births['day'] = births['day'].astype(int)
最后,我们可以把年月日组合在一起形成一个日期索引,这样可以快速计算每一行是星期几了。
# 创建一个时间索引
births.index = pd.to_datetime(10000 * births.year +
100 * births.month +
births.day, format='%Y%m%d')
births['dayofweek'] = births.index.dayofweek
births.index
DatetimeIndex(['1969-01-01', '1969-01-01', '1969-01-02', '1969-01-02',
'1969-01-03', '1969-01-03', '1969-01-04', '1969-01-04',
'1969-01-05', '1969-01-05',
...
'1988-12-27', '1988-12-27', '1988-12-28', '1988-12-28',
'1988-12-29', '1988-12-29', '1988-12-30', '1988-12-30',
'1988-12-31', '1988-12-31'],
dtype='datetime64[ns]', length=14610, freq=None)
用这种方法可以查看在几十年时间里各周每天的出生人数。
import matplotlib.pyplot as plt
import matplotlib as mpl
sns.set()
births.pivot_table('births', index='dayofweek',
columns='decade', aggfunc='mean').plot()
plt.gca().set_xticklabels(['Mon', 'Tues', 'Wed', 'Thurs', 'Fri', 'Sat', 'Sun'])
plt.ylabel('mean births by day');
C:\Users\getwa\AppData\Local\Temp\ipykernel_8704\4064397020.py:7: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. plt.gca().set_xticklabels(['Mon', 'Tues', 'Wed', 'Thurs', 'Fri', 'Sat', 'Sun'])
很明显,周末出生的人数明显低于正常工作日。另外CDC只提供1989年之前的数据,所以没有90年代和21世纪的数据。
另一个有趣的图标则是画出各个年份平均每天出生的人数,可以根据月和日进行分类:
births_by_date = births.pivot_table('births',
[births.index.month, births.index.day])
births_by_date.head()
| births | ||
|---|---|---|
| 1 | 1 | 4009.225 |
| 2 | 4247.400 | |
| 3 | 4500.900 | |
| 4 | 4571.350 | |
| 5 | 4603.625 |
结果是以月和日为多重索引的数据表。
为了更好绘制图形,我们假设一个年份如2012年,之所以用闰年是因为日期中有2月29日。
births_by_date.index = [pd.Timestamp(2012, month, day)
for (month, day) in births_by_date.index]
births_by_date.head()
| births | |
|---|---|
| 2012-01-01 | 4009.225 |
| 2012-01-02 | 4247.400 |
| 2012-01-03 | 4500.900 |
| 2012-01-04 | 4571.350 |
| 2012-01-05 | 4603.625 |
如果专注月和日级别的话,就是一个反映平均每天出生人数的时间序列。可以使用plot()将平均数字勾画出来。从图中可以看到很多有趣的现象:
# Plot the results
fig, ax = plt.subplots(figsize=(12, 4))
births_by_date.plot(ax=ax);
途中可以看到,节假日期间,出生人口急剧下降(独立日、感恩节、圣诞节、新年)具体是什么原因呢?请大家思考。