From 3bb66753076f4037883b7c71ce2fb8e78f8b1194 Mon Sep 17 00:00:00 2001
From: Anton Kling <anton@kling.gg>
Date: Mon, 9 Dec 2024 23:24:25 +0100
Subject: kernel: Add ksnprintf

---
 kernel/libc/include/stdio.h |   2 +
 kernel/libc/stdio/print.c   | 116 +++++++++++++++++++++++++++++++-------------
 2 files changed, 85 insertions(+), 33 deletions(-)

diff --git a/kernel/libc/include/stdio.h b/kernel/libc/include/stdio.h
index 9dd496e..0c6a0a3 100644
--- a/kernel/libc/include/stdio.h
+++ b/kernel/libc/include/stdio.h
@@ -1,10 +1,12 @@
 #ifndef STDIO_H
 #define STDIO_H
 #include <stdarg.h>
+#include <stddef.h>
 
 void putc(const char c);
 void delete_characther(void);
 int kprintf(const char *format, ...);
 int vkprintf(const char *format, va_list list);
+int ksnprintf(char *out, size_t size, const char *format, ...);
 
 #endif
diff --git a/kernel/libc/stdio/print.c b/kernel/libc/stdio/print.c
index 3df6b5f..d4ec48a 100644
--- a/kernel/libc/stdio/print.c
+++ b/kernel/libc/stdio/print.c
@@ -1,28 +1,25 @@
 #include <assert.h>
+#include <kmalloc.h>
 #include <log.h>
+#include <math.h>
 #include <stdio.h>
 #include <string.h>
 
 #define TAB_SIZE 8
 
-inline void putc(const char c) {
-  log_char(c);
-}
-
-void put_string(const char *s, int l) {
-  for (; l > 0; l--, s++) {
-    log_char(*s);
-  }
-}
+struct print_context {
+  void *data;
+  void (*write)(struct print_context *, const char *, int);
+};
 
-#define WRITE(_s, _l, _r)                                                      \
+#define WRITE(s, l, r)                                                         \
   {                                                                            \
-    put_string(_s, _l);                                                        \
-    *(int *)(_r) += _l;                                                        \
+    ctx->write(ctx, s, l);                                                     \
+    *(int *)(r) += l;                                                          \
   }
 
