root/src/modules-lua/noit/HttpClient.lua

Revision cec1cb615495ea336dfb3d405ee59926b654fa1f, 5.6 kB (checked in by Theo Schlossnagle <jesus@omniti.com>, 5 years ago)

patches from Dan... with all the +1s I'm not sure why he didn't commit these himself.

  • Property mode set to 100644
Line 
1 local HttpClient = {};
2 HttpClient.__index = HttpClient;
3
4 function HttpClient:new(hooks)
5     local obj = { }
6     obj.e = noit.socket()
7     setmetatable(obj, HttpClient)
8     obj.hooks = hooks or { }
9     return obj
10 end
11
12 function HttpClient:connect(target, port, ssl)
13     if ssl == nil then ssl = false end
14     self.target = target
15     self.port = port
16     local rv, err = self.e:connect(self.target, self.port)
17     if rv ~= 0 then
18         return rv, err
19     end
20     if self.hooks.connected ~= nil then self.hooks.connected() end
21     if ssl == false then return rv, err end
22     return self.e:ssl_upgrade_socket(self.hooks.certfile and self.hooks.certfile(),
23                                      self.hooks.keyfile and self.hooks.keyfile(),
24                                      self.hooks.cachain and self.hooks.cachain(),
25                                      self.hooks.ciphers and self.hooks.ciphers())
26 end
27
28 function HttpClient:ssl_ctx()
29    return self.e:ssl_ctx()
30 end
31
32 function HttpClient:do_request(method, uri, headers, payload)
33     self.raw_bytes = 0
34     self.content_bytes = 0
35     self.e:write(method .. " " .. uri .. " " .. "HTTP/1.1\r\n")
36     if payload ~= nil and string.len(payload) > 0 then
37       headers["Content-Length"] = string.len(payload)
38     end
39     headers["Accept-Encoding"] = "gzip, deflate";
40     if headers["User-Agent"] == nil then
41         headers["User-Agent"] = "Reconnoiter/0.9"
42     end
43     for header, value in pairs(headers) do
44       self.e:write(header .. ": " .. value .. "\r\n")
45     end
46     self.e:write("\r\n")
47     if payload ~= nil and string.len(payload) > 0 then
48       self.e:write(payload)
49     end
50 end
51
52 function HttpClient:get_headers()
53     local lasthdr
54     local str = self.e:read("\n");
55     if str == nil then error("no response") end
56     self.protocol, self.code = string.match(str, "^HTTP/(%d.%d)%s+(%d+)%s+")
57     if self.protocol == nil then error("malformed HTTP response") end
58     self.code = tonumber(self.code)
59     self.headers = {}
60     while true do
61         local str = self.e:read("\n")
62         if str == nil or str == "\r\n" or str == "\n" then break end
63         str = string.gsub(str, '%s+$', '')
64         local hdr, val = string.match(str, "^([-_%a%d]+):%s+(.+)$")
65         if hdr == nil then
66             if lasthdr == nil then error ("malformed header line") end
67             hdr = lasthdr
68             val = string.match(str, "^%s+(.+)")
69             if val == nil then error("malformed header line") end
70             self.headers[hdr] = self.headers[hdr] .. " " .. val
71         else
72             hdr = string.lower(hdr)
73             self.headers[hdr] = val
74             lasthdr = hdr
75         end
76     end
77     if self.hooks.headers ~= nil then self.hooks.headers(self.headers) end
78 end
79
80 function ce_passthru(str)
81     return str
82 end
83
84 function te_close(self, content_enc_func)
85     local len = 32678
86     local str
87     repeat
88         local str = self.e:read(len)
89         if str ~= nil then
90             self.raw_bytes = self.raw_bytes + string.len(str)
91         end
92     until str == nil or string.len(str) ~= len
93     local decoded = content_enc_func(str)
94     if decoded ~= nil then
95         self.content_bytes = self.content_bytes + string.len(decoded)
96     end
97     if self.hooks.consume ~= nil then self.hooks.consume(decoded) end
98 end
99
100 function te_length(self, content_enc_func)
101     local len = tonumber(self.headers["content-length"])
102     repeat
103         local str = self.e:read(len)
104         if str ~= nil then
105             self.raw_bytes = self.raw_bytes + string.len(str)
106             len = len - string.len(str)
107         end
108         local decoded = content_enc_func(str)
109         self.content_bytes = self.content_bytes + string.len(decoded)
110         if self.hooks.consume ~= nil then self.hooks.consume(decoded) end
111     until str == nil or len == 0
112 end
113
114 function te_chunked(self, content_enc_func)
115     while true do
116         local str = self.e:read("\n")
117         if str == nil then error("bad chunk transfer") end
118         local hexlen = string.match(str, "^([0-9a-fA-F]+)")
119         if hexlen == nil then error("bad chunk length: " .. str) end
120         local len = tonumber(hexlen, 16)
121         if len == 0 then
122           if self.hooks.consume ~= nil then self.hooks.consume("") end
123           break
124         end
125         str = self.e:read(len)
126         if string.len(str or "") ~= len then error("short chunked read") end
127         self.raw_bytes = self.raw_bytes + string.len(str)
128         local decoded = content_enc_func(str)
129         self.content_bytes = self.content_bytes + string.len(decoded)
130         if self.hooks.consume ~= nil then self.hooks.consume(decoded) end
131         -- each chunk ('cept a 0 size one) is followed by a \r\n
132         str = self.e:read("\n")
133         if str ~= "\r\n" and str ~= "\n" then error("short chunked boundary read") end
134     end
135     -- read trailers
136     while true do
137         local str = self.e:read("\n")
138         if str == nil then error("bad chunk trailers") end
139         if str == "\r\n" or str == "\n" then break end
140     end
141 end
142
143 function HttpClient:get_body()
144     local cefunc = ce_passthru
145     local ce = self.headers["content-encoding"]
146     if ce ~= nil then
147         if ce == "gzip" then
148             cefunc = noit.gunzip()
149         elseif ce == "deflate" then
150             cefunc = noit.gunzip()
151         else
152             error("unknown content-encoding: " .. ce)
153         end
154     end
155     local te = self.headers["transfer-encoding"]
156     local cl = self.headers["content-length"]
157     if te ~= nil and te == "chunked" then
158         return te_chunked(self, cefunc)
159     elseif cl ~= nil and tonumber(cl) ~= nil then
160         return te_length(self, cefunc)
161     end
162     return te_close(self, cefunc)
163 end
164
165 function HttpClient:get_response()
166     self:get_headers()
167     return self:get_body()
168 end
169
170 return HttpClient
Note: See TracBrowser for help on using the browser.