add

AI-Driven Development
git clone git://git.lair.cx/add
Log | Files | Refs | README

main.go (2874B)


      1 package main
      2 
      3 import (
      4 	"bytes"
      5 	"encoding/json"
      6 	"errors"
      7 	"flag"
      8 	"fmt"
      9 	"log"
     10 	"net/http"
     11 	"os"
     12 	"strings"
     13 )
     14 
     15 // Config represents the configuration for the program.
     16 type Config struct {
     17 	Lang       string         `json:"lang"`
     18 	Endpoint   string         `json:"endpoint"`
     19 	Token      string         `json:"token"`
     20 	Parameters map[string]any `json:"parameters"`
     21 	Options    map[string]any `json:"options"`
     22 }
     23 
     24 func main() {
     25 	flag.Parse()
     26 
     27 	args := flag.Args()
     28 	if len(args) < 1 {
     29 		log.Fatalln("Usage: addgen [source.yaml]")
     30 	}
     31 
     32 	// Read the configuration file
     33 	data, err := os.ReadFile("add.json")
     34 	if err != nil {
     35 		log.Fatal(err)
     36 	}
     37 
     38 	var config Config
     39 	err = json.Unmarshal(data, &config)
     40 	if err != nil {
     41 		log.Fatal(err)
     42 	}
     43 
     44 	// Verify the configuration
     45 	if len(config.Lang) == 0 || len(config.Endpoint) == 0 || len(config.Token) == 0 {
     46 		log.Fatalf("Invalid configuration: lang=%q, endpoint=%q, token=%q\n", config.Lang, config.Endpoint, config.Token)
     47 	}
     48 
     49 	// Read the prompt file
     50 	prompt, err := os.ReadFile("prompt.txt")
     51 	if err != nil {
     52 		log.Fatal(err)
     53 	}
     54 
     55 	// Read the yaml file
     56 	inputData, err := os.ReadFile(args[0])
     57 	if err != nil {
     58 		log.Fatal(err)
     59 	}
     60 
     61 	input := appendToPrompt(
     62 		// Replace placeholders in the prompt
     63 		strings.ReplaceAll(string(prompt), "{{ content }}", string(inputData)),
     64 
     65 		// Append the language identifier to the prompt
     66 		fmt.Sprintf("```%s\n", config.Lang),
     67 	)
     68 
     69 	// Query the endpoint repeatedly until there is no more results
     70 	for {
     71 		result, err := query(config, input)
     72 		if err != nil {
     73 			log.Fatal(err)
     74 		}
     75 		if len(result) == 0 {
     76 			break
     77 		}
     78 
     79 		input = appendToPrompt(input, result)
     80 	}
     81 }
     82 
     83 func appendToPrompt(prompt string, s string) string {
     84 	fmt.Print(s)
     85 	return prompt + s
     86 }
     87 
     88 func query(config Config, prompt string) (string, error) {
     89 	// Create the HTTP client
     90 	client := http.Client{}
     91 
     92 	var queryParams = map[string]any{
     93 		"inputs":     prompt,
     94 		"parameters": config.Parameters,
     95 		"options":    config.Options,
     96 	}
     97 
     98 	body, err := json.Marshal(queryParams)
     99 	if err != nil {
    100 		return "", nil
    101 	}
    102 
    103 	// Set up the request
    104 	req, err := http.NewRequest("POST", config.Endpoint, bytes.NewBuffer(body))
    105 	if err != nil {
    106 		return "", err
    107 	}
    108 
    109 	// Add the authorization header
    110 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", config.Token))
    111 	req.Header.Set("Content-Type", "application/json")
    112 
    113 	// Make the request
    114 	resp, err := client.Do(req)
    115 	if err != nil {
    116 		return "", err
    117 	}
    118 
    119 	defer resp.Body.Close()
    120 
    121 	// Check the status code
    122 	if resp.StatusCode != http.StatusOK {
    123 		return "", errors.New(fmt.Sprintf("Unexpected status code: %d", resp.StatusCode))
    124 	}
    125 
    126 	// Decode the response
    127 	var response []struct {
    128 		GeneratedText string `json:"generated_text"`
    129 	}
    130 
    131 	err = json.NewDecoder(resp.Body).Decode(&response)
    132 	if err != nil {
    133 		return "", err
    134 	}
    135 
    136 	// Return the generated text
    137 	return response[0].GeneratedText, nil
    138 }