-int print_num(long long n, int base, char *char_set, int prefix,
-              int zero_padding, int right_padding) {
+int print_num(struct print_context *ctx, long long n, int base, char *char_set,
+              int prefix, int zero_padding, int right_padding) {
   int c = 0;
   char str[32];
   int i = 0;
@@ -69,21 +66,25 @@ int print_num(long long n, int base, char *char_set, int prefix,
   return c;
 }
 
-int print_int(long long n, int prefix, int zero_padding, int right_padding) {
-  return print_num(n, 10, "0123456789", prefix, zero_padding, right_padding);
+int print_int(struct print_context *ctx, long long n, int prefix,
+              int zero_padding, int right_padding) {
+  return print_num(ctx, n, 10, "0123456789", prefix, zero_padding,
+                   right_padding);
 }
 
-int print_hex(long long n, int prefix, int zero_padding, int right_padding) {
-  return print_num(n, 16, "0123456789abcdef", prefix, zero_padding,
+int print_hex(struct print_context *ctx, long long n, int prefix,
+              int zero_padding, int right_padding) {
+  return print_num(ctx, n, 16, "0123456789abcdef", prefix, zero_padding,
                    right_padding);
 }
 
-int print_octal(long long n, int prefix, int zero_padding, int right_padding) {
-  return print_num(n, 8, "012345678", prefix, zero_padding, right_padding);
+int print_octal(struct print_context *ctx, long long n, int prefix,
+                int zero_padding, int right_padding) {
+  return print_num(ctx, n, 8, "012345678", prefix, zero_padding, right_padding);
 }
 
-int print_string(const char *s, int *rc, int prefix, int right_padding,
-                 int precision) {
+int print_string(struct print_context *ctx, const char *s, int *rc, int prefix,
+                 int right_padding, int precision) {
   int l = strlen(s);
   char t = ' ';
   int c = 0;
@@ -138,7 +139,7 @@ int parse_precision(const char **fmt) {
   return rc;
 }
 
-int vkprintf(const char *fmt, va_list ap) {
+int vkcprintf(struct print_context *ctx, const char *fmt, va_list ap) {
   int rc = 0;
   const char *s = fmt;
   int prefix = 0;
@@ -212,20 +213,21 @@ int vkprintf(const char *fmt, va_list ap) {
         right_padding = 0;
       }
       if (2 == long_level) {
-        rc += print_int(va_arg(ap, long long), prefix, zero_padding,
+        rc += print_int(ctx, va_arg(ap, long long), prefix, zero_padding,
                         right_padding);
       } else if (1 == long_level) {
-        rc += print_int(va_arg(ap, long long), prefix, zero_padding,
+        rc += print_int(ctx, va_arg(ap, long long), prefix, zero_padding,
                         right_padding);
       } else {
-        rc += print_int(va_arg(ap, int), prefix, zero_padding, right_padding);
+        rc += print_int(ctx, va_arg(ap, int), prefix, zero_padding,
+                        right_padding);
       }
       long_level = 0;
       cont = 0;
       break;
     case 'u':
       assert(-1 == precision);
-      rc += print_int(va_arg(ap, unsigned int), prefix, zero_padding,
+      rc += print_int(ctx, va_arg(ap, unsigned int), prefix, zero_padding,
                       right_padding);
       cont = 0;
       break;
@@ -233,14 +235,14 @@ int vkprintf(const char *fmt, va_list ap) {
       assert(!zero_padding); // this is not supported to strings
       char *a = va_arg(ap, char *);
       if (!a) {
-        if (-1 ==
-            print_string("(NULL)", &rc, prefix, right_padding, precision)) {
+        if (-1 == print_string(ctx, "(NULL)", &rc, prefix, right_padding,
+                               precision)) {
           return -1;
         }
         cont = 0;
         break;
       }
-      if (-1 == print_string(a, &rc, prefix, right_padding, precision)) {
+      if (-1 == print_string(ctx, a, &rc, prefix, right_padding, precision)) {
         return -1;
       }
       cont = 0;
@@ -249,13 +251,13 @@ int vkprintf(const char *fmt, va_list ap) {
     case 'p': // TODO: Print this out in a nicer way
     case 'x':
       assert(-1 == precision);
-      rc +=
-          print_hex(va_arg(ap, const u32), prefix, zero_padding, right_padding);
+      rc += print_hex(ctx, va_arg(ap, const u32), prefix, zero_padding,
+                      right_padding);
       cont = 0;
       break;
     case 'o':
       assert(-1 == precision);
-      rc += print_octal(va_arg(ap, const u32), prefix, zero_padding,
+      rc += print_octal(ctx, va_arg(ap, const u32), prefix, zero_padding,
                         right_padding);
       cont = 0;
       break;
@@ -284,6 +286,54 @@ int vkprintf(const char *fmt, va_list ap) {
   return rc;
 }
 
+struct sn_context {
+  char *out;
+  int size;
+};
+
+void sn_write(struct print_context *_ctx, const char *s, int l) {
+  struct sn_context *ctx = (struct sn_context *)_ctx->data;
+  assert(ctx);
+  size_t k = min(l, ctx->size);
+  memcpy(ctx->out, s, k);
+  ctx->out += k;
+  ctx->size -= k;
+  *(ctx->out) = '\0';
+}
+
+int ksnprintf(char *out, size_t size, const char *format, ...) {
+  struct print_context context;
+
+  struct sn_context *ctx = context.data = kmalloc(sizeof(struct sn_context));
+  if (!ctx) {
+    return -1;
+  }
+  ctx->out = out;
+  ctx->size = size;
+  context.write = sn_write;
+
+  va_list list;
+  va_start(list, format);
+  int rc = vkcprintf(&context, format, list);
+  va_end(list);
+  kfree(ctx);
+  return rc;
+}
+
+void context_serial_write(struct print_context *ctx, const char *s, int l) {
+  (void)ctx;
+  for (; l > 0; l--, s++) {
+    log_char(*s);
+  }
+}
+
+struct print_context serial_context = {.data = NULL,
+                                       .write = context_serial_write};
+
+int vkprintf(const char *format, va_list ap) {
+  return vkcprintf(&serial_context, format, ap);
+}
+
 int kprintf(const char *format, ...) {
   va_list list;
   va_start(list, format);
-- 
cgit v1.2.3