1 /*
2 Copyright (c) 2023-2024 Andrea Fontana
3 
4 Permission is hereby granted, free of charge, to any person
5 obtaining a copy of this software and associated documentation
6 files (the "Software"), to deal in the Software without
7 restriction, including without limitation the rights to use,
8 copy, modify, merge, publish, distribute, sublicense, and/or sell
9 copies of the Software, and to permit persons to whom the
10 Software is furnished to do so, subject to the following
11 conditions:
12 
13 The above copyright notice and this permission notice shall be
14 included in all copies or substantial portions of the Software.
15 
16 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
18 OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19 NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
20 HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
21 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22 FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
23 OTHER DEALINGS IN THE SOFTWARE.
24 */
25 
26 module serverino.worker;
27 
28 import serverino.common;
29 import serverino.config;
30 import serverino.interfaces;
31 import std.experimental.logger : log, warning, info, fatal, critical;
32 import std.process : environment;
33 import std.stdio : FILE;
34 import std.socket;
35 import std.datetime : dur;
36 import std.string : toStringz, indexOf, strip, toLower, empty;
37 import std.algorithm : splitter, startsWith, map;
38 import std.range : assumeSorted;
39 import std.format : format;
40 import std.conv : to;
41 
42 extern(C) int dup(int a);
43 extern(C) int dup2(int a, int b);
44 extern(C) int fileno(FILE *stream);
45 
46 
47 struct Worker
48 {
49    static auto instance()
50    {
51       static Worker* _instance;
52       if (_instance is null) _instance = new Worker();
53       return _instance;
54    }
55 
56    void wake(Modules...)(WorkerConfigPtr config)
57    {
58       import std.conv : to;
59       import std.stdio;
60       import std.format : format;
61 
62       daemonProcess = new ProcessInfo(environment.get("SERVERINO_DAEMON").to!int);
63 
64       version(linux) char[] socketAddress = char(0) ~ cast(char[])environment.get("SERVERINO_SOCKET");
65       else
66       {
67          import std.path : buildPath;
68          import std.file : tempDir;
69          char[] socketAddress = cast(char[]) environment.get("SERVERINO_SOCKET");
70       }
71 
72       assert(socketAddress.length > 0);
73 
74       channel = new Socket(AddressFamily.UNIX, SocketType.STREAM);
75       channel.connect(new UnixAddress(socketAddress));
76       channel.setOption(SocketOptionLevel.SOCKET, SocketOption.RCVTIMEO, dur!"seconds"(1));
77 
78       request._internal = new Request.RequestImpl();
79       output._internal = new Output.OutputImpl();
80 
81 
82       version(Windows)
83       {
84          log("Worker started.");
85       }
86       else
87       {
88          import core.sys.posix.pwd;
89          import core.sys.posix.grp;
90          import core.sys.posix.unistd;
91          import core.stdc.string : strlen;
92 
93          if (config.group.length > 0)
94          {
95             auto gi = getgrnam(config.group.toStringz);
96             if (gi !is null) setgid(gi.gr_gid);
97             else fatal("Can't find group ", config.group);
98          }
99 
100          if (config.user.length > 0)
101          {
102             auto ui = getpwnam(config.user.toStringz);
103             if (ui !is null) setuid(ui.pw_uid);
104             else fatal("Can't find user ", config.user);
105          }
106 
107          auto ui = getpwuid(getuid());
108          auto gr = getgrgid(getgid());
109 
110          if (ui.pw_uid == 0) critical("Worker running as root. Is this intended? Set user/group from config to run worker as unprivileged user.");
111          else log("Worker started. (user=", ui.pw_name[0..ui.pw_name.strlen], " group=", gr.gr_name[0..gr.gr_name.strlen], ")");
112 
113       }
114 
115       // Prevent read from stdin
116       {
117          import std.stdio : stdin;
118 
119          version(Windows)  auto nullSink = fopen("NUL", "r");
120          version(Posix)    auto nullSink = fopen("/dev/null", "r");
121 
122          dup2(fileno(nullSink), fileno(stdin.getFP));
123       }
124 
125       tryInit!Modules();
126 
127       import std.string : chomp;
128 
129       import core.thread : Thread;
130       import core.stdc.stdlib : exit;
131       import core.atomic : cas;
132 
133       __gshared CoarseTime processedStartedAt = CoarseTime.zero;
134       __gshared bool justSent = false;
135 
136       new Thread({
137 
138          Thread.getThis().isDaemon = true;
139 
140          while(processedStartedAt == CoarseTime.zero || CoarseTime.currTime - processedStartedAt < config.maxRequestTime)
141             Thread.sleep(1.dur!"seconds");
142 
143          log("Request timeout.");
144 
145          if (cas(&justSent, false, true))
146          {
147             WorkerPayload wp;
148 
149             output._internal.clear();
150             output.status = 504;
151             output._internal.buildHeaders();
152             wp.isKeepAlive = false;
153             processedStartedAt = CoarseTime.zero;
154             wp.contentLength = output._internal._headersBuffer.array.length + output._internal._sendBuffer.array.length;
155 
156             channel.send((cast(char*)&wp)[0..wp.sizeof] ~ output._internal._headersBuffer.array ~ output._internal._sendBuffer.array);
157          }
158 
159          channel.close();
160          exit(0);
161       }).start();
162 
163       startedAt = CoarseTime.currTime;
164       while(true)
165       {
166          import std.string : chomp;
167 
168          justSent = false;
169          output._internal.clear();
170          request._internal.clear();
171 
172          ubyte[32*1024] buffer;
173 
174          idlingAt = CoarseTime.currTime;
175 
176          import serverino.databuffer;
177 
178          //TODO: Gestire richiesta > 32kb (buffer.length)
179          uint size;
180          bool sizeRead = false;
181          ptrdiff_t recv = -1;
182          static DataBuffer!ubyte data;
183          data.clear();
184 
185          while(sizeRead == false || size > data.length)
186          {
187             recv = -1;
188             while(recv == -1)
189             {
190                recv = channel.receive(buffer);
191                import core.stdc.stdlib : exit;
192 
193                if (recv == -1)
194                {
195 
196                   immutable tm = CoarseTime.currTime;
197                   if (tm - idlingAt > config.maxWorkerIdling)
198                   {
199                      log("Killing worker. [REASON: maxWorkerIdling]");
200                      tryUninit!Modules();
201                      channel.close();
202                      exit(0);
203                   }
204                   else if (tm - startedAt > config.maxWorkerLifetime)
205                   {
206                      log("Killing worker. [REASON: maxWorkerLifetime]");
207                      tryUninit!Modules();
208                      channel.close();
209                      exit(0);
210                   }
211 
212                   continue;
213                }
214                else if (recv < 0)
215                {
216                   tryUninit!Modules();
217                   log("Killing worker. [REASON: socket error]");
218                   channel.close();
219                   exit(cast(int)recv);
220                }
221             }
222 
223             if (recv == 0) break;
224             else if (sizeRead == false)
225             {
226                size = *(cast(uint*)(buffer[0..uint.sizeof].ptr));
227                data.reserve(size);
228                data.append(buffer[uint.sizeof..recv]);
229                sizeRead = true;
230             }
231             else data.append(buffer[0..recv]);
232          }
233 
234          if(data.array.length == 0)
235          {
236             tryUninit!Modules();
237 
238             if (daemonProcess.isTerminated()) log("Killing worker. [REASON: daemon dead?]");
239             else log("Killing worker. [REASON: socket closed?]");
240 
241             channel.close();
242             exit(0);
243          }
244 
245          requestId++;
246          processedStartedAt = CoarseTime.currTime;
247 
248          WorkerPayload wp;
249          wp.isKeepAlive = parseHttpRequest!Modules(config, data.array);
250          processedStartedAt = CoarseTime.zero;
251          wp.contentLength = output._internal._sendBuffer.array.length + output._internal._headersBuffer.array.length;
252 
253          if (cas(&justSent, false, true))
254             channel.send((cast(char*)&wp)[0..wp.sizeof] ~  output._internal._headersBuffer.array ~ output._internal._sendBuffer.array);
255       }
256 
257 
258    }
259 
260    bool parseHttpRequest(Modules...)(WorkerConfigPtr config, ubyte[] data)
261    {
262 
263       scope(exit) {
264          output._internal.buildHeaders();
265          if (!output._internal._sendBody)
266             output._internal._sendBuffer.clear();
267       }
268 
269       import std.utf : UTFException;
270 
271       version(debugRequest) log("-- START RECEIVING");
272       try
273       {
274          size_t			   contentLength = 0;
275 
276          char[]			method;
277          char[]			path;
278          char[]			httpVersion;
279 
280          char[]			requestLine;
281          char[]			headers;
282 
283          bool 			hasContentLength = false;
284 
285 
286          headers = cast(char[]) data;
287          auto headersEnd = headers.indexOf("\r\n\r\n");
288 
289          bool valid = true;
290 
291          // Headers completed?
292          if (headersEnd > 0)
293          {
294             version(debugRequest) log("-- HEADERS COMPLETED");
295 
296             headers.length = headersEnd;
297             data = data[headersEnd+4..$];
298 
299             auto headersLines = headers.splitter("\r\n");
300 
301             requestLine = headersLines.front;
302 
303             {
304                auto fields = requestLine.splitter(' ');
305                size_t popped = 0;
306 
307                if (!fields.empty)
308                {
309                   method = fields.front;
310                   fields.popFront;
311                   popped++;
312                }
313 
314                if (!fields.empty)
315                {
316                   path = fields.front;
317                   fields.popFront;
318                   popped++;
319                }
320 
321                if (!fields.empty)
322                {
323                   httpVersion = fields.front;
324                   fields.popFront;
325                   popped++;
326                }
327 
328                if (["CONNECT", "DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT", "TRACE"].assumeSorted.contains(method) == false)
329                {
330                   debug warning("HTTP method unknown: ", method);
331                   valid = false;
332                }
333             }
334 
335             if (!valid)
336             {
337                output._internal._httpVersion = (httpVersion == "HTTP/1.1")?HttpVersion.HTTP11:HttpVersion.HTTP10;
338                output._internal._sendBody = false;
339                output.status = 400;
340                return false;
341             }
342 
343             headersLines.popFront;
344 
345             foreach(const ref l; headersLines)
346             {
347                auto firstColon = l.indexOf(':');
348                if (firstColon > 0)
349 
350                switch(l[0..firstColon])
351                {
352                   case "content-length":
353                      contentLength = l[firstColon+1..$].to!size_t;
354                      hasContentLength = true;
355                      break;
356 
357                   default:
358                }
359             }
360 
361             // If no content-length, we don't read body.
362             if (contentLength == 0)
363             {
364                version(debugRequest) log("-- NO CONTENT LENGTH, SKIP DATA");
365                data.length = 0;
366             }
367 
368             else if (data.length >= contentLength)
369             {
370                version(debugRequest) log("-- DATA ALREADY READ.");
371                data.length = contentLength;
372             }
373          }
374          else return false;
375 
376          version(debugRequest) log("-- PARSING DATA");
377 
378          {
379             import std.algorithm : max;
380             import std.uni : sicmp;
381 
382             request._internal._httpVersion    = (httpVersion == "HTTP/1.1")?(HttpVersion.HTTP11):(HttpVersion.HTTP10);
383             request._internal._data           = cast(char[])data;
384             request._internal._rawHeaders     = headers.to!string;
385             request._internal._rawRequestLine = requestLine.to!string;
386 
387             output._internal._httpVersion = request._internal._httpVersion;
388 
389             bool insidePath = true;
390             size_t pathLen = 0;
391             size_t queryStart = 0;
392             size_t queryLen = 0;
393 
394             foreach(i, const x; path)
395             {
396                if (insidePath)
397                {
398                   if (x == '?')
399                   {
400                      pathLen = i;
401                      queryStart = i+1;
402                      insidePath = false;
403                   }
404                   else if (x == '#')
405                   {
406                      pathLen = i;
407                      break;
408                   }
409                }
410                else
411                {
412                   // Should not happen!
413                   if (x == '#')
414                   {
415                      queryLen = i;
416                      break;
417                   }
418                }
419             }
420 
421             if (pathLen == 0)
422             {
423                pathLen = path.length;
424                queryStart = path.length;
425                queryLen = path.length;
426             }
427             else if (queryLen == 0) queryLen = path.length;
428 
429             // Just to prevent uri attack like
430             // GET /../../non_public_file
431             auto normalize(string uri)
432             {
433                import std.range : retro, join;
434                import std.algorithm : filter;
435                import std.array : array;
436                import std.typecons : tuple;
437                size_t skips = 0;
438                string norm = uri
439                   .splitter('/')
440                   .retro
441                   .map!(
442                         (x)
443                         {
444                            if (x == "..") skips++;
445                            else if(x != ".")
446                            {
447                               if (skips == 0) return tuple(x, true);
448                               else skips--;
449                            }
450 
451                            return tuple(x, false);
452                         }
453                   )
454                   .filter!(x => x[1] == true)
455                   .map!(x => x[0])
456                   .array
457                   .retro
458                   .join('/');
459 
460                   if (norm.startsWith("/")) return norm;
461                   else return "/" ~ norm;
462             }
463 
464             request._internal._uri            = normalize(cast(string)path[0..pathLen]);
465             request._internal._rawQueryString = cast(string)path[queryStart..queryLen];
466             request._internal._method         = method.to!string;
467 
468             output._internal._sendBody = (!["CONNECT", "HEAD", "TRACE"].assumeSorted.contains(request._internal._method));
469 
470             import std.uri : URIException;
471             try { request._internal.process(); }
472             catch (URIException e)
473             {
474                output.status = 400;
475                output._internal._sendBody = false;
476                return false;
477             }
478 
479             output._internal._keepAlive =
480                config.keepAlive &&
481                output._internal._httpVersion == HttpVersion.HTTP11 &&
482                sicmp(request.header.read("connection", "keep-alive"), "keep-alive") == 0;
483 
484             version(debugRequest) log("-- REQ: ", request.uri);
485             version(debugRequest) log("-- PARSING STATUS: ", request._internal._parsingStatus);
486 
487             try {
488                version(debugRequest) log("-- REQ: ", request);
489             }
490             catch (Exception e ) {log("EX:", e);}
491 
492             if (request._internal._parsingStatus == Request.ParsingStatus.OK)
493             {
494                try
495                {
496                   callHandlers!Modules(request, output);
497 
498                   if (!output._internal._dirty)
499                   {
500                      output.status = 404;
501                      output._internal._sendBody = false;
502                   }
503 
504                   return (output._internal._keepAlive);
505 
506                }
507 
508                // Unhandled Exception escaped from user code
509                catch (Exception e)
510                {
511                   critical(format("%s:%s Uncatched exception: %s", e.file, e.line, e.msg));
512                   critical(format("-------\n%s",e.info));
513 
514                   output.status = 500;
515                   return (output._internal._keepAlive);
516                }
517 
518                // Even worse.
519                catch (Throwable t)
520                {
521                   critical(format("%s:%s Throwable: %s", t.file, t.line, t.msg));
522                   critical(format("-------\n%s",t.info));
523 
524                   // Rethrow
525                   throw t;
526                }
527             }
528             else
529             {
530                debug warning("Parsing error:", request._internal._parsingStatus);
531 
532                if (request._internal._parsingStatus == Request.ParsingStatus.InvalidBody) output.status = 422;
533                else output.status = 400;
534 
535                output._internal._sendBody = false;
536                return (output._internal._keepAlive);
537             }
538 
539          }
540       }
541       catch(UTFException e)
542       {
543          output.status = 400;
544          output._internal._sendBody = false;
545 
546          debug warning("UTFException: ", e.toString);
547       }
548       catch(Exception e) {
549 
550          output.status = 500;
551          output._internal._sendBody = false;
552 
553          debug critical("Unhandled exception: ", e.toString);
554       }
555 
556       return false;
557    }
558 
559    void callHandlers(modules...)(Request request, Output output)
560    {
561       import std.algorithm : sort;
562       import std.array : array;
563       import std.traits : getUDAs, ParameterStorageClass, ParameterStorageClassTuple, fullyQualifiedName, getSymbolsByUDA;
564 
565       struct FunctionPriority
566       {
567          string   name;
568          long     priority;
569          string   mod;
570       }
571 
572       auto getUntaggedHandlers()
573       {
574          FunctionPriority[] fps;
575          static foreach(m; modules)
576          {{
577             alias globalNs = m;
578 
579             foreach(sy; __traits(allMembers, globalNs))
580             {
581                alias s = __traits(getMember, globalNs, sy);
582                static if
583                (
584                   (
585                      __traits(compiles, s(request, output)) ||
586                      __traits(compiles, s(request)) ||
587                      __traits(compiles, s(output))
588                   )
589                )
590                {
591 
592 
593                   static foreach(p; ParameterStorageClassTuple!s)
594                   {
595                      static if (p == ParameterStorageClass.ref_)
596                      {
597                         static if(!is(typeof(ValidSTC)))
598                            enum ValidSTC = false;
599                      }
600                   }
601 
602                   static if(!is(typeof(ValidSTC)))
603                      enum ValidSTC = true;
604 
605 
606                   static if (ValidSTC)
607                   {
608                      FunctionPriority fp;
609                      fp.name = __traits(identifier,s);
610                      fp.priority = 0;
611                      fp.mod = fullyQualifiedName!m;
612 
613                      fps ~= fp;
614                   }
615                }
616             }
617          }}
618 
619          return fps.sort!((a,b) => a.priority > b.priority).array;
620       }
621 
622       auto getTaggedHandlers()
623       {
624          FunctionPriority[] fps;
625 
626          static foreach(m; modules)
627          {{
628             alias globalNs = m;
629 
630             foreach(s; getSymbolsByUDA!(globalNs, endpoint))
631             {
632                static if
633                (
634                   !__traits(compiles, s(request, output)) &&
635                   !__traits(compiles, s(request)) &&
636                   !__traits(compiles, s(output))
637                )
638                {
639                   static assert(0, fullyQualifiedName!s ~ " is not a valid endpoint. Wrong params. Try to change its signature to `" ~ __traits(identifier,s) ~ "(Request request, Output output)`.");
640                }
641 
642                static foreach(p; ParameterStorageClassTuple!s)
643                {
644                   static if (p == ParameterStorageClass.ref_)
645                   {
646                      static assert(0, fullyQualifiedName!s ~ " is not a valid endpoint. Wrong storage class for params. Try to change its signature to `" ~ __traits(identifier,s) ~ "(Request request, Output output)`.");
647                   }
648                }
649 
650                FunctionPriority fp;
651 
652                fp.name = __traits(identifier,s);
653                fp.mod = fullyQualifiedName!m;
654 
655                static if (getUDAs!(s, priority).length > 0 && !is(getUDAs!(s, priority)[0]))
656                   fp.priority = getUDAs!(s, priority)[0].priority;
657 
658 
659                fps ~= fp;
660             }
661          }}
662 
663          return fps.sort!((a,b) => a.priority > b.priority).array;
664 
665       }
666 
667       enum taggedHandlers = getTaggedHandlers();
668       enum untaggedHandlers = getUntaggedHandlers();
669 
670 
671       static if (taggedHandlers !is null && taggedHandlers.length>0)
672       {
673          bool callUntilIsDirty(FunctionPriority[] taggedHandlers)()
674          {
675             static foreach(ff; taggedHandlers)
676             {
677                {
678                   mixin(`import ` ~ ff.mod ~ ";");
679                   alias currentMod = mixin(ff.mod);
680                   alias f = __traits(getMember,currentMod, ff.name);
681 
682                   import std.traits : hasUDA, TemplateOf, getUDAs;
683 
684                   bool willLaunch = true;
685                   static if (hasUDA!(f, route))
686                   {
687                      willLaunch = false;
688                      static foreach(attr;  getUDAs!(f, route))
689                      {
690                         {
691                            if(attr.apply(request)) willLaunch = true;
692                         }
693                      }
694                   }
695 
696                   if (willLaunch)
697                   {
698                      static if (__traits(compiles, f(request, output))) f(request, output);
699                      else static if (__traits(compiles, f(request))) f(request);
700                      else f(output);
701                   }
702 
703                   request._internal._route ~= ff.mod ~ "." ~ ff.name;
704                }
705 
706                if (output._internal._dirty) return true;
707             }
708 
709             return false;
710          }
711 
712         callUntilIsDirty!taggedHandlers;
713       }
714       else static if (untaggedHandlers !is null)
715       {
716 
717          static if (untaggedHandlers.length != 1)
718          {
719             static assert(0, "Please tag each valid endpoint with @endpoint UDA.");
720          }
721          else
722          {
723             {
724                mixin(`import ` ~ untaggedHandlers[0].mod ~ ";");
725                alias currentMod = mixin(untaggedHandlers[0].mod);
726                alias f = __traits(getMember,currentMod, untaggedHandlers[0].name);
727 
728                if (!output._internal._dirty)
729                {
730                   static if (__traits(compiles, f(request, output))) f(request, output);
731                   else static if (__traits(compiles, f(request))) f(request);
732                   else f(output);
733 
734                   request._internal._route ~= untaggedHandlers[0].mod ~ "." ~ untaggedHandlers[0].name;
735                }
736             }
737          }
738       }
739       else static assert(0, "Please add at least one endpoint. Try this: `void hello(Request req, Output output) { output ~= req.dump(); }`");
740    }
741 
742    char[]      mem;
743    //SharedMemory.MemHandle   memHandle;
744    ProcessInfo daemonProcess;
745 
746    Request           request;
747    Output            output;
748 
749    CoarseTime          startedAt;
750    CoarseTime          idlingAt;
751 
752    Socket            channel;
753 
754    __gshared         requestId = 0;
755 }
756 
757 
758 void tryInit(Modules...)()
759 {
760    import std.traits : getSymbolsByUDA, isFunction;
761 
762    static foreach(m; Modules)
763    {
764       static foreach(f;  getSymbolsByUDA!(m, onWorkerStart))
765       {{
766          static assert(isFunction!f, "`" ~ __traits(identifier, f) ~ "` is marked with @onWorkerStart but it is not a function");
767 
768          static if (__traits(compiles, f())) f();
769          else static assert(0, "`" ~ __traits(identifier, f) ~ "` is marked with @onWorkerStart but it is not callable");
770 
771       }}
772    }
773 }
774 
775 void tryUninit(Modules...)()
776 {
777    import std.traits : getSymbolsByUDA, isFunction;
778 
779    static foreach(m; Modules)
780    {
781       static foreach(f;  getSymbolsByUDA!(m, onWorkerStop))
782       {{
783          static assert(isFunction!f, "`" ~ __traits(identifier, f) ~ "` is marked with @onWorkerStop but it is not a function");
784 
785          static if (__traits(compiles, f())) f();
786          else static assert(0, "`" ~ __traits(identifier, f) ~ "` is marked with @onWorkerStop but it is not callable");
787 
788       }}
789    }
790 }
791