root/src/modules-lua/noit/HttpClient.lua

Revision f5982323e614fec6b6ca1900cedb708f89020355, 7.5 kB (checked in by Ryan Phillips <ryan.phillips@rackspace.com>, 3 years ago)

add 100k limit to http read lengths

  • Property mode set to 100644
Line 
1 -- Copyright (c) 2008, OmniTI Computer Consulting, Inc.
2 -- All rights reserved.
3 --
4 -- Redistribution and use in source and binary forms, with or without
5 -- modification, are permitted provided that the following conditions are
6 -- met:
7 --
8 --     * Redistributions of source code must retain the above copyright
9 --       notice, this list of conditions and the following disclaimer.
10 --     * Redistributions in binary form must reproduce the above
11 --       copyright notice, this list of conditions and the following
12 --       disclaimer in the documentation and/or other materials provided
13 --       with the distribution.
14 --     * Neither the name OmniTI Computer Consulting, Inc. nor the names
15 --       of its contributors may be used to endorse or promote products
16 --       derived from this software without specific prior written
17 --       permission.
18 --
19 -- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 -- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 -- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 -- A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 -- OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 -- SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 -- LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 -- DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 -- THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 -- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 -- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 local HttpClient = {};
32 HttpClient.__index = HttpClient;
33
34 function HttpClient:new(hooks)
35     local obj = { }
36     setmetatable(obj, HttpClient)
37     obj.hooks = hooks or { }
38     return obj
39 end
40
41 function HttpClient:connect(target, port, ssl)
42     if ssl == nil then ssl = false end
43     self.e = noit.socket(target)
44     self.target = target
45     self.port = port
46     local rv, err = self.e:connect(self.target, self.port)
47     if rv ~= 0 then
48         return rv, err
49     end
50     if self.hooks.connected ~= nil then self.hooks.connected() end
51     if ssl == false then return rv, err end
52     return self.e:ssl_upgrade_socket(self.hooks.certfile and self.hooks.certfile(),
53                                      self.hooks.keyfile and self.hooks.keyfile(),
54                                      self.hooks.cachain and self.hooks.cachain(),
55                                      self.hooks.ciphers and self.hooks.ciphers())
56 end
57
58 function HttpClient:ssl_ctx()
59    return self.e:ssl_ctx()
60 end
61
62 function HttpClient:do_request(method, uri, headers, payload)
63     self.raw_bytes = 0
64     self.content_bytes = 0
65     local sstr = method .. " " .. uri .. " " .. "HTTP/1.1\r\n"
66     headers["Content-Length"] = nil
67     if payload ~= nil and string.len(payload) > 0 then
68       headers["Content-Length"] = string.len(payload)
69     end
70     headers["Accept-Encoding"] = "gzip, deflate";
71     if headers["User-Agent"] == nil then
72         headers["User-Agent"] = "Reconnoiter/0.9"
73     end
74     for header, value in pairs(headers) do
75       if value ~= nil then sstr = sstr .. header .. ": " .. value .. "\r\n" end
76     end
77     sstr = sstr .. "\r\n"
78     if payload ~= nil and string.len(payload) > 0 then
79       sstr = sstr .. payload
80     end
81     self.e:write(sstr)
82 end
83
84 function HttpClient:get_headers()
85     local lasthdr
86     local str = self.e:read("\n");
87     if str == nil then error("no response") end
88     self.protocol, self.code = string.match(str, "^HTTP/(%d.%d)%s+(%d+)%s+")
89     if self.protocol == nil then error("malformed HTTP response") end
90     self.code = tonumber(self.code)
91     self.headers = {}
92     while true do
93         local str = self.e:read("\n")
94         if str == nil or str == "\r\n" or str == "\n" then break end
95         str = string.gsub(str, '%s+$', '')
96         local hdr, val = string.match(str, "^([-_%a%d]+):%s*(.*)$")
97         if hdr == nil then
98             if lasthdr == nil then error ("malformed header line") end
99             hdr = lasthdr
100             val = string.match(str, "^%s+(.+)")
101             if val == nil then error("malformed header line") end
102             self.headers[hdr] = self.headers[hdr] .. " " .. val
103         else
104             hdr = string.lower(hdr)
105             self.headers[hdr] = val
106             lasthdr = hdr
107         end
108     end
109     if self.hooks.headers ~= nil then self.hooks.headers(self.headers) end
110 end
111
112 function ce_passthru(str)
113     return str
114 end
115
116 function te_close(self, content_enc_func)
117     local len = 32678
118     local str
119     repeat
120         local str = self.e:read(len)
121         if str ~= nil then
122             self.raw_bytes = self.raw_bytes + string.len(str)
123             local decoded = content_enc_func(str)
124             if decoded ~= nil then
125                 self.content_bytes = self.content_bytes + string.len(decoded)
126             end
127             if self.hooks.consume ~= nil then self.hooks.consume(decoded) end
128         end
129     until str == nil or string.len(str) ~= len
130 end
131
132 function te_length(self, content_enc_func)
133     local len = tonumber(self.headers["content-length"])
134     len = len > 102400 and 102400 or len
135     repeat
136         local str = self.e:read(len)
137         if str ~= nil then
138             self.raw_bytes = self.raw_bytes + string.len(str)
139             len = len - string.len(str)
140         end
141         local decoded = content_enc_func(str)
142         self.content_bytes = self.content_bytes + string.len(decoded)
143         if self.hooks.consume ~= nil then self.hooks.consume(decoded) end
144     until str == nil or len == 0
145 end
146
147 function te_chunked(self, content_enc_func)
148     while true do
149         local str = self.e:read("\n")
150         if str == nil then error("bad chunk transfer") end
151         local hexlen = string.match(str, "^([0-9a-fA-F]+)")
152         if hexlen == nil then error("bad chunk length: " .. str) end
153         local len = tonumber(hexlen, 16)
154         if len == 0 then
155           if self.hooks.consume ~= nil then self.hooks.consume("") end
156           break
157         end
158         str = self.e:read(len)
159         if string.len(str or "") ~= len then error("short chunked read") end
160         self.raw_bytes = self.raw_bytes + string.len(str)
161         local decoded = content_enc_func(str)
162         self.content_bytes = self.content_bytes + string.len(decoded)
163         if self.hooks.consume ~= nil then self.hooks.consume(decoded) end
164         -- each chunk ('cept a 0 size one) is followed by a \r\n
165         str = self.e:read("\n")
166         if str ~= "\r\n" and str ~= "\n" then error("short chunked boundary read") end
167         if string.len(self.content_bytes) > 102400 then
168           break
169         end
170     end
171     -- read trailers
172     while true do
173         local str = self.e:read("\n")
174         if str == nil then error("bad chunk trailers") end
175         if str == "\r\n" or str == "\n" then break end
176     end
177 end
178
179 function HttpClient:get_body()
180     local cefunc = ce_passthru
181     local ce = self.headers["content-encoding"]
182     if ce ~= nil then
183         if ce == "gzip" then
184             cefunc = noit.gunzip()
185         elseif ce == "deflate" then
186             cefunc = noit.gunzip()
187         else
188             error("unknown content-encoding: " .. ce)
189         end
190     end
191     local te = self.headers["transfer-encoding"]
192     local cl = self.headers["content-length"]
193     if te ~= nil and te == "chunked" then
194         return te_chunked(self, cefunc)
195     elseif cl ~= nil and tonumber(cl) ~= nil then
196         return te_length(self, cefunc)
197     end
198     return te_close(self, cefunc)
199 end
200
201 function HttpClient:get_response()
202     self:get_headers()
203     return self:get_body()
204 end
205
206 return HttpClient
Note: See TracBrowser for help on using the browser.