Task(dev): Update MindGen to handle new GenOptions fields

This commit is contained in:
Triston Armstrong 2024-12-10 12:12:05 -05:00
parent a5905ae31c
commit fa70d465d4
Signed by: tristonarmstrong
GPG Key ID: A23B48AE45EB6EFE

View File

@ -1,11 +1,12 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// this is a test comment
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[allow(unused)] #[allow(unused)]
pub struct GenRes { pub struct GenRes {
model: String, model: String,
created_at: String, created_at: String,
response: String, pub response: String,
done: bool, done: bool,
total_duration: u64, total_duration: u64,
load_duration: u64, load_duration: u64,
@ -19,6 +20,9 @@
struct GenOptions { struct GenOptions {
temperature: f32, temperature: f32,
num_predict: u8, num_predict: u8,
repeat_last_n: u8,
top_k: u8,
top_p: f32
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -33,16 +37,19 @@
impl MindGen { impl MindGen {
#[allow(unused)] #[allow(unused)]
pub fn new(input: &str) -> Self { pub fn new(input: String) -> Self {
Self { Self {
model: String::from("llama3.1"), model: String::from("llama3.1"),
stream: false, stream: false,
raw: false, raw: false,
prompt: String::from(input), prompt: input,
system: String::from("You are a commit message generator. You will generate commit messages following this format Task(<task #>): <commit message> making sure to never go over 90 characters"), system: String::from("create a short commit message from this diff, with format Task(<branch_name>): <commit_message>. only respond with the commit message"),
options: GenOptions { options: GenOptions {
temperature: 0.5, temperature: 0.1,
num_predict: 90 num_predict: 0,
repeat_last_n: 0,
top_k: 10,
top_p: 0.5
} }
} }
} }